import torch
import torch.nn as nn
from msc_tools import zero_linear

class StackedNN:
    def __init__(self, ann, max_size):
        # ann should be iterable List[torch.nn.Module]
        self.sweights = [] # stacked weights, list of len L (layers) with elements of dims: (Niter, in_dim, out_dim)
        self.sbiases = []
        self.non_lin = []
        for f in ann:
            if isinstance(f, nn.Linear):
                self.sweights.append(f.weight.t().detach().clone()[None, ...])
                self.sbiases.append(f.bias.detach().clone()[None, None, ...])
            elif isinstance(f, nn.BatchNorm1d):
                gam = (f.weight / (f.running_var + f.eps).pow(.5))[None, ...]
                self.sweights[-1].data[-1] *= gam
                self.sbiases[-1].data[-1] = gam * (self.sbiases[-1].data[-1] - f.running_mean[None, ...]) + f.bias[None, ...]
            else:
                self.non_lin.append(f)
        self.max_size = max_size
        if len(self.non_lin) < len(self.sweights):
            self.non_lin += [nn.Identity()] * (len(self.sweights) - len(self.non_lin))

    def __call__(self, x):
        for w, b, nl in zip(self.sweights, self.sbiases, self.non_lin):
            x = nl(torch.matmul(x, w) + b)
        return x

    def push(self, ann):
        # ann should be a simple neural network with the same number of layers as the class instance
        idx = 0
        for f in ann: 
            if isinstance(f, nn.Linear):
                w = self.sweights[idx]
                b = self.sbiases[idx]
                if w.shape[0] > self.max_size:  # need to delete first element
                    w = w[1:]
                    b = b[1:]
                self.sweights[idx] = torch.cat((w, f.weight.t().detach().clone()[None, ...]), 0)
                self.sbiases[idx] = torch.cat((b, f.bias.t().detach().clone()[None, None, ...]), 0)
                idx += 1
            elif isinstance(f, nn.BatchNorm1d): #apply to the previous layer
                gam = (f.weight / (f.running_var + f.eps).pow(.5))[None, ...]
                self.sweights[idx - 1].data[-1] *= gam
                self.sbiases[idx - 1].data[-1] = gam * (self.sbiases[idx - 1].data[-1] - f.running_mean[None, ...]) + f.bias[None, ...]

    def decay(self, c):
            for w, b in zip(self.sweights, self.sbiases):
                w.data *= c
                b.data *= c

    def add(self, snn, c: float):
        # add two stacked NNs, snn should have less or equal number of stacked networks
        for w, b, w2, b2 in zip(self.sweights, self.sbiases, snn.sweights, snn.sbiases):
            w.data[:w2.shape[0]] += c * w2
            b.data[:b2.shape[0]] += c * b2


class DWExNNvanilla:
    def __init__(self, input_size, output_size, nb_hidden, expand_rate, max_expand, kl_weight, entropy_weight, use_w_correction,
                 device=torch.device('cpu'), nl=nn.ReLU(inplace=True), last_layer_nl=nn.ReLU(inplace=True)):
        super().__init__()

        assert nb_hidden > 0, "Number of hidden layers must be greater than 0."
        self.nb_hidden = nb_hidden  # total number of hidden layers (remains fixed)
        self.input_size = input_size
        self.output_size = output_size
        self.expand_rate = expand_rate
        self.max_expand = max_expand
        self._kl_weight = kl_weight
        self._entropy_weight = entropy_weight
        self.eta = 1. / (kl_weight + entropy_weight)
        self.decay = kl_weight / (kl_weight + entropy_weight)
        self.use_w_correction = use_w_correction
        if self.use_w_correction:
            self.w_correction = 1. / (1. - self.decay ** self.max_expand)
        else:
            self.w_correction = 1.

        print(f'step size eta {self.eta} and decay {self.decay}')
        self.device = device
        self.nl = nl
        self.last_layer_nl = last_layer_nl

        # Frozen net
        self.froz_feat = None
        # self.froz_q = None
        self.sig_q = None

        # Trainable mlp
        self.train_feat, self.train_q = self._get_new_mlps()

    def __call__(self, x):  # returns q values
        return self.train_q(self.train_feat(x))  # returns new Q

    def get_logits(self, x, no_old=False):  # get logits, that depend only on old features
        # no_old for debug only
        if self.froz_feat is None:  # no old features yet
            return torch.zeros(len(x), self.output_size, device=self.device)
        else:
            if no_old and self.sig_q.sweights[0].shape[0] == self.max_expand:  # debug only, for computing KL(pik, tilde{pik}) in the paper
                return self.sig_q(self.froz_feat(x))[1:].sum(0) * self.eta * self.w_correction
            return self.sig_q(self.froz_feat(x)).sum(0) * self.eta * self.w_correction  



    def _get_new_mlps(self):
        insize = self.input_size
        ops = []
        for k in range(self.nb_hidden):
            ops.append(nn.Linear(insize, self.expand_rate))
            if k == self.nb_hidden - 1:
                ops.append(self.last_layer_nl)
            else:
                ops.append(self.nl)
            insize = self.expand_rate

        return nn.Sequential(*ops).to(self.device), zero_linear(nn.Linear(self.expand_rate, self.output_size).to(self.device))

    def parameters(self):
        if self.train_feat is not None:
            return [*self.train_feat.parameters(), *self.train_q.parameters()]
        else:
            return None

    def train(self, train_mode):
        if self.train_feat is not None:
            self.train_feat.train(train_mode)

    def set_entropy_weight(self, weight):
        self._entropy_weight = weight
        self.eta = 1. / (self._kl_weight + self._entropy_weight)
        self.decay = self._kl_weight / (self._kl_weight + self._entropy_weight)
        if self.use_w_correction:
            self.w_correction = 1. / (1. - self.decay ** self.max_expand)
        else:
            self.w_correction = 1.

    def update_sigq(self):
        # merge frozen and trainable MLPs and delete oldest function if necessary
        self.train_feat.train(False)
        if self.froz_feat is None:  # building first frozen feat network
            self.froz_feat = StackedNN(self.train_feat, self.max_expand)
            self.sig_q = StackedNN([self.train_q], self.max_expand) 
            # self.sig_q.decay(self.eta)
        else:
            self.froz_feat.push(self.train_feat)
            self.sig_q.decay(self.decay)
            self.sig_q.push([self.train_q])

