# hsic.py

import torch
import time
import torch.nn as nn
import pytorch_lightning as pl
from collections import OrderedDict

import exps.models as models
import exps.losses as losses
import exps.metrics as metrics
import matplotlib.pyplot as plt
import matplotlib.colors as pltc
import numpy as np
from scipy.linalg import eigh
__all__ = ['LinearReg']


class LinearReg(pl.LightningModule):
    def __init__(self, opts):
        super().__init__()
        self.save_hyperparameters()
        self.opts = opts
        self.kernelx = getattr(models, opts.kernel_type)(sigma=opts.sigma_x)
        self.kernely = getattr(models, opts.kernel_type)(sigma=opts.sigma_y)
        self.kernels = getattr(models, opts.kernel_type)(sigma=opts.sigma_s)

        ################################### Kernel Encoder ####################################
        data_train = np.load('./data/gaussian/data_train.npy', allow_pickle=True)
        data = torch.from_numpy(data_train).float()
        self.y = data[:, 0:2]
        self.y[:, 1] = torch.pow(data[:, 1], 3)
        self.label = data[:, 2].long()
        # self.s = data[:, 0].unsqueeze(1)
        self.s = torch.pow(data[:, 0].unsqueeze(1), 3)

        data_train1 = np.load('./data/gaussian/data_train.npy', allow_pickle=True)
        data1 = torch.from_numpy(data_train1).float()
        self.x = data1[:, 0:1]

        ######################################################################################################3

        # K_x = torch.mm(self.x, self.x.t())
        K_x = self.kernelx(self.x, self.x)

        n = K_x.shape[0]
        Y_bar = (self.y - torch.mean(self.y, dim=0))
        S_bar = (self.s - torch.mean(self.s, dim=0))

        H = torch.eye(n) - torch.ones(n) / n
        self.H = H
        K = torch.mm(torch.mm(H, K_x), H)
        #     import pdb; pdb.set_trace()
        U, Sigma, V = torch.svd(K)
        d = torch.matrix_rank(K).item()
        L = U[:, 0:d]
        V_d = V[:, 0:d]
        Sigma_d = torch.diag(1 / Sigma[0:d])

        B1 = opts.tau * torch.mm(torch.t(L), S_bar)
        B2 = torch.mm(B1, torch.t(S_bar))

        B3 = (opts.tau - 1) * torch.mm(torch.t(L), Y_bar)
        B4 = torch.mm(B3, torch.t(Y_bar))

        B = torch.mm(B2 + B4, L)

        beta, U = torch.eig(B, eigenvectors=True)
        sorted, indices = torch.sort(beta[:, 0])

        G = U[:, indices[0: opts.r]]

        theta = torch.mm(L, G)

        theta = torch.mm(torch.mm(torch.mm(V_d, Sigma_d), torch.t(L)), theta)
        self.theta = theta.t()
        ####################################################################################################
        self.model = getattr(models, opts.model_type)(**opts.model_options)

        self.criterion = {}
        self.criterion['trn_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)
        self.criterion['val_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)

        self.acc_trn = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.acc_val = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.acc_tst = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        if opts.fairness_type is not None:
            self.dep_s_val = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_y_val = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_s = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_y = getattr(metrics, opts.fairness_type)(**opts.fairness_options)

        else:
            self.fair_met = None
            self.fair_met_val = None

    def training_step(self, batch, batch_idx):
        x, y, s, _ = batch
        self.x = self.x.to(device=x.device)
        self.theta = self.theta.to(device=x.device)
        self.H = self.H.to(device=x.device)
        K_x = self.kernelx(x, self.x)
        K_x = torch.mm(K_x, self.H)
        z = torch.mm(K_x, self.theta.t())
        out = self.model(z)
        loss = self.criterion['trn_loss'](out, y)
        acc = self.acc_trn(out, y)

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        output = OrderedDict({
            'loss': loss,
            'acc': acc
        })
        return output

    def validation_step(self, batch, batch_idx):
        x, y, s, label = batch
        self.x = self.x.to(device=x.device)
        self.theta = self.theta.to(device=x.device)
        self.H = self.H.to(device=x.device)
        n = x.size(0)
        K_x = self.kernelx(x, self.x)
        K_x = torch.mm(K_x, self.H)
        # import pdb; pdb.set_trace()
        z = torch.mm(K_x, self.theta.t())
        out = self.model(z)
        loss = self.criterion['val_loss'](out, y)
        acc = self.acc_val(out, y)

        tensorboard_v = self.logger.experiment
        fig, ax = plt.subplots()
        colors = ['red', 'blue']
        ax.scatter(np.random.rand(n), z.detach().cpu(),
                   c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax.axis('equal')
        fig.tight_layout()
        tensorboard_v.add_figure('embedding_val', fig, batch_idx, self.trainer.log_dir)

        fig1, ax1 = plt.subplots()
        ax1.scatter(y[:, 0].detach().cpu(), y[:, 1].detach().cpu(),
                    c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax1.axis('equal')
        fig1.tight_layout()
        tensorboard_v.add_figure('y_val', fig1, batch_idx, self.trainer.log_dir)

        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)


        output = OrderedDict({
            'loss': loss,
            'acc': acc
        })

        if self.dep_s_val is not None:

            s_n = (s - torch.mean(s, dim=0)) / torch.std(s, dim=0)
            z_n = (z - torch.mean(z, dim=0)) / torch.std(z, dim=0)
            y_n = (y - torch.mean(y, dim=0)) / torch.std(y, dim=0)

            dep_s = self.dep_s_val(z_n, s_n)
            # import pdb; pdb.set_trace()

    def validation_epoch_end(self, outputs):
        self.log('val_dep_s', self.dep_s_val.compute(), on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y, s, label = batch
        n = x.size(0)
        self.x = self.x.to(device=x.device)
        self.theta = self.theta.to(device=x.device)
        self.H = self.H.to(device=x.device)
        K_x = self.kernelx(x, self.x)
        K_x = torch.mm(K_x, self.H)
        z = torch.mm(K_x, self.theta.t())
        out = self.model(z)
        loss = self.criterion['val_loss'](out, y)
        acc = self.acc_tst(out, y)

        tensorboard_t = self.logger.experiment
        fig, ax = plt.subplots()
        colors = ['red', 'blue']
        ax.scatter(np.random.rand(n), z.detach().cpu(),
                   c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')

        ax.axis('equal')
        fig.tight_layout()
        tensorboard_t.add_figure('embedding_test', fig, batch_idx, self.trainer.log_dir)

        fig1, ax1 = plt.subplots()
        ax1.scatter(out[:, 0].detach().cpu(), out[:, 1].detach().cpu(),
                    c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax1.axis('equal')
        fig1.tight_layout()
        tensorboard_t.add_figure('output_test', fig1, batch_idx, self.trainer.log_dir)

        self.log('test_loss', loss, on_step=False, on_epoch=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        np.savetxt(self.opts.result_path+'/test_out_kernel.out', out.cpu().numpy(), fmt='%10.5f')
        np.savetxt(self.opts.result_path+'/test_z_kernel.out', z.cpu().numpy(), fmt='%10.5f')
        if self.dep_s is not None:

            s_n = (s - torch.mean(s, dim=0)) / torch.std(s, dim=0)
            z_n = (z - torch.mean(z, dim=0)) / torch.std(z, dim=0)
            y_n = (y - torch.mean(y, dim=0)) / torch.std(y, dim=0)

            dep_s = self.dep_s(z_n, s_n)
    #
    def test_epoch_end(self, outputs):
        # if self.fair_met is not None:
        self.log('dep_s', self.dep_s.compute(), on_step=False, on_epoch=True)


    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.opts.optim_method)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.opts.learning_rate, **self.opts.optim_options)
        if self.opts.scheduler_method is not None:
            scheduler = getattr(torch.optim.lr_scheduler, self.opts.scheduler_method)(
                optimizer, **self.opts.scheduler_options
            )
            return [optimizer], [scheduler]
        else:
            return [optimizer]
