import itertools
from typing import Any
from copy import deepcopy

from sklearn.utils import check_random_state
import numpy as np
import torch
from torch import nn

from braindecode.augmentation.base import Output, Compose

from eeg_augment.diff_aug.relax import RelaxGumbelSoftmax, ControlVariate
from eeg_augment.diff_aug.diff_transforms import POSSIBLE_DIFF_TRANSFORMS
from eeg_augment.diff_aug.diff_transforms import convert_diff_transform


class SampledTransform(torch.autograd.Function):
    @staticmethod
    def forward(ctx, weights, transforms, X, y):
        for idx, (b, op) in enumerate(zip(weights, transforms)):
            if b == 1:
                res = op(X, y)
                ctx.save_for_backward(X.detach(), y.detach(), weights.detach())
                ctx.op = op
                ctx.idx = idx
                return res

    @staticmethod
    def backward(ctx, grad_output, grad_y):
        X, y, weights = ctx.saved_tensors
        with torch.enable_grad():
            weights.requires_grad_()
            s = (ctx.op(X, y)[0] * grad_output).sum()
            s.backward()
        grad_weights = torch.zeros_like(weights)
        grad_weights.data[ctx.idx] = s
        return grad_weights, None, None, None


class DADAPolicy(nn.Module):
    """DADA differentiable augmentation policy, proposed in [1]_

    Parameters
    ----------
    n_subpolicies : int
        Number of subpolicies to sample uniformly from.
    subpolicy_len : int, optional
        Number of consecutive DiffTransforms making up a subpolicy. Defaults to
        2.
    temperature : float, optional
        Used for setting how steep to do the softmax operator in the subpolicy
        stages. Defaults to 0.05.
    ch_names : list, optional
        Ordered list of channels in the dataset. Defaults to None. Ignored if
        diff_transforms is passed.
    sfreq : float, optional
        Sampling rate in the dataset. Defaults to None. Ignored if
        diff_transforms is passed.
    batchwise : bool, optional
        Whether to sample subpolicies per batch (True) or per sample (False).
        Defaults to True.
    random_state : int | numpy.random.RandomState, optional
        Used for setting the transforms and sampling the subpolicies. Defaults
        to None.

    References
    ----------
    .. [1] Li et al., (2020). DADA: Differentiable Automatic Data Augmentation.
    In proceedings of ECCV 2020.
    """
    def __init__(
        self,
        n_subpolicies: int,
        subpolicy_len: int = 2,
        temperature: float = 0.05,
        ch_names: list = None,
        sfreq: float = None,
        batchwise: bool = True,
        random_state: Any = None,
        device: str or torch.device = "cpu",
        **kwargs,
    ):
        super(DADAPolicy, self).__init__()
        print(f"WARNING! Unused kwargs passed to DADAPolicy init: {kwargs}")
        self.rng = check_random_state(random_state)
        self.batchwise = batchwise
        self.n_subpolicies = n_subpolicies
        self.device = device
        self.temperature = torch.as_tensor(
            temperature,
            device=self.device,
        )

        # So that init does not change self.rng state
        # Important to ensure independance between forward passes and init
        init_rng = deepcopy(self.rng)

        # In order to compute all possible subpolicies, we first fetch possible
        # differentiable transform classes
        all_tf_classes = list(POSSIBLE_DIFF_TRANSFORMS.keys())

        # And compute the cartesian product of possible sequences of those
        tf_classes_n_times = [all_tf_classes] * subpolicy_len
        all_tf_cls_combinations = itertools.product(*tf_classes_n_times)

        # Before finally looping across each possible sequence of transform,
        # initializing each obj differently with the init_rng and composing
        # them into a subpolicy with Compose
        possible_additional_params = {
            "sfreq": sfreq,
            "ordered_ch_names": ch_names,
        }

        self.diff_subpolicies = [
            Compose([
                cls(
                    random_state=self.rng,
                    initial_probability=init_rng.uniform(),
                    initial_magnitude=init_rng.uniform(),
                    **{
                        key: value
                        for key, value in possible_additional_params.items()
                        if key in POSSIBLE_DIFF_TRANSFORMS[cls]
                    }
                ) for cls in cls_sequence
            ])
            for cls_sequence in all_tf_cls_combinations
        ]

        # Init subpolicy weights
        n = len(self.diff_subpolicies)
        bound = np.sqrt(6) / np.sqrt(n + 1)
        self._weights = nn.Parameter(
            torch.as_tensor(init_rng.rand(n), dtype=torch.float)
            * 2 * bound - bound
        )
        self._weights.data = self._weights.to(self.device)

        # We need to instantiate a NN-based control variate when using
        # RELAX estimator
        self.variate_net = ControlVariate(
            self._weights.shape[0]
        ).to(self.device)

    def forward(self, X, y) -> Output:
        # Sample a subpolicy and apply it (and use RELAX to estimate the
        # gradient)
        weights = self.sample_weights()

        # In order to have gradients backprop through weights, a special
        # sampling is needed
        return SampledTransform.apply(weights, self.diff_subpolicies, X, y)

    def sample_weights(self):
        """Sample a single operation using the RELAX gradient estimator

        Returns
        -------
        torch.Tensor
            Onehot encoded vector corresponding to the sampled transform.
        """
        self.variate_net.zero_grad()
        # TODO: Temperature could be learnt too
        return RelaxGumbelSoftmax.apply(
            self._weights,
            self.temperature,
            self.variate_net,
        )

    def perturbation_on(self, eps=None):
        """Not implemented"""
        pass

    def perturbation_off(self):
        """Not implemented"""
        pass

    @property
    def weights(self):
        return self._weights.div(self.temperature).softmax(dim=0)

    @property
    def all_probabilities(self):
        return np.hstack([
                    [tf.probability.item() for tf in subpolicy.transforms]
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_magnitudes(self):
        return np.hstack([
                    [
                        tf.magnitude.item() for tf in subpolicy.transforms
                        if tf.magnitude is not None
                    ] for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_weights(self):
        return self.weights

    @property
    def all_prob_grads(self):
        return np.hstack([
                    [
                        tf._probability.grad.item()
                        for tf in subpolicy.transforms
                        if tf._probability.grad is not None
                    ] for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_mag_grads(self):
        return np.hstack([
                    [
                        tf._magnitude.grad.item()
                        for tf in subpolicy.transforms
                        if tf.magnitude is not None and
                        tf._magnitude.grad is not None
                    ] for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_weight_grads(self):
        return self._weights.grad


def convert_dada_subpolicy(dada_subpolicy, random_state):
    """ Converts a dada subpolicy object into a regular subpolicy
    (i.e. Compose transforms object)
    """
    seq_standard_tr = list()
    for o, diff_tf in enumerate(dada_subpolicy.transforms):
        m = diff_tf._magnitude
        if m is not None:
            m = m.item()
        print(
            f"{o}: {type(diff_tf).__name__}",
            f"p={diff_tf._probability.item()}",
            f"m={m}",
        )
        seq_standard_tr.append(
            convert_diff_transform(diff_tf, random_state)
        )
    return Compose(seq_standard_tr)
