from copy import deepcopy
from scipy  import linalg
import torch

import utils
from trainer import validate

# Analyze the relation between Gram matrix and Fisher matrix
class Analyzer:
    def __init__(self, net, train_loader, test_loader=None):
        self.net = net 
        self.X = train_loader.X
        self.y = train_loader.y
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.num_samples = self.y.shape[0]
        self.num_parameters = utils.num_para(net)
        self.grads = torch.zeros(self.num_samples, self.num_parameters) # nxp matrix
        self.errs = torch.zeros(self.num_samples) # n

        self.comp = False

    def compute_grads(self):
        for i in range(self.num_samples):
            self.net.zero_grad()
            xi, yi = self.X[i:i+1], self.y[i:i+1]
            fi = self.net(xi)
            ei = (fi-yi)
            fi.backward()

            self.grads[i] = utils.vecterize_grad(self.net)
            self.errs[i] = ei.item()
        
        self.comp = True

    def compute_feature_norms(self):
        G = self.grads @ self.grads.t()
        chi = (G*G).mean(dim=0)
        return chi.clone()

    def compute_diversity(self):
        G = (self.grads @ self.grads.t())/self.num_samples
        s = torch.linalg.eigvalsh(G).flip(dims=(0,))

        return s[0]*s[0]/s.pow(2).sum()

    def compute_hessian(self):
        self.hessian = (self.grads.t() @ self.grads)/self.num_samples 

    def compute_fisher(self):
        Z = self.grads * self.errs.view(-1,1)
        self.fisher = (Z.t() @ Z)/self.num_samples

    def prepare_grads(self):
        if not self.comp:
            self.compute_grads()

    def loss(self):
        return (self.errs * self.errs).mean()/2

    def hessian_dot_fisher(self):
        pass 

    def hessian_fro(self):
        pass 
        
    def fisher_fro(self):
        pass

    def hessian_trace(self):
        self.prepare_grads()
        tr = (self.grads * self.grads).sum()/self.num_samples 
        return tr.item()

    def alpha(self):
        H_dot_F = self.hessian_dot_fisher()
        H_fro = self.hessian_fro()
        F_fro = self.fisher_fro()

        o = H_dot_F/(H_fro*F_fro)
        return o

    def beta(self):
        F_fro = self.fisher_fro()
        H_fro = self.hessian_fro()
        L = self.loss().item()

        o = F_fro/(2*L*H_fro)
        return o

    def mu(self):
        H_dot_F = self.hessian_dot_fisher()
        H_fro = self.hessian_fro()
        L = self.loss().item()
        o = H_dot_F/(2*L*H_fro*H_fro)
        return o

    def gamma(self):
        chi = self.compute_feature_norms()
        chi /= chi.mean()
        return chi.min().item()

    def train_err(self):
        return validate(self.net, self.train_loader)

    def test_err(self):
        return validate(self.net, self.test_loader)

    def track_checkpoints(self, history, track_inf=None, verbose=True):
        stat = {'iter':history['iter']}
        for key in track_inf:
            stat[key] = []

        for k, states in enumerate(history['net']):
            self.net.load_state_dict(states)
            self.compute_grads()

            for key in track_inf:
                method = getattr(self, key)
                stat[key].append(method())
            if verbose:
                print('{:}-th checkpoints are proceed!'.format(k+1))

        return stat


class AnalyzeNet(Analyzer):
    def __init__(self, net, train_loader, test_loader=None):
        super(AnalyzeNet, self).__init__(net, train_loader, test_loader)

    def hessian_fro(self):
        self.prepare_grads()
        G = self.grads @ self.grads.t()
        fro = (G*G).sum().sqrt()/self.num_samples
        return fro.item()

    def fisher_fro(self):
        self.prepare_grads()
        z = self.grads * self.errs.view(-1,1)
        F = z @ z.t()
        fro = (F*F).sum().sqrt()/self.num_samples
        return fro.item()

    def hessian_dot_fisher(self):
        G = self.grads 
        F = self.grads * self.errs.view(-1,1)
        Z = (F @ G.t())/self.num_samples
        return (Z*Z).sum().item() 

class MinimaAnalyzer:
    def __init__(self, net, X, y):
        self.net = net 
        self.X = X 
        self.y = y
        self.num_samples = X.shape[0]
        self.num_parameters = utils.num_para(net)
        self.grads = torch.zeros(self.num_samples, self.num_parameters)

    def compute_grads(self):
        for i in range(self.num_samples):
            self.net.zero_grad()
            xi = self.X[i:i+1]
            fi = self.net(xi)
            fi.backward()
            self.grads[i] = utils.vecterize_grad(self.net)

    def hessian_PCA(self):
        F = self.grads
        U, S, Vh = torch.linalg.svd(F, full_matrices=False)
        return Vh.t(), (S*S)/self.num_samples



def track_info(model, ckpts, info_func, X, y):
    infos = []
    for k, ckpt in enumerate(ckpts):

        model.load_state_dict(ckpt)
        info = info_func(model, X, y)
        infos.append(info)
        print('{:}-th ckpt'.format(k+1))
    return infos 


def Gram_spectrum(net, X, y):
    num_samples, num_parameters = X.shape[0], utils.num_para(net)
    grads = torch.zeros(num_samples, num_parameters)

    for i in range(num_samples):
        net.zero_grad()
        xi = X[i:i+1]
        fi = net(xi)
        fi.backward()
        grads[i] = utils.vecterize_grad(net)
    G = (grads @ grads.t()).numpy()/num_samples
    s = linalg.eigvalsh(G, subset_by_index=[0,num_samples-1])
    return list(s)





