import sys
import wandb
import torch
import numpy as np
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict

from data import data
from model import model
from utils import utils, losses
from visualization import wandb_utils


def leave_one_out_reg_kernels_one(K_YY, K_QQ, reg):
    Kinv = np.linalg.solve(K_YY + reg * np.eye(K_YY.shape[0]), K_YY).T
    diag_idx = np.arange(K_YY.shape[0])
    return ((K_QQ[diag_idx, diag_idx] + (Kinv @ K_QQ @ Kinv)[diag_idx, diag_idx] -
             2 * (Kinv @ K_QQ)[diag_idx, diag_idx]) / (1 - Kinv[diag_idx, diag_idx]) ** 2).mean()


def leave_one_out_reg_kernels(K_YY, K_QQ, reg_list):
    loos = []
    for reg in reg_list:
        loos.append(leave_one_out_reg_kernels_one(K_YY, K_QQ, reg))
    U, eigs = np.linalg.svd(K_YY, hermitian=True)[:2]
    svd_tol = eigs.max() * U.shape[0] * np.finfo(U.dtype).eps
    regs = np.array(reg_list)
    return loos, regs < svd_tol, svd_tol


class CausalHSIC:
    def __init__(self, data_cfg, model_cfg, exp_cfg) -> None:
        self.data_cfg = data_cfg
        self.model_cfg = model_cfg
        self.exp_cfg = exp_cfg

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.dataloaders = data.create_dataloaders(data_cfg, model_cfg.modes)

        self._set_kernels()  # goes first to correctly setup kernels
        self._get_yz_regressors()
        self._setup_model()
        self._setup_optimizers()
        self._setup_schedulers()

        self.last_best = -1
        self.val_loss = np.inf

    def _get_yz_regressors(self):
        self.yz_reg = defaultdict(dict)
        for mode in self.model_cfg.modes:
            self.yz_reg[mode]['coef'] = torch.FloatTensor(self.dataloaders[mode].dataset.linear_reg.coef_).to(
                self.device)
            self.yz_reg[mode]['intercept'] = torch.FloatTensor(self.dataloaders[mode].dataset.linear_reg.intercept_).to(
                self.device)

    def _set_kernels(self):
        self.kernel_ft = [*self.model_cfg.kernel_ft.keys()][0]
        self.kernel_ft_args = [*self.model_cfg.kernel_ft.values()][0]
        self.kernel_res = [*self.model_cfg.kernel_res.keys()][0]
        self.kernel_res_args = [*self.model_cfg.kernel_res.values()][0]

    def _setup_model(self):
        print("Initializing networks.")
        self.model = model.factory.create(self.model_cfg.model_key, **{"model_cfg": self.model_cfg}).to(self.device)
        if self.exp_cfg.load is not None:
            saved_model = torch.load(self.exp_cfg.load, map_location=self.device)
            utils.copy_state_dict(self.model.state_dict(), saved_model['model'])
        if self.exp_cfg.wandb:
            wandb.watch(self.model)

    def _setup_optimizers(self):
        print("Initializing optimizers.")
        params = list(self.model.parameters())
        optimizer = self.model_cfg.optimizer

        self.opt = eval("optim.{}(params, **{})".format([*optimizer.keys()][0],
                                                        [*optimizer.values()][0]))
        if self.exp_cfg.resume:
            saved_opt = torch.load(self.exp_cfg.load, map_location=self.device)['optimizer']
            self.opt.load_state_dict(saved_opt)

    def _setup_schedulers(self):
        scheduler = self.model_cfg.scheduler
        self.scheduler = eval("optim.lr_scheduler.{}(self.opt, **{})".format([*scheduler.keys()][0],
                                                                             [*scheduler.values()][0]))

    def _backprop(self, loss):
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

    def _epoch(self, epochID, mode):
        '''
        Run a single epoch, aggregate losses & log to wandb.
        '''
        train = 'train' in mode
        self.model.train() if train else self.model.eval()

        all_losses = defaultdict(list)

        data_iter = iter(self.dataloaders[mode])
        tqdm_iter = tqdm(range(len(self.dataloaders[mode])), dynamic_ncols=True)

        for i in tqdm_iter:
            batch = utils.dict_to_device(next(data_iter), self.device)
            x, y, z = batch['x'], batch['y'], batch['z']
            if train:
                ft, y_ = self.model(x)
            else:
                with torch.no_grad():
                    ft, y_ = self.model(x)

            # supervised target loss:
            target_loss = F.mse_loss(y_, y)

            # HSIC regularizer:
            res = z - (torch.mm(y, self.yz_reg[mode]['coef']) + self.yz_reg[mode]['intercept'])
            hsic = 0
            if self.model_cfg.n_last_reg_layers == -1 or self.model_cfg.n_last_reg_layers > len(ft):
                self.model_cfg.n_last_reg_layers = len(ft)
            for int_ft in ft[-self.model_cfg.n_last_reg_layers:]:
                hsic += losses.hsic(int_ft, res, self.kernel_ft, self.kernel_ft_args,
                                    self.kernel_res, self.kernel_res_args,
                                    self.model_cfg.biased)
            loss = target_loss \
                   + self.model_cfg.lamda * hsic

            if train:
                self._backprop(loss)

            tqdm_iter.set_description("V: {} | Epoch: {} | {} | Loss: {:.4f}".format(
                self.exp_cfg.version, epochID, mode, loss.item()
            ), refresh=True)

            all_losses['target_loss'].append(target_loss.item())
            all_losses['hsic'].append(hsic.item())
            all_losses['total_loss'].append(loss.item())

        all_losses = utils.aggregate(all_losses)
        if self.exp_cfg.wandb:
            wandb_utils.log_epoch_summary(epochID, mode, all_losses)

        return all_losses['total_loss']

    def save(self, epochID, loss):
        '''
        Save on improvement as well as every 5 epochs.
        Early stopping.
        '''
        save = False
        if loss < self.val_loss:
            self.val_loss = loss
            save = True
            self.last_best = epochID
        elif epochID - self.last_best > self.model_cfg.patience:
            sys.exit(f"No improvement in the last {self.model_cfg.patience} epochs. EARLY STOPPING.")
        elif epochID > 0 and epochID % 5 == 0:
            save = True
        if save:
            utils.save_model(self.model, self.opt, epochID, loss, self.exp_cfg.output_location)

    def run(self):
        '''
        Run training/inference and save checkpoints.
        '''
        print("Beginning run:")
        for epoch in range(self.model_cfg.epochs):
            for mode in self.model_cfg.modes:
                loss = self._epoch(epoch, mode)
                if mode == 'train':
                    self.scheduler.step()
                if mode == 'val':
                    self.save(epoch, loss)


class CausalHSICTrainerBuilder:
    def __init__(self):
        self._instance = None

    def __call__(self, data_cfg, model_cfg, exp_cfg, **_ignored):
        if not self._instance:
            self._instance = CausalHSIC(data_cfg=data_cfg, model_cfg=model_cfg, exp_cfg=exp_cfg)
        return self._instance
