
# hsic.py

import torch
import time
import torch.nn as nn
import pytorch_lightning as pl
from collections import OrderedDict

import exps.models as models
import exps.losses as losses
import exps.metrics as metrics
import matplotlib.pyplot as plt
import matplotlib.colors as pltc
import numpy as np
from scipy.linalg import eigh
__all__ = ['ARLReg']


class ARLReg(pl.LightningModule):
    def __init__(self, opts):
        super().__init__()
        self.save_hyperparameters()
        self.opts = opts

        ####################################################################################################
        self.model = getattr(models, opts.model_type)(**opts.model_options)
        self.model_adv = getattr(models, opts.model_adv_type)(**opts.model_options)

        self.criterion = {}
        self.criterion['tgt_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)
        self.criterion['adv_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)
        self.criterion['val_tgt_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)
        self.criterion['val_adv_loss'] = getattr(losses, opts.loss_type)(**opts.loss_options)

        self.loss_tgt_trn = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.loss_adv_trn = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)

        self.loss_tgt_val = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.loss_adv_val = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)

        self.loss_tgt_tst = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)
        self.loss_adv_tst = getattr(metrics, opts.evaluation_type)(**opts.evaluation_options)

        if opts.fairness_type is not None:
            self.dep_s_val = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_y_val = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_s = getattr(metrics, opts.fairness_type)(**opts.fairness_options)
            self.dep_y = getattr(metrics, opts.fairness_type)(**opts.fairness_options)

        else:
            self.fair_met = None
            self.fair_met_val = None

    def training_step(self, batch, batch_nb, optimizer_idx):
        x, y, s, _ = batch
        z, out = self.model(x)
        # z, out = self.model(y)
        out_adv = self.model_adv(z)

        if optimizer_idx == 0:
            # import pdb; pdb.set_trace()
            adv_loss = self.criterion['adv_loss'](out_adv, s)
            loss_adv = self.loss_adv_trn(out_adv, s)

            self.log('train_loss_adv', loss_adv, on_step=False, on_epoch=True, prog_bar=True)
            output = OrderedDict({
                'loss': adv_loss,
                'loss_adv': loss_adv
            })
            return output

        if optimizer_idx == 1:
            adv_loss = self.criterion['adv_loss'](out_adv, s)
            loss = (1 - self.opts.tau) * self.criterion['tgt_loss'](out, y) - self.opts.tau * adv_loss
            loss_tgt = self.loss_tgt_trn(out, y)

            self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log('train_loss_tgt', loss_tgt, on_step=False, on_epoch=True, prog_bar=True)

            output = OrderedDict({
                'loss': loss,
                'loss_tgt': loss_tgt
            })
            return output



    def validation_step(self, batch, batch_idx):
        x, y, s, label = batch
        n = x.size(0)
        # import pdb; pdb.set_trace()
        z, out = self.model(x)
        # z, out = self.model(y)
        out_adv = self.model_adv(z)
        loss = (1 - self.opts.tau) * self.criterion['val_tgt_loss'](out, y) \
               - self.opts.tau * self.criterion['val_adv_loss'](out_adv, s)
        loss_tgt = self.loss_tgt_val(out, y)
        loss_adv = self.loss_adv_val(out_adv, s)

        tensorboard_v = self.logger.experiment
        fig, ax = plt.subplots()
        colors = ['red', 'blue']
        if len(z[0, :]) == 1:
            ax.scatter(np.random.rand(n), z.detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        else:
            ax.scatter(z[:, 0].detach().cpu(), z[:, 1].detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax.axis('equal')
        fig.tight_layout()
        tensorboard_v.add_figure('embedding_val', fig, batch_idx, self.trainer.log_dir)

        fig1, ax1 = plt.subplots()
        ax1.scatter(y[:, 0].detach().cpu(), y[:, 1].detach().cpu(),
                    c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax1.axis('equal')
        fig1.tight_layout()
        tensorboard_v.add_figure('y_val', fig1, batch_idx, self.trainer.log_dir)

        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_loss_tgt', loss_tgt, on_step=False, on_epoch=True)
        self.log('val_loss_adv', loss_adv, on_step=False, on_epoch=True, prog_bar=True)
        # print("\nweights", self.model.encoder[0].weight/torch.norm(self.model.encoder[0].weight, 'fro'))


        if self.dep_s_val is not None:

            s_n = (s - torch.mean(s, dim=0)) / torch.std(s, dim=0)
            z_n = (z - torch.mean(z, dim=0)) / torch.std(z, dim=0)
            y_n = (y - torch.mean(y, dim=0)) / torch.std(y, dim=0)
            dep_s = self.dep_s_val(z_n, s_n)

    def validation_epoch_end(self, outputs):
        self.log('val_dep_s', self.dep_s_val.compute(), on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y, s, label = batch
        n = x.size(0)
        z, out = self.model(x)
        # z, out = self.model(y)
        out_adv = self.model_adv(z)
        loss_tgt = self.loss_tgt_tst(out, y)
        loss_adv = self.loss_adv_tst(out_adv, s)


        tensorboard_t = self.logger.experiment
        fig, ax = plt.subplots()
        colors = ['red', 'blue']
        if len(z[0, :]) == 1:
            ax.scatter(np.random.rand(n), z.detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        else:
            ax.scatter(z[:, 0].detach().cpu(), z[:, 1].detach().cpu(),
                       c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax.axis('equal')
        fig.tight_layout()
        tensorboard_t.add_figure('embedding_test', fig, batch_idx, self.trainer.log_dir)

        fig1, ax1 = plt.subplots()
        ax1.scatter(out[:, 0].detach().cpu(), out[:, 1].detach().cpu(),
                    c=label.detach().cpu().numpy(), cmap=pltc.ListedColormap(colors), marker='x')
        ax1.axis('equal')
        fig1.tight_layout()
        tensorboard_t.add_figure('output_test', fig1, batch_idx, self.trainer.log_dir)

        self.log('test_loss_tgt', loss_tgt, on_step=False, on_epoch=True)
        self.log('test_loss_adv', loss_adv, on_step=False, on_epoch=True, prog_bar=True)

        np.savetxt(self.opts.result_path+'/test_out_kernel.out', out.cpu().numpy(), fmt='%10.5f')
        np.savetxt(self.opts.result_path+'/test_z_kernel.out', z.cpu().numpy(), fmt='%10.5f')
        if self.dep_s is not None:
            # dep_s = self.dep_s(z.squeeze(1), y[:, 0])
            # dep_s = self.dep_s(z.squeeze(1), s.squeeze(1))
            s_n = (s - torch.mean(s, dim=0)) / torch.std(s, dim=0)
            z_n = (z - torch.mean(z, dim=0)) / torch.std(z, dim=0)
            y_n = (y - torch.mean(y, dim=0)) / torch.std(y, dim=0)
            dep_s = self.dep_s(z_n, s_n)
    #
    def test_epoch_end(self, outputs):
        # if self.fair_met is not None:
        self.log('dep_s', self.dep_s.compute(), on_step=False, on_epoch=True)

        # print("\nweights", self.model.encoder[0].weight / torch.norm(self.model.encoder[0].weight, 'fro'))


    def configure_optimizers(self):
        optimizer_enc = getattr(torch.optim, self.opts.optim_method)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.opts.learning_rate, **self.opts.optim_options)
        optimizer_dec = getattr(torch.optim, self.opts.optim_method)(
            filter(lambda p: p.requires_grad, self.model_adv.parameters()),
            lr=self.opts.learning_rate, **self.opts.optim_options)
        if self.opts.scheduler_method is not None:
            scheduler_enc = getattr(torch.optim.lr_scheduler, self.opts.scheduler_method)(
                optimizer_enc, **self.opts.scheduler_options
            )
        if self.opts.scheduler_method is not None:
            scheduler_dec = getattr(torch.optim.lr_scheduler, self.opts.scheduler_method)(
                optimizer_dec, **self.opts.scheduler_options
            )
            return [optimizer_dec, optimizer_enc], [scheduler_dec, scheduler_enc]
        else:
            return [optimizer_dec, optimizer_enc]
