import numpy as np
import torch.nn as nn
import torch
from contextlib import contextmanager
from HIB.utils import get_tqdm, ifnone



class TorchWelfordEstimator(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = None  # Defined on first forward pass
        self.shape = None  # Defined on first forward pass
        self.register_buffer('_n_samples', torch.tensor([0], dtype=torch.long))

    def _init(self, shape, device):
        self.device = device
        self.shape = shape
        self.register_buffer('m', torch.zeros(*shape))
        self.register_buffer('s', torch.zeros(*shape))
        self.to(device)

    def forward(self, inputs):

        att, value = inputs

        x = value.reshape(-1, value.size(-1))
        
        if self.shape is None:
            # Initialize runnnig mean and std on first datapoint
            self._init(x.shape[1:], x.device)
        for xi in x:
            old_m = self.m.clone()
            self.m = self.m + (xi-self.m) / (self._n_samples.float() + 1)
            self.s = self.s + (xi-self.m) * (xi-old_m)
            self._n_samples += 1
        return x

    def n_samples(self):
        """ Returns the number of seen samples. """
        return int(self._n_samples.item())

    def mean(self):
        """ Returns the estimate of the mean. """
        return self.m

    def std(self):
        """returns the estimate of the standard derivation."""
        return torch.sqrt(self.s / (self._n_samples.float() - 1))


def insert_into_sequential(sequential, layer, idx):
    """
    Returns a ``nn.Sequential`` with ``layer`` inserted in ``sequential`` at position ``idx``.
    """
    children = list(sequential.children())
    children.insert(idx, layer)
    return nn.Sequential(*children)




class _InterruptExecution(Exception):
    pass


class _HIBForwardHook:
    def __init__(self, iba):
        self.iba = iba

    def __call__(self, m, inputs, outputs):
       
        return self.iba(inputs, outputs)


class HIB(nn.Module):

    def __init__(self,
                 layer_att=None,
                 beta=10,
                 min_std=0.01,
                 optimization_steps=10,
                 lr=1,
                 batch_size=10,
                 lamb=0.1,
                 feature_mean=None,
                 feature_std=None,
                 progbar=True):
        super().__init__()
        self.beta = beta
        self.min_std = min_std
        self.optimization_steps = optimization_steps
        self.lr = lr
        self.batch_size = batch_size
        self.alpha = None  # Initialized on first forward pass
        self.progbar = progbar
        self._buffer_capacity = None  # Filled on forward pass, used for loss
        self.estimator = TorchWelfordEstimator()
        self.device = None
        self._mean = feature_mean
        self._std = feature_std
        self._restrict_flow = False
        self._interrupt_execution = False
        self._hook_handle = None
        self._build_alpha = False
        self.lamb = lamb
        self._estimate = False
        self._constraint_flow = False

        if layer_att is not None:
            self._hook_handle = layer_att.register_forward_hook(_HIBForwardHook(self))
        else:
            pass

    def reset_estimate(self):
        """
        Resets the estimator. Useful if the distribution changes. Which can happen if you
        trained the model more.
        """
        self.estimator = TorchWelfordEstimator()
    

    def estimate(self, model, dataloader, processor, device=None, n_samples=10000, progbar=None, reset=True):

        progbar = progbar if progbar is not None else self.progbar
        if progbar:
            tqdm = get_tqdm()
            bar = tqdm(dataloader, total=n_samples)
        else:
            bar = None

        if device is None:
            device = next(iter(model.parameters())).device
        if reset:
            self.reset_estimate()
        for batch in dataloader:
            if isinstance(batch, dict):
                array = batch["audio"]["array"].squeeze(0)
                sampling_rate = batch["audio"]["sampling_rate"]
                array = processor(array, sampling_rate=sampling_rate, return_tensors="pt", padding="longest").input_values
            elif len(batch)==2:
                array, speaker_id = batch
            else:
                array, action, object, location = batch
            if self.estimator.n_samples() > n_samples:
                break
            with torch.no_grad(), self.interrupt_execution(), self.enable_estimation():
                model(array.to(device))
            if bar:
                bar.n = self.estimator.n_samples()
                bar.refresh()
        if bar:
            bar.close()

        # Cache results
        self._mean = self.estimator.mean()
        self._std = self.estimator.std()

    
    def _reset_alpha(self):
        """ Used to reset the mask to train on another sample """
        with torch.no_grad():
            self.alpha = torch.randn_like(self.alpha)

    def build_alpha(self, inputs):
        """
        Initialize alpha with the same shape as the features.
        """
        att, value = inputs
        shape = att.shape
        device = att.device
        self.alpha = nn.Parameter(torch.randn(shape, device=device), requires_grad=True)
    
    def _do_constraint_information(self, inputs):
        """
        IComplete forward propagation using the trained alpha.
        """
        old_att, value = inputs
        new_att = torch.softmax(self.alpha, dim=-1)
        z = torch.bmm(new_att, value)
        return z


    def detach(self):
        """ Remove the bottleneck to restore the original model """
        if self._hook_handle is not None:
            self._hook_handle.remove()
            self._hook_handle = None
        else:
            raise ValueError("Cannot detach hock. Either you never attached or already detached.")

    def forward(self, inputs, outputs):
        """
        You don't need to call this method manually.
        """

        if self._restrict_flow:
            return self._do_restrict_information(inputs)
        if self._build_alpha:
            self.build_alpha(inputs)
        if self._estimate:
            self.estimator(inputs)
        if self._constraint_flow:
            return self._do_constraint_information(inputs)
        if self._interrupt_execution:
            raise _InterruptExecution()
        
        
        return outputs
    
    @contextmanager
    def enable_estimation(self):
        """
        Context manager to enable estimation of the mean and standard derivation.
        """
        self._estimate = True
        try:
            yield
        finally:
            self._estimate = False

    @contextmanager
    def interrupt_execution(self):
        """
        Interrupts the execution of the model
        """
        self._interrupt_execution = True
        try:
            yield
        except _InterruptExecution:
            pass
        finally:
            self._interrupt_execution = False


    @staticmethod
    def _kl_div(value, att, lamb):

        H,T,d = value.shape
        
        lamb = torch.tensor(lamb)

        a = att.reshape(H*T, T)
        aa_T_tr = torch.norm(a, dim=1)**2
        det_mean =  torch.log(aa_T_tr + lamb**2).mean()
        
        capacity = 0.5*(aa_T_tr.mean() + d * lamb**2 - d - det_mean - 2 * (d-1) * torch.log(lamb))

        return capacity

    def _do_restrict_information(self, inputs):

        old_att, value = inputs
        H,T,d = value.shape
        new_att = self.alpha
        
        self._buffer_capacity = self._kl_div(value, new_att, self.lamb)

        if self._mean is not None and self._std is not None:
            value = value.reshape(H*T,d)
            value = (value - self._mean.unsqueeze(0)) / self._std.unsqueeze(0)
            value = value.reshape(H,T,d)
        
        new_att = torch.softmax(new_att, dim=-1) 
        
        z = torch.bmm(new_att, value)
        z = z + self.lamb * torch.randn_like(z)

        return z

    @contextmanager
    def enable_build_alpha(self):
        """
        Context manager to enable build alpha 
        """
        self._build_alpha = True
        try:
            yield
        finally:
            self._build_alpha = False


    @contextmanager
    def restrict_flow(self):
        """
        Context mananger to enable restrict flow.
        """
        self._restrict_flow = True
        try:
            yield
        finally:
            self._restrict_flow = False

    @contextmanager
    def constraint_flow(self):
        """
        Context mananger to enable constraint flow.
        """
        self._constraint_flow = True
        try:
            yield
        finally:
            self._constraint_flow = False

    def analyze(self, input_audio, model_loss_fn, beta=None, optimization_steps=None, min_std=None, lr=None, batch_size=None,):
        """
        Generates HIB for a given sample.
        """
        assert input_audio.shape[0] == 1, "We can only fit one sample a time"

        
        beta = ifnone(beta, self.beta)
        optimization_steps = ifnone(optimization_steps, self.optimization_steps)
        min_std = ifnone(min_std, self.min_std)
        lr = ifnone(lr, self.lr)
        batch_size = ifnone(batch_size, self.batch_size)
    

        batch = input_audio.expand(batch_size, -1)

        # Reset from previous run or modifications
        self._reset_alpha()
        optimizer = torch.optim.Adam(lr=lr, params=[self.alpha])


        self._loss = []
        self._alpha_grads = []
        self._model_loss = []
        self._information_loss = []

        opt_range = range(optimization_steps)
        
        tqdm = get_tqdm()
        opt_range = tqdm(opt_range, desc="Training Bottleneck", disable=not self.progbar)

        with self.restrict_flow():
            for _ in opt_range:
                optimizer.zero_grad()
                model_loss = model_loss_fn(batch)
                information_loss = self.capacity().mean()
                loss = model_loss + beta * information_loss
                loss.backward()
                optimizer.step()

                self._alpha_grads.append(self.alpha.grad.cpu().numpy())
                self._loss.append(loss.item())
                self._model_loss.append(model_loss.item())
                self._information_loss.append(information_loss.item())

        return torch.sigmoid(self.alpha.detach().cpu())

    def capacity(self):
        """
        Returns a tensor with the capacity from the last input, averaged
        over the redundant batch dimension.
        """
        return self._buffer_capacity.mean(dim=0)

