import torch
from tqdm import tqdm
import torch.nn.functional as F
from collections import defaultdict
import numpy as np
from scipy.linalg import solve as scp_solve

from utils import utils, losses
from visualization import wandb_utils
from trainer.causal_hsic import CausalHSIC


class CausalHSCICorrected(CausalHSIC):
    def __init__(self, data_cfg, model_cfg, exp_cfg) -> None:
        super().__init__(data_cfg, model_cfg, exp_cfg)

    def _set_kernels(self):
        self.kernel_ft = [*self.model_cfg.kernel_ft.keys()][0]
        self.kernel_ft_args = [*self.model_cfg.kernel_ft.values()][0]
        self.kernel_z = [*self.model_cfg.kernel_z.keys()][0]
        self.kernel_z_args = [*self.model_cfg.kernel_z.values()][0]
        self.kernel_y = [*self.model_cfg.kernel_y.keys()][0]
        self.kernel_y_args = [*self.model_cfg.kernel_y.values()][0]

    def _get_yz_regressors(self):
        self.yz_reg = defaultdict(dict)

        print('YZ regressors correction')
        # save memory
        for mode in self.model_cfg.modes:
            del self.dataloaders[mode].dataset.linear_reg

        train_modes = ['train', 'train_ood']
        # todo: add other modes
        print('ONLY TRAIN/VAL HSIC IS CORRECT, OOD IS NOT')
        # for mode in self.model_cfg.modes:
        for mode in train_modes:
            try:
                self.Y_heldout = torch.FloatTensor(self.dataloaders[mode].dataset.targets_heldout)
                self.Z_heldout = torch.FloatTensor(self.dataloaders[mode].dataset.distractors_heldout)

                print('Points saved')
                n_points = self.Y_heldout.shape[0]

                Ky = eval(f'losses.{self.kernel_y}_kernel(self.Y_heldout, **self.kernel_y_args)')
                print('Ky computed')
                Kz = eval(f'losses.{self.kernel_z}_kernel(self.Z_heldout, **self.kernel_z_args)')
                print('Kz computed')
                I = torch.eye(n_points, device=Ky.device)
                print('All gram matrices computed')

                # W_all = torch.linalg.solve((Ky + self.model_cfg.ridge_lambda * I).double(),
                #                            torch.cat((I, Kz), 1).double()).float()
                W_all = torch.tensor(scp_solve(np.float128((Ky + self.model_cfg.ridge_lambda * I).cpu().numpy()),
                                                    np.float128(torch.cat((I, Kz), 1).cpu().numpy()),
                                                    assume_a='pos')).float().to(Ky.device)
                print('W_all computed')

                # del Ky, Kz, I
                print('Old matrices deleted')

                self.W_1 = W_all[:, :n_points].to(self.device)
                self.W_2 = W_all[:, n_points:].to(self.device)

                self.Y_heldout = self.Y_heldout.to(self.device)
                self.Z_heldout = self.Z_heldout.to(self.device)
            except:
                continue

    def _epoch(self, epochID, mode):
        '''
        Run a single epoch, aggregate losses & log to wandb.
        '''
        train = 'train' in mode
        self.model.train() if train else self.model.eval()

        all_losses = defaultdict(list)

        data_iter = iter(self.dataloaders[mode])
        tqdm_iter = tqdm(range(len(self.dataloaders[mode])), dynamic_ncols=True)

        for i in tqdm_iter:
            batch = utils.dict_to_device(next(data_iter), self.device)
            x, y, z = batch['x'], batch['y'], batch['z']

            if train:
                ft, y_ = self.model(x)
            else:
                with torch.no_grad():
                    ft, y_ = self.model(x)

            # supervised target loss:
            target_loss = F.mse_loss(y_, y)

            # HSIC regularizer:
            hsic = 0
            if self.model_cfg.n_last_reg_layers == -1 or self.model_cfg.n_last_reg_layers > len(ft):
                self.model_cfg.n_last_reg_layers = len(ft)
            for int_ft in ft[-self.model_cfg.n_last_reg_layers:]:
                hsic +=  losses.hsic_corrected(int_ft, z, self.Z_heldout, y, self.Y_heldout, self.W_1, self.W_2,
                                               self.kernel_ft, self.kernel_ft_args,
                                               self.kernel_z, self.kernel_z_args, self.kernel_y, self.kernel_y_args,
                                               self.model_cfg.biased)
            loss = target_loss \
                   + self.model_cfg.lamda * hsic

            if train:
                self._backprop(loss)

            tqdm_iter.set_description("V: {} | Epoch: {} | {} | Loss: {:.4f}".format(
                self.exp_cfg.version, epochID, mode, loss.item()
            ), refresh=True)

            all_losses['target_loss'].append(target_loss.item())
            all_losses['hsic_c'].append(hsic.item())
            all_losses['total_loss'].append(loss.item())

        all_losses = utils.aggregate(all_losses)
        if self.exp_cfg.wandb:
            wandb_utils.log_epoch_summary(epochID, mode, all_losses)

        return all_losses['total_loss']


class CausalHSICCorrectedTrainerBuilder:
    def __init__(self):
        self._instance = None

    def __call__(self, data_cfg, model_cfg, exp_cfg, **_ignored):
        if not self._instance:
            self._instance = CausalHSCICorrected(data_cfg=data_cfg, model_cfg=model_cfg, exp_cfg=exp_cfg)
        return self._instance
