from xmeta.utils.evaluation import accuracy
import numpy as np
import torch
from torch import nn
from xmeta.utils.higher_grad import higher_grad
from xmeta.utils.tensor import tensor2numpy
from xmeta.utils.opa import CrossEntropyHessian, inverse_psdmat
from xmeta.utils.gd_inverse import dot_generalized_inv_gd, dot_inv_gd, repeat_dot_inv_gd
import pandas as pd
from tqdm import tqdm
import pickle
import os
from torchsummary import summary
import learn2learn as l2l
import sys


class ExplainerBase:
    def __init__(self,
                 model,
                 params=None,
                 meta_params=None,
                 savedir=None,
                 tag='',
                 ):
        self.model = model
        self.params = params
        self.meta_params = meta_params
        self.src_test_hessian = None
        self.src_param_matrix = None
        self.trg_param_matrix = None
        self.test_errors = None
        self.n_src_task = None
        self.n_meta_param = None
        self.grad_src_test_errors = None
        self.savedir = savedir
        self.tag = tag
        self.name = 'expl_' + self.tag

    def clone_model(self):
        # shold be overridden if output model is to be updated.
        return self.model
        
    def set_src_test_errors(self, src_test_errors: list):
        for error in src_test_errors:
            self.add_src_test_error(error)

    def add_src_test_error(self, error: torch.Tensor):
        v = higher_grad(error, self.meta_params)
        h = higher_grad(v, self.meta_params)
        v = v.detach().to('cpu').numpy()
        h = h.detach().to('cpu').numpy()
        if self.grad_src_test_errors is None:
            self.grad_src_test_errors = v
        else:
            self.grad_src_test_errors =\
                np.vstack([self.grad_src_test_errors, v])

        if self.src_test_hessian is None:
            self.src_test_hessian = h
        else:
            self.src_test_hessian += h
    
    def discard_intermediate(self):
        self.src_test_hessian = None
        self.grad_src_test_errors = None
        self.name = 'di_' + self.name

    def set_src_param_matrix(self):
        assert self.src_test_hessian is not None
        assert self.grad_src_test_errors is not None

        print('Inverting hessian')
        inverse_hessian = np.linalg.inv(self.src_test_hessian)
        print(f'done  (shape {inverse_hessian.shape})')
        self.src_param_matrix =\
            - inverse_hessian.dot(self.grad_src_test_errors.T)
        print(f'set src_param_matrix (shape {self.src_param_matrix.shape})')

    def set_src_generalized_matrix(self,
                                   min_positive_ev: float = None,
                                   min_eigen_ratio: float = 1.0e-9,
                                   n_positive_ev: int = None
                                   ):
        assert self.src_test_hessian is not None
        assert self.grad_src_test_errors is not None
        
        eigen_values, eigen_mat = np.linalg.eig(self.src_test_hessian)
        eigen_values = eigen_values.real
        eigen_mat = eigen_mat.real
        if min_positive_ev is None:
            avg_ev = eigen_values[eigen_values > 0].mean()
            min_positive_ev = avg_ev * min_eigen_ratio
        eigen_values[eigen_values < min_positive_ev] = 0.
        
        if n_positive_ev is not None:
            assert n_positive_ev <= len(eigen_values)
            idx_small = np.argsort(eigen_values)[:len(eigen_values) - n_positive_ev]
            eigen_values[idx_small] = 0.

        # [None] To keep src_test_hessian diagonalizable, we force it to be
        #        real symmetric. We should not use np.linalg.inv()
        #        because it does not gurantee symmetry and it may cause an NaN element
        self.src_test_hessian =\
            (eigen_mat @ np.diag(eigen_values) @ eigen_mat.conj().T).real

        print('Computing pseudo-inverse of hessian')
        inv_eigen_values = np.where(eigen_values >= min_positive_ev,
                                    1 / np.maximum(eigen_values, min_positive_ev),
                                    eigen_values)
        inverse_hessian =\
            (eigen_mat @ np.diag(inv_eigen_values) @ eigen_mat.conj().T).real
        print(f'done  (shape {inverse_hessian.shape})')

        self.src_param_matrix =\
            - inverse_hessian.dot(self.grad_src_test_errors.T)
        print(f'set src_param_matrix (shape {self.src_param_matrix.shape})')

    def set_trg_param_matrix(self, trg_param_matrix=None,
                             hessian=None, gradient=None, trg_train_error=None,
                             params=None):
        assert self.src_param_matrix is not None
        assert isinstance(params, list)
        self.params = params
        if self.n_meta_param is None:
            self.n_meta_param = len(self.src_param_matrix)
        
        if trg_param_matrix is not None:
            self.trg_param_matrix = trg_param_matrix
        else:
            print('set_trg_param_matrix not implemented')
            sys.exit()

    def calc_src_task_scores(self, y):
        assert self.trg_param_matrix is not None
        m = higher_grad(y, self.params).detach().to('cpu').numpy()
        scores = m.dot(self.trg_param_matrix)
        return scores
    
    def explain(self, y,
                top_k=5):
        scores = self.calc_src_task_scores(y)
        idxes = np.argsort(-scores)[:top_k].tolist()
        top_scores = scores[idxes].tolist()
        return idxes, top_scores


class ExplainerOPA(ExplainerBase):
    
    def __init__(self, num_hessian_elements: int = None,
                 ortho_vectors: bool = True,
                 # entropy_weight: bool = False,
                 **kwargs):
        # super().__init__(**kwargs)
        ExplainerBase.__init__(self, **kwargs)
        self.num_hessian_elements = num_hessian_elements
        self.ortho_vectors = ortho_vectors

        name = self.tag
        # self.entropy_weight = entropy_weight
        # if args.discard_intermediate:
        #     explainer_name = 'di_' + explainer_tag
        # if entropy_weight:
        #     explainer_name = 'ew_' + explainer_name
        if self.ortho_vectors:
            name = 'ov_' + name
        if self.num_hessian_elements is not None:
            name = f'nh{self.num_hessian_elements}_' + name
        self.name = 'expl_opa_' + name
    
    def add_src_test_error(self, error: torch.Tensor, model_output: torch.Tensor):
        v = higher_grad(error, self.meta_params)
        v = v.detach().to('cpu').numpy()
        if self.grad_src_test_errors is None:
            self.grad_src_test_errors = v
        else:
            self.grad_src_test_errors =\
                np.vstack([self.grad_src_test_errors, v])

        if self.src_test_hessian is None:
            self.src_test_hessian =\
                CrossEntropyHessian(model_output, self.meta_params,
                                    max_num_elements=self.num_hessian_elements,
                                    ortho_vectors=self.ortho_vectors,
                                    # entropy_weight=self.entropy_weight
                                    )
        else:
            self.src_test_hessian.add(
                CrossEntropyHessian(model_output, self.meta_params,
                                    # entropy_weight=self.entropy_weight
                                    )
            )

    def normalize_src_test_hessian(self):
        self.src_test_hessian.normalize()
    
    def set_src_param_matrix(self, pseudo_inv=True, taylor_series=False, **kwargs):
        assert self.src_test_hessian is not None
        assert self.grad_src_test_errors is not None

        print('multiplying inverted hessian and grad_srs_test_errors')
        if not taylor_series:
            inverse_hessian = inverse_psdmat(m=self.src_test_hessian, **kwargs)
            self.src_param_matrix = - inverse_hessian.dot(self.grad_src_test_errors.T)

        elif pseudo_inv:
            self.src_param_matrix = - repeat_dot_inv_gd(m=self.src_test_hessian,
                                                        v=self.grad_src_test_errors.T,
                                                        algo=dot_generalized_inv_gd,
                                                        max_repeat=1000, **kwargs
                                                        )
        else:
            self.src_param_matrix = - repeat_dot_inv_gd(m=self.src_test_hessian,
                                                        v=self.grad_src_test_errors.T,
                                                        algo=dot_inv_gd,
                                                        max_repeat=1000, **kwargs
                                                        )

        print(f'set src_param_matrix (shape {self.src_param_matrix.shape})')
    

def save_explainer(explainer,
                   prefix: str = None,
                   postfix: str = None,
                   model_path: str = None,
                   ):
    pkl_name = explainer.name
    if prefix is not None:
        pkl_name = prefix + '_' + pkl_name
    if postfix is not None:
        pkl_name = pkl_name + '_' + postfix
    pkl_name = pkl_name + '.pkl'

    if model_path is not None:
        model = explainer.model
        explainer.model = model_path
    
    path = os.path.join(explainer.savedir, pkl_name)
    with open(path, 'wb') as f:
        with open(path, 'wb') as f:
            pickle.dump(explainer, f)
            print(f'saved {path}')
    
    explainer.model = model
