import abc
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.utils import data
from typing import List, Dict
from attribution_linear import hessian_inverse

def _set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        _set_attr(getattr(obj, names[0]), names[1:], val)

def _del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_attr(getattr(obj, names[0]), names[1:])
    
class AutogradInfluenceModule(abc.ABC):
    r"""
    Adapted from the original implementation https://github.com/alstonlo/torch-influence to allow multi-task learning.
    """

    def __init__(
            self,
            model: nn.Module,
            train_loader: Dict[str, data.DataLoader],
            test_loader: Dict[str, data.DataLoader],
            device: torch.device,
            damp: float,
            tns: List[str], # task names
            check_eigvals: bool = False
    ):
        model.eval()
        self.model = model.to(device)
        self.device = device

        self.is_model_functional = False
        self.params_names = tuple(name for name, _ in self._model_params())
        self.params_shape = tuple(p.shape for _, p in self._model_params())

        self.train_loaders = train_loader
        self.test_loaders = test_loader
        self.tns = tns
        self.damp = damp
        self.encoder_shape = sum(p.numel() for p in self.model.encoder.parameters())
        self.decoder_shape = sum(p.numel() for p in self.model.decoders[self.tns[0]].parameters()) # all decoders have the same shape
        params = self._model_make_functional()
        flat_params = self._flatten_params_like(params)

        d = flat_params.shape[0]
        hess = 0.0
        for tn in self.tns:
            for batch in self.train_loaders[tn]:
                batch_size = len(batch[1])
                def f(theta_):
                    self._model_reinsert_params(self._reshape_like_params(theta_))
                    return F.cross_entropy(self.model(batch[0])[tn], batch[1].squeeze())

                hess_batch = torch.autograd.functional.hessian(f, flat_params).detach()
                hess = hess + hess_batch * batch_size

        with torch.no_grad():
            self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)
            hess = hess / sum([len(train_loader.dataset) for train_loader in self.train_loaders.values()])
            hess = hess + damp * torch.eye(d, device=hess.device)

            if check_eigvals:
                eigvals = np.linalg.eigvalsh(hess.cpu().numpy())
                print("hessian min eigval %f", np.min(eigvals).item())
                print("hessian max eigval %f", np.max(eigvals).item())
                if not bool(np.all(eigvals >= 0)):
                    raise ValueError()
                
            H_ss = []
            H_s_m_plus_one = []
            H_m_plus_one_m_plus_one = torch.zeros([self.encoder_shape, self.encoder_shape]).to(device)

            for i, tn in enumerate(self.tns):
                H_ss_tn = hess[self.decoder_shape * i: self.decoder_shape * (i+1), self.decoder_shape * i: self.decoder_shape * (i+1)]
                H_ss.append(H_ss_tn)
                H_s_m_plus_one_tn = hess[self.decoder_shape * i: self.decoder_shape * (i+1), -self.encoder_shape:]
                H_s_m_plus_one.append(H_s_m_plus_one_tn)

            H_ss = torch.stack(H_ss)
            H_s_m_plus_one =  torch.stack(H_s_m_plus_one)

            H_m_plus_one_m_plus_one = hess[-self.encoder_shape:, -self.encoder_shape:]
            H_st_inv, H_s_m_plus_one_inv, H_m_plus_one_m_plus_one_inv = hessian_inverse(H_ss, H_s_m_plus_one, H_m_plus_one_m_plus_one, len(self.tns), self.decoder_shape, device)
            
            # fill in the hessian inverse
            self.hess_inv = torch.zeros_like(hess)
            for i, tn in enumerate(self.tns):
                self.hess_inv[self.decoder_shape * i: self.decoder_shape * (i+1), self.decoder_shape * i: self.decoder_shape * (i+1)] = H_st_inv[i][i]
                for j, _tn in enumerate(self.tns):
                    self.hess_inv[self.decoder_shape * i: self.decoder_shape * (i+1), self.decoder_shape * j: self.decoder_shape * (j+1)] = H_st_inv[i][j]
                    self.hess_inv[self.decoder_shape * j: self.decoder_shape * (j+1), self.decoder_shape * i: self.decoder_shape * (i+1)] = H_st_inv[j][i]
                self.hess_inv[self.decoder_shape * i: self.decoder_shape * (i+1), -self.encoder_shape:] = H_s_m_plus_one_inv[i]
                self.hess_inv[-self.encoder_shape:, self.decoder_shape * i: self.decoder_shape * (i+1)] = H_s_m_plus_one_inv[i].t()
            self.hess_inv[-self.encoder_shape:, -self.encoder_shape:] = H_m_plus_one_m_plus_one_inv
    
    def influence(self, i: int, tn_i:int, j: int, tn_j: int) -> torch.Tensor:
        r"""An example function to show how to compute the influence of the :math:`i`-th test example on the :math:`j`-th test example.

        Args:
            i: the index of the training example.
            j: the index of the test example.
            tn_i: the task index for the training example
            tn_j: the task index for the test example

        Returns:
            The influence of the :math:`i`-th training example on the :math:`j`-th test example.
        """
        self.model.eval()
        params = self._model_params(with_names=False)
        loss_i = F.cross_entropy(self.model(self.test_loaders[tn_i].dataset[i][0])[tn_i], self.test_loaders[tn_i].dataset[i][1].squeeze())
        grad_i = self._flatten_params_like(torch.autograd.grad(loss_i, params, create_graph=True, allow_unused=True))
        
        for param in params:
            if param.grad is not None:
                param.grad.data.zero_()

        loss_j = F.cross_entropy(self.model(self.test_loaders[tn_j].dataset[j][0])[tn_j], self.test_loaders[tn_j].dataset[j][1].squeeze())
        grad_j = self._flatten_params_like(torch.autograd.grad(loss_j, params, create_graph=True, allow_unused=True))
        return - grad_i @ self.hess_inv[tn_i][tn_j] @ grad_j 

    def _model_params(self, with_names=True):
        assert not self.is_model_functional
        return tuple((name, p) if with_names else p for name, p in self.model.named_parameters() if p.requires_grad)

    def _model_make_functional(self):
        assert not self.is_model_functional
        params = tuple(p.detach().requires_grad_() for p in self._model_params(False))

        for name in self.params_names:
            _del_attr(self.model, name.split("."))
        self.is_model_functional = True

        return params

    def _model_reinsert_params(self, params, register=False):
        for name, p in zip(self.params_names, params):
            _set_attr(self.model, name.split("."), torch.nn.Parameter(p) if register else p)
        self.is_model_functional = not register

    def _flatten_params_like(self, params_like):
        vec = []
        cnt = 0
        for p, param in zip(params_like, self.params_shape):
            if p is None:
                zero_tensor = torch.zeros(param, device=self.device)
                vec.append(zero_tensor.view(-1))
            else:
                vec.append(p.view(-1))
        return torch.cat(vec)

    def _reshape_like_params(self, vec):
        pointer = 0
        split_tensors = []
        for dim in self.params_shape:
            num_param = dim.numel()
            split_tensors.append(vec[pointer: pointer + num_param].view(dim))
            pointer += num_param
        return tuple(split_tensors)

