# hsic.py

import torch
import time
import torch.nn as nn
import pytorch_lightning as pl
from collections import OrderedDict

import exps.models as models
import exps.losses as losses
import exps.metrics as metrics
import matplotlib.pyplot as plt
import matplotlib.colors as pltc
import numpy as np
from scipy.linalg import eigh
__all__ = ['KernelReg']


class KernelReg(pl.LightningModule):
    def __init__(self, opts):
        super().__init__()
        self.save_hyperparameters()
        self.opts = opts
        self.kernelx = getattr(models, opts.kernel_type)(sigma=opts.sigma_x)
        self.kernely = getattr(models, opts.kernel_type)(sigma=opts.sigma_y)
        self.kernels = getattr(models, opts.kernel_type)(sigma=opts.sigma_s)

        ################################### Kernel Encoder ####################################
        data_train = np.load('./data/gaussian/data_train.npy', allow_pickle=True)
        data = torch.from_numpy(data_train).float()
        self.y = data[:, 0:2]
        self.y[:, 1] = torch.pow(data[:, 1], 3)
        self.label = data[:, 2].long()
        # self.s = data[:, 0].unsqueeze(1)
        self.s = torch.pow(data[:, 0].unsqueeze(1), 3)

        data_train1 = np.load('./data/gaussian/data_train.npy', allow_pickle=True)
        data1 = torch.from_numpy(data_train1).float()
        self.x = data1[:, 0:1]

        ######################################################################################################3
        K_x = self.kernelx(self.x, self.x)

        # if opts.kernel_labels == 'yes':
        K_y = self.kernely(self.y, self.y)
        K_s = self.kernels(self.s, self.s)

        n = K_x.shape[0]

        H = torch.eye(n) - torch.ones(n) / n
        ####################### centering ###############################3
        K_x = torch.mm(torch.mm(H, K_x), H)
        #############################################
        # K_x = torch.mm(torch.mm(H, K_x), H)
        V1, Sigma, _ = torch.svd(K_x)
        # d = min(1*torch.matrix_rank(K_x).item(), n)
        d = int(max(opts.cholesky_factor*torch.matrix_rank(K_x).item(), int(3*opts.r)))


        L_x = torch.mm(V1[:, 0:d], torch.pow(torch.diag(Sigma[0:d]), 0.5))

        self.H = H
        K_s = torch.mm(torch.mm(H, K_s), H)
        K_y = torch.mm(torch.mm(H, K_y), H)

        B1 = (1 - opts.tau) * K_y - opts.tau * K_s
        #########################################################################
        # if opts.kernel_labels == 'yes':
        B = torch.mm(torch.mm(L_x.t(), B1), L_x)
        A = (1-opts.lam) * torch.mm(torch.mm(L_x.t(), H), L_x) + opts.lam * torch.eye(d)
        eig, U = torch.lobpcg(B, k=opts.r, B=A)

        if eig[opts.r-1] < 0:
            U[:, opts.r-1] = torch.zeros(d, )

        temp = torch.mm(L_x, torch.inverse(torch.mm(L_x.t(), L_x)))

        theta = np.sqrt(n) * torch.mm(temp, U)
        # theta = np.sqrt(n) * torch.mm(torch.pinverse(L_x.t()), U)

        self.theta = theta.t()
        # import pdb; pdb.set_trace()
        ####################################################################################################
        self.model = getattr(models, opts.model_type)(**opts.model_options)

        self.criterion = {}
        self.criterion['trn_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)
        self.criterion['val_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)

        self.acc_trn = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.acc_val = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.acc_tst = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        if opts.fairness_type is not None:
            self.dep_s_val = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_y_val = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_s = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_y = getattr(metrics, opts.fairness_type)(**opts.fairness_options)

        else:
            self.fair_met = None
            self.fair_met_val = None

    def training_step(self, batch, batch_idx):
        x, y, s, _ = batch
        self.x = self.x.to(device=x.device)
        self.theta = self.theta.to(device=x.device)
        self.H = self.H.to(device=x.device)
        K_x = self.kernelx(x, self.x)
        K_x = torch.mm(K_x, self.H)
        z = torch.mm(K_x, self.theta.t())

        out = self.model(z)
        loss = self.criterion['trn_loss'](out, y)
        acc = self.acc_trn(out, y)

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        output = OrderedDict({
            'loss': loss,
            'acc': acc
        })
        return output

    def validation_step(self, batch, batch_idx):
        x, y, s, label = batch
        self.x = self.x.to(device=x.device)
        self.theta = self.theta.to(device=x.device)
        self.H = self.H.to(device=x.device)
        n = x.size(0)
        K_x = self.kernelx(x, self.x)
        K_x = torch.mm(K_x, self.H)
        # import pdb; pdb.set_trace()
        z = torch.mm(K_x, self.theta.t())

        out = self.model(z)
        loss = self.criterion['val_loss'](out, y)
        acc = self.acc_val(out, y)

        tensorboard_v = self.logger.experiment
        fig, ax = plt.subplots()
        colors = ['red', 'blue']
        if len(z[0, :]) == 1:
            ax.scatter(np.random.rand(n), z.detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        else:
            ax.scatter(z[:, 0].detach().cpu(), z[:, 1].detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')


        ax.axis('equal')
        fig.tight_layout()
        tensorboard_v.add_figure('embedding_val', fig, batch_idx, self.trainer.log_dir)

        fig1, ax1 = plt.subplots()
        ax1.scatter(y[:, 0].detach().cpu(), y[:, 1].detach().cpu(),
                    c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax1.axis('equal')
        fig1.tight_layout()
        tensorboard_v.add_figure('y_val', fig1, batch_idx, self.trainer.log_dir)

        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        # print("\nweights", self.model.encoder[0].weight/torch.norm(self.model.encoder[0].weight, 'fro'))

        # import pdb;
        # pdb.set_trace()

        output = OrderedDict({
            'loss': loss,
            'acc': acc
        })

        if self.dep_s_val is not None:

            s_n = (s - torch.mean(s, dim=0)) / torch.std(s, dim=0)
            z_n = (z - torch.mean(z, dim=0)) / torch.std(z, dim=0)
            # y_n = (y - torch.mean(y, dim=0)) / torch.std(y, dim=0)
            dep_s = self.dep_s_val(z_n.squeeze(1), s_n.squeeze(1))


    def validation_epoch_end(self, outputs):
        # pass
        self.log('val_dep_s', self.dep_s_val.compute(), on_step=False, on_epoch=True)


    def test_step(self, batch, batch_idx):
        x, y, s, label = batch
        n = x.size(0)
        self.x = self.x.to(device=x.device)
        self.theta = self.theta.to(device=x.device)
        self.H = self.H.to(device=x.device)
        K_x = self.kernelx(x, self.x)
        K_x = torch.mm(K_x, self.H)
        z = torch.mm(K_x, self.theta.t())

        out = self.model(z)
        loss = self.criterion['val_loss'](out, y)
        acc = self.acc_tst(out, y)

        tensorboard_t = self.logger.experiment
        fig, ax = plt.subplots()
        colors = ['red', 'blue']

        if len(z[0, :]) == 1:
            ax.scatter(np.random.rand(n), z.detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        else:
            ax.scatter(z[:, 0].detach().cpu(), z[:, 1].detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax.axis('equal')
        fig.tight_layout()
        tensorboard_t.add_figure('embedding_test', fig, batch_idx, self.trainer.log_dir)

        fig1, ax1 = plt.subplots()
        ax1.scatter(out[:, 0].detach().cpu(), out[:, 1].detach().cpu(),
                    c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax1.axis('equal')
        fig1.tight_layout()
        tensorboard_t.add_figure('output_test', fig1, batch_idx, self.trainer.log_dir)

        self.log('test_loss', loss, on_step=False, on_epoch=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        # import pdb; pdb.set_trace()
        embedding = np.zeros((n, 2))
        embedding[:, 0] = np.random.rand(n)
        embedding[:, 1] = z.squeeze(1).cpu().numpy()
        # embedding = np.concatenate((np.random.rand(n), z.squeeze(1).cpu().numpy()), axis=1)
        np.savetxt(self.opts.result_path+'/test_out_kernel.out', out.cpu().numpy(), fmt='%10.5f')
        # np.savetxt(self.opts.result_path+'/test_z_kernel.out', z.cpu().numpy(), fmt='%10.5f')
        np.savetxt(self.opts.result_path+'/test_z_kernel.out', embedding, fmt='%10.5f')
        if self.dep_s is not None:

            s_n = (s - torch.mean(s, dim=0)) / torch.std(s, dim=0)
            z_n = (z - torch.mean(z, dim=0)) / torch.std(z, dim=0)
            # y_n = (y - torch.mean(y, dim=0)) / torch.std(y, dim=0)

            dep_s = self.dep_s(z_n.squeeze(1), s_n.squeeze(1))

    #
    def test_epoch_end(self, outputs):
        # pass
        self.log('dep_s', self.dep_s.compute(), on_step=False, on_epoch=True)



    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.opts.optim_method)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.opts.learning_rate, **self.opts.optim_options)
        if self.opts.scheduler_method is not None:
            scheduler = getattr(torch.optim.lr_scheduler, self.opts.scheduler_method)(
                optimizer, **self.opts.scheduler_options
            )
            return [optimizer], [scheduler]
        else:
            return [optimizer]
