
from collections import defaultdict
import copy

from .utils import merge
from .exceptions import (err_exit, DXError, DXCLIError)
from .compat import basestring

'''
System Requirements
+++++++++++++++++++

A module containing utility methods useful for packing and unpacking
system requirements.
'''

class SystemRequirementsDict(object):
    """
    A class representing system requirements that can be passed as
    "systemRequirements" or an entry of the "systemRequirementsByExecutable"
    to the class-xxxx/run or job/new API call (after converting
    it to a dictionary with as_dict()).
    """

    def __init__(self, entrypoints):
        """
        Example of the entrypoints input:
        {"main":
            {"instanceType": "mem2_hdd2_x2"},
         "other_function":
            {"instanceType": "mem2_hdd2_x1",
             "clusterSpec": {"type": "spark",
                             "version": "2.4.0",
                             "initialInstanceCount": 2}}}
        """
        if entrypoints is not None and not isinstance(entrypoints, dict):
            raise DXError("Expected entrypoints to be a dict or None")
        self.entrypoints = copy.deepcopy(entrypoints)

    @classmethod
    def from_instance_count(cls, instance_count_arg, entrypoint="*"):
        """
        Returns a SystemRequirementsDict that can be passed as a
        "systemRequirements" input or an entry of the "systemRequirementsByExecutable" mapping
        to job/new or run/ API calls.
        The instance_count_arg should be either a:
        * string or int eg. "6" or 8
        * dictionary, eg. {"main": 4, "other_function": 2}
        """
        try:
            if instance_count_arg is None:
                return cls(None)
            if isinstance(instance_count_arg, basestring) or isinstance(instance_count_arg, int):
                return cls({entrypoint: {"clusterSpec": {"initialInstanceCount": int(instance_count_arg)}}})
            if isinstance(instance_count_arg, dict):
                return cls({k: {"clusterSpec": {"initialInstanceCount": int(v)}} for k, v in instance_count_arg.items()})
            raise ValueError
        except ValueError:
            DXError('Expected instance_count field to be either an int, string or a dict')

    @classmethod
    def from_instance_type(cls, instance_type_arg, entrypoint="*"):
        """
        Returns SystemRequirementsDict that can be passed as a
        "systemRequirements" input or an entry of the "systemRequirementsByExecutable" mapping
        to job/new or run/ API calls.
        The instance_type_arg should be either a:
        * string, eg. mem1_ssd1_x2
        * dictionary, eg. {"main": "mem2_hdd2_x2", "other_function": "mem2_hdd2_x1"}
        """
        if instance_type_arg is None:
            return cls(None)
        if isinstance(instance_type_arg, basestring):
            # By default, all entry points ("*") should use this instance type
            return cls({entrypoint: {"instanceType": instance_type_arg}})
        if isinstance(instance_type_arg, dict):
            # instance_type is a map of entry point to instance type
            return cls({fn: {"instanceType": fn_inst} for fn, fn_inst in instance_type_arg.items()})
        raise DXError('Expected instance_type field to be either a string or a dict')

    @classmethod
    def from_sys_requirements(cls, system_requirements, _type='all'):
        """
        Returns SystemRequirementsDict encapsulating system requirements.
        It can extract only entrypoints with specific fields ('clusterSpec',
        'instanceType', 'instanceTypeSelector', etc), depending on the value of _type.
        """
        allowed_types = ['all', 'clusterSpec', 'instanceType', 'instanceTypeSelector', 'fpgaDriver', 'nvidiaDriver']
        if _type not in (allowed_types):
            raise DXError("Expected '_type' to be one of the following: {}".format(allowed_types))

        if _type == 'all':
            return cls(system_requirements)

        extracted = defaultdict(dict)
        for entrypoint, req in system_requirements.items():
            if _type in req:
                extracted[entrypoint][_type] = req[_type]
        return cls(dict(extracted))

    def override_cluster_spec(self, srd):
        """
        Returns SystemRequirementsDict can be passed in a "systemRequirements"
        input or as an entry of the "systemRequirementsByExecutable" mapping
        to app-xxx/run, e.g. {'fn': {'clusterSpec': {initialInstanceCount: 3, version: "2.4.0", ..}}}
        Since full clusterSpec must be passed to the API server, we need to retrieve the cluster
        spec defined in app doc's systemRequirements and overwrite the field initialInstanceCount
        with the value the user passed to dx run for each entrypoint.
        initialInstanceCount is currently the only clusterSpec's field the user is allowed to change
        at runtime.
        A few scenarios when requesting instance count for different entrypoints with dx run 
        and the resulting merged systemRequirements (merged_cluster_spec). The bootstapScript
        field here is only one of many (version, ports, etc) that should be copied from app
        spec to merged_cluster_spec:

        Requested: {"*": 5}
        App doc: {"main": "clusterSpec": {"initialInstanceCount": 7, bootstrapScript: "x.sh"},
                "other": "clusterSpec": {"initialInstanceCount": 9, bootstrapScript: "y.sh"}}
        Merged: {"main": "clusterSpec": {"initialInstanceCount": 5, bootstrapScript: "x.sh"},
                "other": "clusterSpec": {"initialInstanceCount": 5, bootstrapScript: "y.sh"}}
        
        Requested: {"*": 15}
        App doc: {"main": "clusterSpec": {"initialInstanceCount": 7, bootstrapScript: "x.sh"},
                  "other": "clusterSpec": {"initialInstanceCount": 9, bootstrapScript: "y.sh"},
                  "*": "clusterSpec": {"initialInstanceCount": 11, bootstrapScript: "y.sh"}}
        Merged: {"main": "clusterSpec": {"initialInstanceCount": 15, bootstrapScript: "x.sh"},
                 "other": "clusterSpec": {"initialInstanceCount": 15, bootstrapScript: "y.sh"},
                 "*": "clusterSpec": {"initialInstanceCount": 15, bootstrapScript: "y.sh"}}

        Requested: {"main": 12}
        App doc: {"main": "clusterSpec": {"initialInstanceCount": 7, bootstrapScript: "x.sh"},
                  "other": "clusterSpec": {"initialInstanceCount": 9, bootstrapScript: "y.sh"}}
        Merged: {"main": "clusterSpec": {"initialInstanceCount": 12, bootstrapScript: "x.sh"}}

        Requested: {"main": 33}
        App doc: {"*": "clusterSpec": {"initialInstanceCount": 2, bootstrapScript: "z.sh"}}
        Merged: {"main": "clusterSpec": {"initialInstanceCount": 33, bootstrapScript: "z.sh"}}

        Requested: {"main": 22, "*": 11}
        App doc: {"*": "clusterSpec": {"initialInstanceCount": 2, bootstrapScript: "t.sh"}}
        Merged: {"main": "clusterSpec": {"initialInstanceCount": 22, bootstrapScript: "t.sh"},
                 "*": "clusterSpec": {"initialInstanceCount": 11, bootstrapScript: "t.sh"}}
        """

        merged_cluster_spec = copy.deepcopy(self.entrypoints)

        # Remove entrypoints without "clusterSpec"
        merged_cluster_spec = dict([(k, v) for k, v in merged_cluster_spec.items() if v.get("clusterSpec") is not None])

        # Remove entrypoints not provided in requested instance counts
        merged_cluster_spec = dict([(k, v) for k, v in merged_cluster_spec.items() if \
            k in srd.entrypoints or "*" in srd.entrypoints])

        # Overwrite values of self.entrypoints.clusterSpec with the ones from srd
        # Named entrypoint takes precedence over the wildcard
        for entry_pt, req in merged_cluster_spec.items():
            merged_cluster_spec[entry_pt]["clusterSpec"].update(
                srd.entrypoints.get(entry_pt, srd.entrypoints.get("*"))["clusterSpec"])

        # Check if all entrypoints in srd are included in merged_cluster_spec
        # (if a named entrypoint was used in srd and such an entrypoint doesn't exist
        #  in app sys req, we need to take the cluster spec from the app's "*", if it exists)
        for entry_pt, req in srd.entrypoints.items():
            if entry_pt not in merged_cluster_spec and "*" in self.entrypoints and "clusterSpec" in self.entrypoints["*"]:
                merged_cluster_spec[entry_pt] = {"clusterSpec": copy.deepcopy(self.entrypoints["*"]["clusterSpec"])}
                merged_cluster_spec[entry_pt]["clusterSpec"].update(req["clusterSpec"])

        return SystemRequirementsDict(merged_cluster_spec)

    def override_spec(self, requested_spec):
        if "*" in requested_spec.as_dict():
            return requested_spec
        
        entrypoints = self.entrypoints.keys() | requested_spec.entrypoints.keys()
        merged_spec = dict.fromkeys(entrypoints)
        for e in entrypoints:
            merged_spec[e] = requested_spec.entrypoints.get(e) or requested_spec.entrypoints.get("*") or self.entrypoints.get(e) or self.entrypoints.get("*")
        return SystemRequirementsDict(merged_spec)

    def _add_dict_values(self, d1, d2):
        """
        Merges the values of two dictionaries, which are expected to be dictionaries, e.g
        d1 = {'a': {'x': pqr}}
        d2 = {'a': {'y': lmn}, 'b': {'y': rst}}
        will return: {'a': {'x': pqr, 'y': lmn}, 'b': {'y': rst}}.
        Collisions of the keys of the sub-dictionaries are not checked.
        """

        if d1 is None and d2 is None:
            return None

        d1 = d1 or {}
        d2 = d2 or {}

        added = {}
        for key in set(list(d1.keys()) + list(d2.keys())):
            added[key] = dict(d1.get(key, {}), **(d2.get(key, {})))
        return added

    def __add__(self, other):
        if not isinstance(other, SystemRequirementsDict):
            raise DXError("Developer error: SystemRequirementsDict expected")
        added_entrypoints = self._add_dict_values(self.entrypoints, other.entrypoints)
        return SystemRequirementsDict(added_entrypoints)

    def as_dict(self):
        return self.entrypoints
