import matplotlib.pyplot as plt
import math
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from transformers import TrainerCallback
# from pyhessian import hessian
import wandb

from typing import Dict
import torch
import math
from torch.autograd import Variable
import numpy as np

from pyhessian.utils import group_product, group_add, normalization, get_params_grad, hessian_vector_product, orthnormal


CHUNK_SIZE = 500


def hessian_vector_product(gradsH, params, v):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv

def get_params_grad(model):
    """
    get model parameters and corresponding gradients
    """
    params = []
    grads = []
    for param in model.parameters():
        if not param.requires_grad:
            continue
        params.append(param)
        grads.append(0. if param.grad is None else param.grad + 0.)
    return params, grads


class hessian():
    """
    The class used to compute :
        i) the top 1 (n) eigenvalue(s) of the neural network
        ii) the trace of the entire neural network
        iii) the estimated eigenvalue density
    """

    def __init__(self, model, criterion=None, data=None, dataloader=None, cuda=True):
        """
        model: the model that needs Hessain information
        criterion: the loss function
        data: a single batch of data, including inputs and its corresponding labels
        dataloader: the data loader including bunch of batches of data
        """

        # make sure we either pass a single batch or a dataloader
        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        self.model = model.eval()  # make model is in evaluation model
        self.criterion = criterion

        if data != None:
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'

        # pre-processing for single batch case to simplify the computation.
        if not self.full_dataset:
            self.inputs, self.targets = self.data
            if self.device == 'cuda':
                if isinstance(self.inputs, torch.Tensor):
                    self.inputs = self.inputs.cuda()
                elif isinstance(self.inputs, Dict):
                    for k in self.inputs:
                        self.inputs[k] = self.inputs[k].cuda()
                else:
                    raise NotImplementedError
                if self.targets is not None:
                    self.targets = self.targets.cuda()

            # if we only compute the Hessian information for a single batch data, we can re-use the gradients.
            outputs = self.model(**self.inputs)
            if "loss" in outputs:
                loss = outputs["loss"]
            else:
                assert self.targets is not None, print("targets should not be None if model does not calculate loss!")
                assert self.criterion is not None, print("criterion should not be None if model does not calculate loss!")
                # TODO outputs may not be the correct thing to pass into criterion
                loss = self.criterion(outputs, self.targets)
            loss.backward(create_graph=True)

        # this step is used to extract the parameters from the model
        params, gradsH = get_params_grad(self.model)
        self.params = params
        self.gradsH = gradsH  # gradient used for Hessian computation

    def dataloader_hv_product(self, v):

        device = self.device
        num_data = 0  # count the number of datum points in the dataloader

        THv = [torch.zeros(p.size()).to(device) for p in self.params
              ]  # accumulate result
        for inputs, targets in self.data:
            self.model.zero_grad()
            tmp_num_data = inputs.size(0)
            outputs = self.model(inputs.to(device))
            loss = self.criterion(outputs, targets.to(device))
            loss.backward(create_graph=True)
            params, gradsH = get_params_grad(self.model)
            self.model.zero_grad()
            Hv = torch.autograd.grad(gradsH,
                                     params,
                                     grad_outputs=v,
                                     only_inputs=True,
                                     retain_graph=False)
            THv = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv, Hv)
            ]
            num_data += float(tmp_num_data)

        THv = [THv1 / float(num_data) for THv1 in THv]
        eigenvalue = group_product(THv, v).cpu().item()
        return eigenvalue, THv

    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0

        while computed_dim < top_n:
            eigenvalue = None
            v = [torch.randn(p.size()).to(device) for p in self.params
                ]  # generate random vector
            v = normalization(v)  # normalize the vector

            for i in range(maxIter):
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    raise NotImplementedError("full dataset mode is not correctly implemented for full parameter finetuning!")
                    # tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    # print("v length:", len(v))
                    # print("grad length:", len(self.gradsH))
                    # print("param length:", len(self.params))
                    # print("v[0] shape:", v[0].size())
                    tmp_eigenvalue = 0.0
                    Hv = []
                    n = len(self.params) // CHUNK_SIZE
                    if n * CHUNK_SIZE< len(self.params):
                        n += 1
                    for i in range(n):
                        # print("in eigenvalues", torch.cuda.memory_allocated() / (1024*1024))
                        Hv_tmp = hessian_vector_product(self.gradsH[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], self.params[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE])
                        Hv += Hv_tmp
                        # TODO break down the process and add tmp_eigen_value together
                        
                tmp_eigenvalue += group_product(Hv, v).cpu().item()

                v = normalization(Hv)

                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6) < tol:
                        break
                    else:
                        eigenvalue = tmp_eigenvalue
            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1

        return eigenvalues, eigenvectors

    def trace(self, maxIter=100, tol=1e-3):
        """
        compute the trace of hessian using Hutchinson's method
        maxIter: maximum iterations used to compute trace
        tol: the relative tolerance
        """

        device = self.device
        trace_vhv = []
        trace = 0.

        for i in range(maxIter):
            self.model.zero_grad()
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1

            if self.full_dataset:
                raise NotImplementedError("full dataset mode is not correctly implemented for full parameter finetuning!")
                # _, Hv = self.dataloader_hv_product(v)
            else:
                trace_tmp = 0.0
                n = len(self.params) // CHUNK_SIZE
                if n * CHUNK_SIZE< len(self.params):
                    n += 1
                for i in range(n):
                    # print("in trace", torch.cuda.memory_allocated() / (1024*1024))

                    Hv = hessian_vector_product(self.gradsH[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], self.params[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE])
                    # Hv = hessian_vector_product(self.gradsH, self.params, v)
                    trace_tmp += group_product(Hv, v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE]).cpu().item()
            trace_vhv.append(trace_tmp)
            if abs(np.mean(trace_vhv) - trace) / (trace + 1e-6) < tol:
                return trace_vhv
            else:
                trace = np.mean(trace_vhv)

        return trace_vhv

    def density(self, iter=100, n_v=1):
        """
        compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ)
        iter: number of iterations used to compute trace
        n_v: number of SLQ runs
        """

        device = self.device
        eigen_list_full = []
        weight_list_full = []

        for k in range(n_v):
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1
            v = normalization(v)

            # standard lanczos algorithm initlization
            v_list = [v]
            # w_list = []
            alpha_list = []
            beta_list = []
            ############### Lanczos
            for i in range(iter):
                self.model.zero_grad()
                # w_prime = [torch.zeros(p.size()).to(device) for p in self.params]
                w_prime = []
                if i == 0:
                    if self.full_dataset:
                        raise NotImplementedError("full dataset mode is not correctly implemented for full parameter finetuning!")
                        # _, w_prime = self.dataloader_hv_product(v)
                    else:
                        # alpha_tmp = 0.0
                        # w_tmp = 0.0
                        n = len(self.params) // CHUNK_SIZE
                        if n * CHUNK_SIZE< len(self.params):
                            n += 1
                        for i in range(n):
                            w_prime_tmp = hessian_vector_product(self.gradsH[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], self.params[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE])
                            # Hv = hessian_vector_product(self.gradsH, self.params, v)
                            # alpha_tmp += group_product(w_prime_tmp, v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE]).cpu().item()
                            w_prime += w_prime_tmp
                            # w_tmp += group_add(w_prime, v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE])
                        # w_prime = hessian_vector_product(
                            # self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w = group_add(w_prime, v, alpha=-alpha)
                    # w_list.append(w)
                else:
                    beta = torch.sqrt(group_product(w, w))
                    beta_list.append(beta.cpu().item())
                    if beta_list[-1] != 0.:
                        # We should re-orth it
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    else:
                        # generate a new vector
                        w = [torch.randn(p.size()).to(device) for p in self.params]
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    if self.full_dataset:
                        raise NotImplementedError("full dataset mode is not correctly implemented for full parameter finetuning!")
                        # _, w_prime = self.dataloader_hv_product(v)
                    else:
                        # print("in density", torch.cuda.memory_allocated() / (1024*1024))
                        # alpha_tmp = 0.0
                        n = len(self.params) // CHUNK_SIZE
                        if n * CHUNK_SIZE< len(self.params):
                            n += 1
                        for i in range(n):
                            w_prime_tmp = hessian_vector_product(self.gradsH[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], self.params[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE], v[i*CHUNK_SIZE:(i+1)*CHUNK_SIZE])
                            # Hv = hessian_vector_product(self.gradsH, self.params, v)
                            w_prime += w_prime_tmp
                            # alpha_tmp += group_product(w_prime, v).cpu().item()
                        # w_prime = hessian_vector_product(
                            # self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w_tmp = group_add(w_prime, v, alpha=-alpha)
                    w = group_add(w_tmp, v_list[-2], alpha=-beta)
                    v_list = v_list[-3:]

            T = torch.zeros(iter, iter).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
            a_, b_ = torch.eig(T, eigenvectors=True)

            eigen_list = a_[:, 0]
            weight_list = b_[0, :]**2
            eigen_list_full.append(list(eigen_list.cpu().numpy()))
            weight_list_full.append(list(weight_list.cpu().numpy()))

        return eigen_list_full, weight_list_full



def get_esd_plot(eigenvalues, weights):
    plt.clf()
    density, grids = density_generate(eigenvalues, weights)
    plt.semilogy(grids, density + 1.0e-7)
    plt.ylabel('Density (Log Scale)', fontsize=14, labelpad=10)
    plt.xlabel('Eigenvlaue', fontsize=14, labelpad=10)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.axis([np.min(eigenvalues) - 1, np.max(eigenvalues) + 1, None, None])
    plt.tight_layout()
    return plt
    # plt.savefig('example.pdf')

def density_generate(eigenvalues,
                     weights,
                     num_bins=10000,
                     sigma_squared=1e-5,
                     overhead=0.01):

    eigenvalues = np.array(eigenvalues)
    weights = np.array(weights)

    lambda_max = np.mean(np.max(eigenvalues, axis=1), axis=0) + overhead
    lambda_min = np.mean(np.min(eigenvalues, axis=1), axis=0) - overhead

    grids = np.linspace(lambda_min, lambda_max, num=num_bins)
    sigma = sigma_squared * max(1, (lambda_max - lambda_min))

    num_runs = eigenvalues.shape[0]
    density_output = np.zeros((num_runs, num_bins))

    for i in range(num_runs):
        for j in range(num_bins):
            x = grids[j]
            tmp_result = gaussian(eigenvalues[i, :], x, sigma)
            density_output[i, j] = np.sum(tmp_result * weights[i, :])
    density = np.mean(density_output, axis=0)
    normalization = np.sum(density) * (grids[1] - grids[0])
    density = density / normalization
    return density, grids


def gaussian(x, x0, sigma_squared):
    return np.exp(-(x0 - x)**2 /
                  (2.0 * sigma_squared)) / np.sqrt(2 * np.pi * sigma_squared)

class PyHessianCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_epoch_end(self, args, state, control, model=None, train_dataloader=None, **kwargs):
        if args.pyhessian and state.is_world_process_zero:
            if not hasattr(self, "sampled_inputs"):
                for inputs in train_dataloader:
                    self.sampled_inputs = inputs
                    print("keep sampled inputs for eigen calculation")
                    break
            model.zero_grad()
            # data=(inputds, labels), criterion=LossFunction
            # since `model` directly gives loss, there is no need to pass targets and criterion` 
            hessian_comp = hessian(model, criterion=None, data=(self.sampled_inputs, None), cuda=True)
            top_eigenvalues, _ = hessian_comp.eigenvalues()
            top_eigenvalues = top_eigenvalues[-1]
            del _
            torch.cuda.empty_cache()
            trace = hessian_comp.trace()
            density_eigen, density_weight = hessian_comp.density()
            plt = get_esd_plot(density_eigen, density_weight)
            # print('\n***Top Eigenvalues: ', top_eigenvalues)
            # print('\n***Trace: ', np.mean(trace))
            # print("\n***Eigen value density***", density_eigen)

            
            eigen_info = {
                "top_eigenvalues": top_eigenvalues,
                "trace": np.mean(trace),
                "eigen_density": wandb.Image(plt),
            }
            print(eigen_info)
            print("eigen info calculated!")
            state.eigen_info = eigen_info
            # assert hasattr(state, "eigen_info")
    
    def on_epoch_end_old(self, args, state, control, model=None, train_dataloader=None, **kwargs):
        if args.pyhessian and state.is_world_process_zero:
            for inputs in train_dataloader:
                model.zero_grad()
                # data=(inputds, labels), criterion=LossFunction
                # since `model` directly gives loss, there is no need to pass targets and criterion` 
                hessian_comp = hessian(model, criterion=None, data=(inputs, None), cuda=True)
                top_eigenvalues, _ = hessian_comp.eigenvalues()
                trace = hessian_comp.trace()
                density_eigen, density_weight = hessian_comp.density()
                plt = get_esd_plot(density_eigen, density_weight)
                # print('\n***Top Eigenvalues: ', top_eigenvalues)
                # print('\n***Trace: ', np.mean(trace))
                # print("\n***Eigen value density***", density_eigen)
                break # use first batch only

            
            eigen_info = {
                "top_eigenvalues": top_eigenvalues[-1],
                "trace": np.mean(trace),
                "eigen_density": wandb.Image(plt),
            }
            print(eigen_info)
            print("eigen info calculated!")
            state.eigen_info = eigen_info
            # assert hasattr(state, "eigen_info")