import torch
import numpy as np
import ipdb as pdb
import torch.nn as nn
import torch.nn.init as init
import pytorch_lightning as pl
import torch.distributions as D
from torch.nn import functional as F
from components.beta import BetaVAE_MLP
from components.transforms import NormalizingFlow
from metrics.correlation_nonlinear import compute_mcc_nonlinear
from metrics.correlation import compute_mcc
from metrics.block import compute_r2
import random as rd
import quadprog
import torch.optim as opt
import wandb


class themodel(nn.Module):

    def __init__(
        self, 
        input_dim,
        c_dim,
        s_dim, 
        nclass,
        n_flow_layers,
        train_dataset,
        test_dataset,
        optimizer="adam",
        embedding_dim=0,
        hidden_dim=128,
        bound=5,
        count_bins=8,
        order='linear',
        lr=1e-4,
        beta=0.0025,
        gamma=0.001,
        sigma=1e-6,
        sigma_x=0.1,
        sigma_y=None,
        vae_slope=0.2,
        use_warm_start=False,
        spline_pth=None,
        decoder_dist='gaussian',
        correlation='Pearson',
        encoder_n_layers=3,
        decoder_n_layers=1,
        scheduler=None,
        lr_factor=0.5,
        lr_patience=10,
        hz_to_z=True,
        max_epochs = 10,
        n_mem = 3,
        batch_size = 256,
        n_domain = 10,
        cuda = True,
        margin = 0.5,
        save_all = True,
        seed = 1,
        save_path = ''
    ):
        '''Stationary subspace analysis'''
        super().__init__()
        self.c_dim = c_dim
        self.s_dim = s_dim
        self.z_dim = c_dim + s_dim
        self.input_dim = input_dim
        self.lr = lr
        self.nclass = nclass
        self.beta = beta
        self.gamma = gamma
        self.sigma = sigma
        self.correlation = correlation
        self.decoder_dist = decoder_dist
        self.embedding_dim = embedding_dim
        self.scheduler = scheduler
        self.lr_factor = lr_factor
        self.lr_patience = lr_patience
        self.best_r2 = 0.
        self.best_mcc = 0.
        self.best_sum = 0.
        self.best_sum_mcc = 0.
        self.best_sum_r2 = 0.
        self.r2_at_best_mcc = 0.
        self.mcc_at_best_r2 = 0.
        self.hz_to_z = hz_to_z
        self.max_epochs = max_epochs
        self.cuda = cuda
        self.gpu = True
        self.margin = margin
        self.save_path = save_path

        # we shuold introduce parameteres for continual learning 
        self.batch_size = batch_size
        self.n_domain = n_domain
        self.n_mem = n_mem
        self.mem_size = [self.n_domain, self.n_mem, self.input_dim]
        self.mem_label_size = [self.n_domain, self.n_mem, 1]

        self.observed_domian = []
        self.mem_cnt = 0
        self.old_domain = 0
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.save_all = save_all
        self.seed = seed
        self.importance = torch.nn.Parameter(torch.ones([1, c_dim]))

        # embedding of the domain index
        if self.embedding_dim > 0:
            self.embeddings = nn.Embedding(self.nclass, self.embedding_dim)


        # Inference
        self.net = BetaVAE_MLP(input_dim=self.input_dim+self.embedding_dim, 
                               output_dim=self.input_dim,
                               z_dim=self.z_dim, 
                               slope=vae_slope,
                               encoder_n_layers=encoder_n_layers,
                               decoder_n_layers=decoder_n_layers,
                               hidden_dim=hidden_dim)


        # Spline flow model to learn the noise distribution
        self.spline_list = []
        for i in range(self.nclass):
            spline = NormalizingFlow(
                input_dim=s_dim,
                n_layers=n_flow_layers,
                bound=bound,
                count_bins=count_bins,
                order=order,
            )

            if use_warm_start:
                spline.load_state_dict(torch.load(spline_pth, 
                                                  map_location=torch.device('cpu')))

                print("Load pretrained spline flow", flush=True)
            self.spline_list.append(spline)
        self.spline_list = nn.ModuleList(self.spline_list)

        self.optimizer = opt.Adam(self.parameters())
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        
        self.grads = torch.Tensor(sum(self.grad_dims), self.nclass)

        if self.gpu:
            self.grads = self.grads.cuda()

        # base distribution for calculation of log prob under the model
        if self.gpu:
            self.register_buffer('base_dist_mean', torch.zeros(self.s_dim).cuda())
            self.register_buffer('base_dist_var', torch.eye(self.s_dim).cuda())
        else:
            self.register_buffer('base_dist_mean', torch.zeros(self.s_dim))
            self.register_buffer('base_dist_var', torch.eye(self.s_dim))


    def forward(self, batch):
        x, c = batch['x'], batch['c']
        if self.embedding_dim > 0:
            x = torch.cat([x, self.embeddings(c.squeeze().long())], dim=1)
        _, mus, logvars, zs = self.net(x)
        return zs, mus, logvars    

    def reparameterize(self, mean, logvar, random_sampling=True):
        if random_sampling:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5*logvar)
            z = mean + eps*std
            return z
        else:
            return mean

    def reconstruction_loss(self, x, x_recon, distribution):
        batch_size = x.size(0)
        assert batch_size != 0

        if distribution == 'bernoulli':
            recon_loss = F.binary_cross_entropy_with_logits(
                x_recon, x, size_average=False).div(batch_size)

        elif distribution == 'gaussian':
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

        elif distribution == 'sigmoid_gaussian':
            x_recon = F.sigmoid(x_recon)
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

        return recon_loss
    
    def get_mask(self): 
        # if mask=1, it is style
        mask = (self.importance > 0.1).detach().float()
        additional_mask = torch.zeros([1, self.z_dim-self.i_dim]).to(mask.device)
        full_mask = torch.cat([mask, additional_mask], 1)
        return torch.gt(full_mask, 1e-8), torch.gt(1-full_mask, 1e-8)

    
    def get_importance(self):
        return self.importance

    @property
    def base_dist(self):
        return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)


    def _loss(self, data):
            x, c = data['x'], data['c']

            if self.gpu:
                x = x.cuda()
                c = c.cuda()

            batch_size, _ = x.shape
            c = torch.squeeze(c).to(torch.int64)

            if self.gpu:
                self.net.cuda()
                self.spline_list.cuda()

            x_recon, mus, logvars, zs = self.net(x) # reconstructed g(f(x)), mean returned by encoder f(x), log varaince returend by encoder f(x), 
                                                    # zs is the sample sampled from the distruibution N(mus, logvar)
            # VAE ELBO loss: recon_loss + kld_loss
            recon_loss = self.reconstruction_loss(x, x_recon, self.decoder_dist)
            q_dist = D.Normal(mus, torch.exp(logvars / 2))
            log_qz = q_dist.log_prob(zs)
            # Content KLD
            p_dist = D.Normal(torch.zeros_like(mus[:,:self.c_dim]), torch.ones_like(logvars[:,:self.c_dim]))
            log_pz_content = torch.sum(p_dist.log_prob(zs[:,:self.c_dim]),dim=-1)
            log_qz_content = torch.sum(log_qz[:,:self.c_dim],dim=-1)
            kld_content = log_qz_content - log_pz_content
            kld_content = kld_content.mean()
            # Style KLD
            log_qz_style = log_qz[:,self.c_dim:]
            residuals = zs[:,self.c_dim:]
            sum_log_abs_det_jacobians = 0
            one_hot = F.one_hot(c, num_classes=self.nclass)
            # Nonstationary branch
            es = [ ]
            logabsdet = [ ]
            for c in range(self.nclass):
                es_c, logabsdet_c = self.spline_list[c](residuals)
                es.append(es_c)
                logabsdet.append(logabsdet_c)
                
            es = torch.stack(es, axis=1)
            logabsdet = torch.stack(logabsdet, axis=1)
            mask = one_hot.reshape(-1, self.nclass)
            es = (es * mask.unsqueeze(-1)).sum(1)
            logabsdet = (logabsdet * mask).sum(1)
            es = es.reshape(batch_size, self.s_dim)
            sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + logabsdet
            log_pz_style = self.base_dist.log_prob(es) + sum_log_abs_det_jacobians
            kld_style = torch.sum(log_qz_style, dim=-1) - log_pz_style
            kld_style = kld_style.mean()
            # VAE training
            loss = recon_loss + self.beta * kld_content + self.gamma * kld_style# + self.sigma * hsic_loss
            return loss

    def _val_loss(self, data):
            x, y, c = data['x'], data['y'], data['c']

            if self.gpu:
                x = x.cuda()
                y = y.cuda()
                c = c.cuda()

            batch_size, _ = x.shape
            c = torch.squeeze(c).to(torch.int64)

            if self.gpu:
                self.net.cuda()
                self.spline_list.cuda()

            x_recon, mus, logvars, zs = self.net(x) # reconstructed g(f(x)), mean returned by encoder f(x), log varaince returend by encoder f(x), 
                                                    # zs is the sample sampled from the distruibution N(mus, logvar)
            # VAE ELBO loss: recon_loss + kld_loss
            recon_loss = self.reconstruction_loss(x, x_recon, self.decoder_dist)
            q_dist = D.Normal(mus, torch.exp(logvars / 2))
            log_qz = q_dist.log_prob(zs)
            # Content KLD
            p_dist = D.Normal(torch.zeros_like(mus[:,:self.c_dim]), torch.ones_like(logvars[:,:self.c_dim]))
            log_pz_content = torch.sum(p_dist.log_prob(zs[:,:self.c_dim]),dim=-1)
            log_qz_content = torch.sum(log_qz[:,:self.c_dim],dim=-1)
            kld_content = log_qz_content - log_pz_content
            kld_content = kld_content.mean()
            # Style KLD
            log_qz_style = log_qz[:,self.c_dim:]
            residuals = zs[:,self.c_dim:]
            sum_log_abs_det_jacobians = 0
            one_hot = F.one_hot(c, num_classes=self.nclass)
            # Nonstationary branch
            es = [ ]
            logabsdet = [ ]
            for c in range(self.nclass):
                es_c, logabsdet_c = self.spline_list[c](residuals)
                es.append(es_c)
                logabsdet.append(logabsdet_c)
                
            es = torch.stack(es, axis=1)
            logabsdet = torch.stack(logabsdet, axis=1)
            mask = one_hot.reshape(-1, self.nclass)
            es = (es * mask.unsqueeze(-1)).sum(1)
            logabsdet = (logabsdet * mask).sum(1)
            es = es.reshape(batch_size, self.s_dim)
            sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + logabsdet
            log_pz_style = self.base_dist.log_prob(es) + sum_log_abs_det_jacobians
            kld_style = torch.sum(log_qz_style, dim=-1) - log_pz_style
            kld_style = kld_style.mean()
            # VAE training
            loss = recon_loss + self.beta * kld_content + self.gamma * kld_style# + self.sigma * hsic_loss
            # Compute Kernel Regression R^2
            if self.hz_to_z is False:
                r2 = compute_r2(mus[:,:self.c_dim], y[:,:self.c_dim])
            else:
                r2 = compute_r2(y[:,:self.c_dim], mus[:,:self.c_dim])
            # Compute Mean Correlation Coefficient (MCC)
            zt_recon = mus[:,self.c_dim:].T.detach().cpu().numpy()
            zt_true = y[:,self.c_dim:].T.detach().cpu().numpy()
            mcc = compute_mcc_nonlinear(zt_recon, zt_true, self.correlation)

            wandb.log({"val_mcc": mcc})
            wandb.log({"val_r2": r2})  
            wandb.log({"val_elbo_loss": loss})
            wandb.log({"val_recon_loss": recon_loss})
            wandb.log({"val_kld_content": kld_content})
            wandb.log({"val_kld_style": kld_style})

            if r2 >= self.best_r2:
                self.best_r2 = r2
                self.mcc_at_best_r2 = mcc
            wandb.log({"best_r2": self.best_r2})
            wandb.log({"mcc_at_best_r2": self.mcc_at_best_r2})

            return loss

    def store_grad(self, pp, grads, grad_dims, tid):
        """
            This stores parameter gradients of past tasks.
            pp: parameters
            grads: gradients
            grad_dims: list with number of parameters per layers
            tid: task id
        """
        # store the gradients
        grads[:, tid].fill_(0.0)
        cnt = 0
        for param in pp():
            if param.grad is not None:
                beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                en = sum(grad_dims[:cnt + 1])
                grads[beg: en, tid].copy_(param.grad.data.view(-1))
            cnt += 1

    def project2cone2(self, gradient, memories, margin=0.5, eps=1e-3):
        """
        Solves the GEM dual QP described in the paper given a proposed
        gradient "gradient", and a memory of task gradients "memories".
        Overwrites "gradient" with the final projected update.
        input:  gradient, p-vector
        input:  memories, (t * p)-vector
        output: x, p-vector
        """
        memories_np = memories.cpu().t().double().numpy()
        gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
        t = memories_np.shape[0]
        P = np.dot(memories_np, memories_np.transpose())
        P = 0.5 * (P + P.transpose()) + np.eye(t) * eps
        q = np.dot(memories_np, gradient_np) * -1
        G = np.eye(t)
        h = np.zeros(t) + margin
        v = quadprog.solve_qp(P, q, G, h)[0]
        x = np.dot(v, memories_np) + gradient_np
        gradient.copy_(torch.Tensor(x).view(-1, 1))

    def overwrite_grad(self, pp, newgrad, grad_dims):
        """
        This is used to overwrite the gradients with a new gradient
        vector, whenever violations occur.
        pp: parameters
        newgrad: corrected gradient
        grad_dims: list storing number of parameters at each layer
        """
        cnt = 0
        for param in pp():
            if param.grad is not None:
                beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                en = sum(grad_dims[:cnt + 1])
                this_grad = newgrad[beg: en].contiguous().view(
                    param.grad.data.size())
                param.grad.data.copy_(this_grad)
            cnt += 1


    def training_step(self, train_loader, max_epochs, val_loader_large):
        # flag current domain, fill the memory 
        mem_stack = []
        for i in range(len(train_loader)): # train_loader is a list with same length as all domains
            mem_stack.append({'x': [], 'c':[]})

        for k in range(len(train_loader)):
            train_loader_current = train_loader[k]

            if k != self.old_domain:
                self.observed_domian.append(k)
                self.old_domain = k
            
            for j in range(self.n_mem):
                idx_tmp = rd.randint(0, len(self.train_dataset[k])-1)
                mem_stack[k]['x'].append(self.train_dataset[k].__getitem__(idx_tmp)['x'])
                mem_stack[k]['c'].append(self.train_dataset[k].__getitem__(idx_tmp)['c'])

            mem_stack[k]['x'] = torch.stack(mem_stack[k]['x']) 
            mem_stack[k]['c'] = torch.stack(mem_stack[k]['c'])
            if self.gpu:
                mem_stack[k]['x'] = mem_stack[k]['x'].cuda()
                mem_stack[k]['c'] = mem_stack[k]['c'].cuda()
            
            for epochs in range(max_epochs):

                wandb.log({"epochs": epochs})
                for batch_idx, data in enumerate(train_loader_current):

                    if len(self.observed_domian) > 1:
                        for tt in range(len(self.observed_domian) - 1):
                            self.zero_grad()
                            past_task = self.observed_domian[tt]
                            past_loss = self._loss(mem_stack[past_task])
                            past_loss.backward()
                            self.store_grad(self.parameters, self.grads, self.grad_dims, past_task)

                    self.zero_grad()

                    loss_current = self._loss(data)
                    wandb.log({"loss": loss_current.item()})
                    loss_current.backward()

                    # check if gradient violates constraints
                    if len(self.observed_domian) > 1:
                        # copy gradient
                        self.store_grad(self.parameters, self.grads, self.grad_dims, k)
                        indx = torch.cuda.LongTensor(self.observed_domian[:-1]) if self.gpu \
                            else torch.LongTensor(self.observed_domian[:-1])
                        dotp = torch.mm(self.grads[:, k].unsqueeze(0),
                                        self.grads.index_select(1, indx))
                        if (dotp < 0).sum() != 0:
                            self.project2cone2(self.grads[:, k].unsqueeze(1),
                                        self.grads.index_select(1, indx), self.margin)
                            # copy gradients back
                            self.overwrite_grad(self.parameters, self.grads[:, k],
                                        self.grad_dims)

                    self.optimizer.step()
                    if batch_idx % 10 == 0:
                        print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                            epochs+1, (batch_idx+1) * len(data['x']), len(train_loader_current.dataset.data['x']),
                            loss_current.item()))

            model_name = 'd_' + str(k+1) + '_seed_' + str(self.seed) + '_mem256'
            if self.save_all:
                save_str = self.save_path + model_name +'.pth'
                torch.save(self.state_dict(), save_str)
                
            # we can check the MCC_all after finishing evert domain
            with torch.no_grad():
                for _, data in enumerate(val_loader_large):
                    x, y, c = data['x'], data['y'], data['c']
                    if self.gpu:
                        x = x.cuda()
                        y = y.cuda()
                        c = c.cuda()
                    x_recon, mus, logvars, zs = self.net(x)
                    zt_recon = mus[:,self.c_dim:].T.detach().cpu().numpy()
                    zt_true = y[:,self.c_dim:].T.detach().cpu().numpy()
                    mcc_all_nonlinear = compute_mcc_nonlinear(zt_recon, zt_true, self.correlation)
                    mcc_all = compute_mcc(zt_recon, zt_true, self.correlation)
                    wandb.log({'MCC_all_nonlinear': mcc_all_nonlinear})
                    wandb.log({'MCC_all': mcc_all})


    def valdation_step(self, val_loader):

            for k in range(len(val_loader)):
                val_loader_current = val_loader[k]
                with torch.no_grad():
                    for _, data in enumerate(val_loader_current):
                        val_loss = self._val_loss(data)


    def train_step_baseline(self, train_loader, max_epochs, val_loader_large):

        for k in range(len(train_loader)):
            train_loader_current = train_loader[k]

            for epochs in range(max_epochs):
                wandb.log({"epochs": epochs})
                for batch_idx, data in enumerate(train_loader_current):
                        loss_current = self._loss(data)
                        wandb.log({"loss": loss_current.item()})
                        loss_current.backward()
                        self.optimizer.step()
                        self.zero_grad()

            model_name = 'd_' + str(k+1) + '_seed_' + str(self.seed) + '_mem0'
            if self.save_all:
                save_str = self.save_path + model_name +'.pth'
                torch.save(self.state_dict(), save_str)

            with torch.no_grad():
                for _, data in enumerate(val_loader_large):
                    x, y, c = data['x'], data['y'], data['c']
                    if self.gpu:
                        x = x.cuda()
                        y = y.cuda()
                        c = c.cuda()
                    x_recon, mus, logvars, zs = self.net(x)
                    zt_recon = mus[:,self.c_dim:].T.detach().cpu().numpy()
                    zt_true = y[:,self.c_dim:].T.detach().cpu().numpy()
                    mcc_all_nonlinear = compute_mcc_nonlinear(zt_recon, zt_true, self.correlation)
                    mcc_all = compute_mcc(zt_recon, zt_true, self.correlation)
                    wandb.log({'MCC_all_nonlinear': mcc_all_nonlinear})
                    wandb.log({'MCC_all': mcc_all})
