import torch
import numpy as np


class Hessian():

    def __init__(self, model, criterion, data=None, dataloader=None, cuda=True):

        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        self.model = model.eval()
        self.criterion = criterion

        if data != None:
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'

        if not self.full_dataset:
            self.inputs, self.targets = self.data
            if self.device == 'cuda':
                self.inputs, self.targets = self.inputs.cuda(
                ), self.targets.cuda()

            outputs = self.model(self.inputs)
            loss = self.criterion(outputs, self.targets)
            loss.backward(create_graph=True)

        params, gradsH = get_params_grad(self.model)
        self.params = params
        self.gradsH = gradsH  # gradient used for Hessian computation

    def dataloader_hv_product(self, v):

        device = self.device
        num_data = 0

        THv = [torch.zeros(p.size()).to(device) for p in self.params]
        for inputs, targets in self.data:
            self.model.zero_grad()
            tmp_num_data = inputs.size(0)
            outputs = self.model(inputs.to(device))
            loss = self.criterion(outputs, targets.to(device))
            loss.backward(create_graph=True)
            params, gradsH = get_params_grad(self.model)
            self.model.zero_grad()
            Hv = torch.autograd.grad(gradsH,
                                     params,
                                     grad_outputs=v,
                                     only_inputs=True,
                                     retain_graph=False)
            THv = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv, Hv)
            ]
            num_data += float(tmp_num_data)

        THv = [THv1 / float(num_data) for THv1 in THv]
        eigenvalue = group_product(THv, v).cpu().item()
        return eigenvalue, THv

    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0

        while computed_dim < top_n:
            eigenvalue = None
            v = [torch.randn(p.size()).to(device) for p in self.params]
            v = normalization(v)

            for i in range(maxIter):
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    Hv = hessian_vector_product(self.gradsH, self.params, v)
                    tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)

                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6) < tol:
                        break
                    else:
                        eigenvalue = tmp_eigenvalue
            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1

        return eigenvalues, eigenvectors

    def trace(self, maxIter=100, tol=1e-3):

        device = self.device
        trace_vhv = []
        trace = 0.

        for i in range(maxIter):
            self.model.zero_grad()
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            for v_i in v:
                v_i[v_i == 0] = -1

            if self.full_dataset:
                _, Hv = self.dataloader_hv_product(v)
            else:
                Hv = hessian_vector_product(self.gradsH, self.params, v)
            trace_vhv.append(group_product(Hv, v).cpu().item())
            if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
                return trace_vhv
            else:
                trace = np.mean(trace_vhv)

        return trace_vhv

    def density(self, iter=100, n_v=1):

        device = self.device
        eigen_list_full = []
        weight_list_full = []

        for k in range(n_v):
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            for v_i in v:
                v_i[v_i == 0] = -1
            v = normalization(v)

            v_list = [v]
            w_list = []
            alpha_list = []
            beta_list = []
            for i in range(iter):
                self.model.zero_grad()
                w_prime = [torch.zeros(p.size()).to(device) for p in self.params]
                if i == 0:
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w = group_add(w_prime, v, alpha=-alpha)
                    w_list.append(w)
                else:
                    beta = torch.sqrt(group_product(w, w))
                    beta_list.append(beta.cpu().item())
                    if beta_list[-1] != 0.:
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    else:
                        w = [torch.randn(p.size()).to(device) for p in self.params]
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w_tmp = group_add(w_prime, v, alpha=-alpha)
                    w = group_add(w_tmp, v_list[-2], alpha=-beta)

            T = torch.zeros(iter, iter).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
            a_, b_ = torch.eig(T, eigenvectors=True)

            eigen_list = a_[:, 0]
            weight_list = b_[0, :]**2
            eigen_list_full.append(list(eigen_list.cpu().numpy()))
            weight_list_full.append(list(weight_list.cpu().numpy()))

        return eigen_list_full, weight_list_full
    

def group_product(xs, ys):
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])


def group_add(params, update, alpha=1):
    for i, p in enumerate(params):
        params[i].data.add_(update[i] * alpha)
    return params


def normalization(v):
    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item()
    v = [vi / (s + 1e-6) for vi in v]
    return v


def get_params_grad(model):
    params = []
    grads = []
    for param in model.parameters():
        if not param.requires_grad:
            continue
        params.append(param)
        grads.append(0. if param.grad is None else param.grad + 0.)
    return params, grads


def hessian_vector_product(gradsH, params, v):
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv


def orthnormal(w, v_list):
    for v in v_list:
        w = group_add(w, v, alpha=-group_product(w, v))
    return normalization(w)

def eigenvalues_modal(self, maxIter=100, tol=1e-3, top_n=1):    
    assert top_n >= 1

    device = self.device

    eigenvalues_ehr = []
    eigenvectors_ehr = []

    eigenvalues_cxr = []
    eigenvectors_cxr = []

    computed_dim = 0
    sample_ehr = True
    sample_cxr = True

    while computed_dim < top_n:
        eigenvalue_ehr = None
        eigenvalue_cxr = None

        # initial v and normalize
        v_ehr = [torch.randn(p.size()).to(device) for p in self.params_ehr]
        v_cxr = [torch.randn(p.size()).to(device) for p in self.params_cxr]
        v_ehr = normalization(v_ehr)
        v_cxr = normalization(v_cxr)

        for i in range(maxIter):
            if sample_ehr:
                v_ehr = orthnormal(v_ehr, eigenvectors_ehr)
            if sample_cxr:
                v_cxr = orthnormal(v_cxr, eigenvectors_cxr)
            self.model.zero_grad()

            # calculate in each batch
            Hv_ehr, Hv_cxr = self.hv_product_modal(v_ehr, v_cxr)

            if sample_ehr:
                tmp_eigenvalue_ehr = group_product(Hv_ehr, v_ehr).cpu().item()
                v_ehr = normalization(Hv_ehr)
            if sample_cxr:
                tmp_eigenvalue_cxr = group_product(Hv_cxr, v_cxr).cpu().item()
                v_cxr = normalization(Hv_cxr)

            if eigenvalue_ehr is None:
                eigenvalue_ehr = tmp_eigenvalue_ehr
                eigenvalue_cxr = tmp_eigenvalue_cxr
            else:
                if abs(eigenvalue_ehr - tmp_eigenvalue_ehr) / (abs(eigenvalue_ehr) + 1e-6) < tol:
                    sample_ehr = False
                else:
                    eigenvalue_ehr = tmp_eigenvalue_ehr

                if abs(eigenvalue_cxr - tmp_eigenvalue_cxr) / (abs(eigenvalue_cxr) + 1e-6) < tol:
                    sample_cxr = False
                else:
                    eigenvalue_cxr = tmp_eigenvalue_cxr

                if not sample_ehr and not sample_cxr:
                    break

        eigenvalues_ehr.append(eigenvalue_ehr)
        eigenvectors_ehr.append(v_ehr)
        eigenvalues_cxr.append(eigenvalue_cxr)
        eigenvectors_cxr.append(v_cxr)
        computed_dim += 1

    return eigenvalues_ehr, eigenvalues_cxr, np.mean(eigenvalues_ehr), np.mean(eigenvalues_cxr)