import torch
from contextlib import contextmanager
from collections import namedtuple


class GeneralizedDARTS(torch.nn.Module):
    """Implement a generalized architecture search space."""

    def __init__(self, in_channels=3, num_classes=10, softmax_normalization=False, setup=None, use_jit=True,
                 temperature=1, norm=None, *args, **kwargs):
        """Initialize with channels/classes of dataset.

        Optionally: Set softmax_normalization=True to parametrize all alphas as softmax(alpha)
        """
        super().__init__()
        self.criterion = torch.nn.CrossEntropyLoss()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.setup = setup
        self.use_jit = use_jit

        self.norm = norm  # Normalize operator outputs before weighted sum? Options: 'DARTS', 'renorm', None
        self.softmax_normalization = softmax_normalization
        self.temperature = 1  # only active if softmax_normalization=True

        self._initialize_layers(*args, **kwargs)
        self.reset_alphas()

    """ Implement these methods for a new search space:"""

    def _initialize_layers(self, *args, **kwargs):
        """Initialize the model layers here."""
        raise NotImplementedError()

    def _forward_implementation(self, input, weights, **kwargs):
        """Implement the forward pass here, assume weights are probabilities."""
        raise NotImplementedError()
        # return outputs

    @torch.no_grad()
    def project_onto_constraint(self):
        """Project onto ||theta_k|| < alpha_k, keeping alpha fixed."""
        def _project(theta, alpha):
            theta.clamp_(-alpha, alpha)

        raise NotImplementedError()
        # return None

    def _initialize_alphas(self):
        """Initialize alphas here."""
        raise NotImplementedError()
        # return list_of_alphas

    def genotype(self):
        raise NotImplementedError()
        return namedtuple('Genotype', 'operations')

    """General methods"""

    def forward(self, input, **kwargs):
        return self._forward_implementation(input, self.return_probabilities(), **kwargs)

    def forward_argmax(self, input, **kwargs):
        """Evaluate the currently dominating combination."""
        binary_weights = [torch.zeros_like(a).scatter_(1, a.argmax(dim=-1, keepdim=True), 1) for a in self.arch_parameters()]
        return self._forward_implementation(input, binary_weights, **kwargs)

    def _loss(self, input, target):
        logits = self(input)
        return self.criterion(logits, target)

    def return_probabilities(self):
        if self.softmax_normalization:
            if isinstance(self.temperature, list):
                return [(a * t).softmax(dim=-1) for a, t in zip(self.arch_parameters(), self.temperature)]
            else:
                return [(a * self.temperature).softmax(dim=-1) for a in self.arch_parameters()]
        else:
            return [a for a in self.arch_parameters()]

    @contextmanager
    def no_sync(self):
        """Overwrite this in distributed mode."""
        yield

    @torch.no_grad()
    def reset_alphas(self):
        alphas = self._initialize_alphas()
        alphas = [a.to(device=self.setup['device'], dtype=self.setup['dtype']) for a in alphas]
        if not self.softmax_normalization:
            alphas = [(a * self.temperature).softmax(dim=-1).requires_grad_() for a in alphas]
        else:
            alphas = [a.requires_grad_() for a in alphas]
        self._arch_parameters = alphas

    def reset_alpha_grad(self):
        self._arch_parameters = [a.detach().requires_grad_() for a in self._arch_parameters]

    def arch_parameters(self):
        return self._arch_parameters

    def mean_entropy(self):
        """Print mean entropy over all alpha parameters."""
        entropy = 0
        for alpha in self.return_probabilities():
            entropy += self.entropy(alpha)
        return entropy.item() / len(self.arch_parameters())

    def normalized_entropy(self):
        """Print mean entropy over all alpha parameters."""
        entropy = 0
        for alpha in self.return_probabilities():
            entropy += self.entropy(alpha) / torch.log(torch.as_tensor(torch.numel(alpha), dtype=torch.float))
        return entropy.item() / len(self.arch_parameters())

    @staticmethod
    def entropy(y):
        """Return the simplex entropy of an input vector y."""
        if y.sum() > (1 + 1e-4) or (y < 0).any():
            return torch.tensor(float("Inf"), device=y.device, dtype=y.dtype)
        else:
            return torch.where(y > 0, -y * torch.log(y), torch.zeros_like(y)).sum()
