#*
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
# This file is part of PyHessian library.
#
# PyHessian is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# PyHessian is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with PyHessian.  If not, see <http://www.gnu.org/licenses/>.
#*

import torch
import math
from torch.autograd import Variable
import numpy as np

from numpy import linalg as LA

from pyhessian.utils import group_product, group_add, normalization, get_params_grad, hessian_vector_product, orthnormal, get_params

import torch.nn.functional as F

from utils import *



from torch.cuda.amp import GradScaler, autocast

class hessian():
    """
    The class used to compute :
        i) the top 1 (n) eigenvalue(s) of the neural network
        ii) the trace of the entire neural network
        iii) the estimated eigenvalue density
    """

    def __init__(self, model, criterion, data=None, dataloader=None, cuda=True):
        """
        model: the model that needs Hessain information
        criterion: the loss function
        data: a single batch of data, including inputs and its corresponding labels
        dataloader: the data loader including bunch of batches of data
        """

        # make sure we either pass a single batch or a dataloader
        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        self.model = model.eval()  # make model is in evaluation model
        self.criterion = criterion

        if data != None:
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'

        # pre-processing for single batch case to simplify the computation.
        if not self.full_dataset:
            self.inputs, self.targets = self.data
            if self.device == 'cuda':
                self.inputs, self.targets = self.inputs.cuda(
                ), self.targets.cuda()

            # if we only compute the Hessian information for a single batch data, we can re-use the gradients.
            outputs = self.model((self.inputs)) #######################################
            loss = self.criterion(outputs, self.targets)
#             print('loss', loss)
            loss.backward(create_graph=True)

        # this step is used to extract the parameters from the model
        params, gradsH = get_params_grad(self.model)
        self.params = params
        self.gradsH = gradsH  # gradient used for Hessian computation

    def dataloader_hv_product(self, v):

        device = self.device
        num_data = 0  # count the number of datum points in the dataloader

        THv = [torch.zeros(p.size()).to(device) for p in self.params
              ]  # accumulate result
        for inputs, targets in self.data:
            self.model.zero_grad()
            tmp_num_data = inputs.size(0)
            outputs = self.model((inputs.to(device)))
            loss = self.criterion(outputs, targets.to(device))
            loss.backward(create_graph=True)
            params, gradsH = get_params_grad(self.model)
            self.model.zero_grad()
            Hv = torch.autograd.grad(gradsH,
                                     params,
                                     grad_outputs=v,
                                     only_inputs=True,
                                     retain_graph=False)
            THv = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv, Hv)
            ]
            num_data += float(tmp_num_data)

        THv = [THv1 / float(num_data) for THv1 in THv]
        eigenvalue = group_product(THv, v).cpu().item()
        return eigenvalue, THv

    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0

        while computed_dim < top_n:
            if top_n>1:
                print('computed_dim:',computed_dim, end='\r')
            eigenvalue = None
            v = [torch.randn(p.size()).to(device) for p in self.params
                ]  # generate random vector
            v = normalization(v)  # normalize the vector

            for i in range(maxIter):
                if i==maxIter-1:
                    print('\n maxiter with tolerence', abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6), '---')
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    Hv = hessian_vector_product(self.gradsH, self.params, v)
                    tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)

                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6) < tol:
                        print('\n within tolerence with iter', i, '---')
                        break
                    else:
                        eigenvalue = tmp_eigenvalue
            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1

        return eigenvalues, eigenvectors

    def trace(self, maxIter=100, tol=1e-3):
        """
        compute the trace of hessian using Hutchinson's method
        maxIter: maximum iterations used to compute trace
        tol: the relative tolerance
        """

        device = self.device
        trace_vhv = []
        trace = 0.

        for i in range(maxIter):
            self.model.zero_grad()
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1

            if self.full_dataset:
                _, Hv = self.dataloader_hv_product(v)
            else:
                Hv = hessian_vector_product(self.gradsH, self.params, v)
            trace_vhv.append(group_product(Hv, v).cpu().item())
            if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
                return trace_vhv
            else:
                trace = np.mean(trace_vhv)

        return trace_vhv

    def density(self, iter=100, n_v=1):
        """
        compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ)
        iter: number of iterations used to compute trace
        n_v: number of SLQ runs
        """

        device = self.device
        eigen_list_full = []
        weight_list_full = []

        for k in range(n_v):
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1
            v = normalization(v)

            # standard lanczos algorithm initlization
            v_list = [v]
            w_list = []
            alpha_list = []
            beta_list = []
            ############### Lanczos
            for i in range(iter):
                self.model.zero_grad()
                w_prime = [torch.zeros(p.size()).to(device) for p in self.params]
                if i == 0:
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w = group_add(w_prime, v, alpha=-alpha)
                    w_list.append(w)
                else:
                    beta = torch.sqrt(group_product(w, w))
                    beta_list.append(beta.cpu().item())
                    if beta_list[-1] != 0.:
                        # We should re-orth it
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    else:
                        # generate a new vector
                        w = [torch.randn(p.size()).to(device) for p in self.params]
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w_tmp = group_add(w_prime, v, alpha=-alpha)
                    w = group_add(w_tmp, v_list[-2], alpha=-beta)

            T = torch.zeros(iter, iter).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
#             a_, b_ = torch.eig(T, eigenvectors=True)
            a_, b_ = LA.eig(T.cpu().numpy())
            a_, b_ = torch.Tensor(a_).cuda(), torch.Tensor(b_).cuda() 

#             eigen_list = a_[:, 0]
            eigen_list = a_
            weight_list = b_[0, :]**2
            eigen_list_full.append(list(eigen_list.cpu().numpy()))
            weight_list_full.append(list(weight_list.cpu().numpy()))

        return eigen_list_full, weight_list_full

    

class my_hessian(hessian):
    """
    The class used to compute :
        i) the top 1 (n) eigenvalue(s) of the neural network
        ii) the trace of the entire neural network
        iii) the estimated eigenvalue density
    """

    def __init__(self, criterion, data=None, dataloader=None, cuda=True, mode='train'):
        # make sure we either pass a single batch or a dataloader
        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        if data != None: #############
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.criterion = criterion

        self.init_v = None
        self.mode = mode
        
    def ready(self, model):
        """
        model: the model that needs Hessain information
        criterion: the loss function
        data: a single batch of data, including inputs and its corresponding labels
        dataloader: the data loader including bunch of batches of data
        """
        if self.mode =='eval':
            self.model = model.eval()  # make model is in evaluation model
        else:
            self.model = model.train()  # make model is in training model
        
        # pre-processing for single batch case to simplify the computation.
        if not self.full_dataset:
            self.inputs, self.targets = self.data
            if self.device == 'cuda':
                self.inputs, self.targets = self.inputs.cuda(), self.targets.cuda()

            # if we only compute the Hessian information for a single batch data, we can re-use the gradients.
            outputs = self.model((self.inputs)) #######################################
            loss = self.criterion(outputs, self.targets)
#             print('loss', loss)
#             loss.backward(create_graph=True) 

            ######################################
            # this step is used to extract the parameters from the model
            gradsH = torch.autograd.grad(outputs=loss,
                                 inputs=self.model.module.parameters(),
                                 create_graph = True)
        
            params, _ = get_params_grad(self.model)
        else:
            params, gradsH = get_params_grad(self.model)
        ######################################
        self.params = params
        self.gradsH = gradsH  # gradient used for Hessian computation
        

    def Hv(self, v):
        return hessian_vector_product(self.gradsH, self.params, v)

    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0
        
        self.total_iter = 0
        while computed_dim < top_n:
            if top_n>1:
                print('computed_dim:',computed_dim, end='\r')
            eigenvalue = None
            if self.init_v is None:
                v = [torch.randn(p.size()).to(device) for p in self.params
                    ]  # generate random vector
                v = normalization(v)  # norrmalize the vector
            else:
                ### already normalized
                v = self.init_v[computed_dim]
            
            if computed_dim == 0:
                miter = maxIter*2 ######### first evc is more important
            else:
                miter = maxIter
                
            for i in range(miter):
                if i==miter-1:
                    print('\n maxiter with tolerence %.4f'%(ratio), '---')
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    Hv = hessian_vector_product(self.gradsH, self.params, v)
                    tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)
                
                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    ratio = abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6)
                    if ratio < tol:          
#                         print('\n within tolerence with iter', i, '---')
                        break
                    else:              
                        eigenvalue = tmp_eigenvalue              
            self.total_iter += i+1
            eigenvalues.append(eigenvalue)
            # print(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1
            
        pair_list = []
        for j, evl in enumerate(eigenvalues):
            pair_list.append((evl, eigenvectors[j]))
                             
        pair_list.sort(reverse=True)
        
        evls = []
        evcs = []
        for (evl,evc) in pair_list:
            evls.append(evl)
            evcs.append(evc)
            
        self.init_v = evcs

        return evls, evcs
    
        
    
    def dataloader_hv_product(self, v):
        '''
        Only when self.full_dataset == True
        '''
        device = self.device
        num_data = 0  # count the number of datum points in the dataloader

        THv = [torch.zeros(p.size()).to(device) for p in self.params
              ]  # accumulate result
        for inputs, targets in self.data:
            self.model.zero_grad()
            tmp_num_data = inputs.size(0)
            outputs = self.model((inputs.to(device)))
            loss = self.criterion(outputs, targets.to(device))
#             loss.backward(create_graph=True) 
            #######################################
            gradsH = torch.autograd.grad(outputs=loss,
                                 inputs=self.model.module.parameters(),
                                 create_graph = True)
            params, _ = get_params_grad(self.model)
            #######################################
            self.model.zero_grad()
            Hv = torch.autograd.grad(gradsH,
                                     params,
                                     grad_outputs=v,
                                     only_inputs=True,
                                     retain_graph=False)
            THv = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv, Hv)
            ]
            num_data += float(tmp_num_data)

        THv = [THv1 / float(num_data) for THv1 in THv]
        eigenvalue = group_product(THv, v).cpu().item()
        return eigenvalue, THv
    

class ffcv_hessian(my_hessian):
    """
    The class used to compute :
        i) the top 1 (n) eigenvalue(s) of the neural network
        ii) the trace of the entire neural network
        iii) the estimated eigenvalue density
    """

    def __init__(self, criterion, data=None, dataloader=None, cuda=True):
        # make sure we either pass a single batch or a dataloader
        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        if data != None: #############
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.criterion = criterion

        self.init_v = None
        self.scaler = GradScaler()
        self.scaler2 = GradScaler(1) #init_scale=1)
        
    def ready(self, model):
        """
        model: the model that needs Hessain information
        criterion: the loss function
        data: a single batch of data, including inputs and its corresponding labels
        dataloader: the data loader including bunch of batches of data
        """

        # self.model = model.eval()  # make model is in evaluation model
        self.model = model.train()  # make model is in training model
        
        # pre-processing for single batch case to simplify the computation.
        assert not self.full_dataset
            
        self.inputs, self.targets = self.data
        if self.device == 'cuda':
            self.inputs, self.targets = self.inputs.cuda(), self.targets.cuda()

        # if we only compute the Hessian information for a single batch data, we can re-use the gradients.
        
        with autocast():
            outputs = self.model((self.inputs))
            loss = self.criterion(outputs, self.targets)
        gradsH = torch.autograd.grad(outputs=self.scaler.scale(loss),
                             inputs=self.model.module.parameters(),
                             create_graph = True)

        params = get_params(self.model)
        self.params = params
        
        inv_scale = 1./self.scaler.get_scale()        
        
        gradsH1 = []
        for i, g in enumerate(gradsH):
            gradsH1.append(torch.nan_to_num(g,nan=0.0)*inv_scale)
        self.gradsH = gradsH1
        

    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0
        
        self.total_iter = 0
        while computed_dim < top_n:
            if top_n>1:
                print('computed_dim:',computed_dim, end='\r')
            eigenvalue = None
            if self.init_v is None:
                v = [torch.randn(p.size()).to(device) for p in self.params
                    ]  # generate random vector
                v = normalization(v)  # norrmalize the vector
            else:
                ### already normalized
                v = self.init_v[computed_dim]
            
            if computed_dim == 0:
                miter = maxIter*2 ######### first evc is more important
            else:
                miter = maxIter
                
            for i in range(miter):
                print('inner hessian',i)
                if i == miter-1:
                    print('\n maxiter with tolerence %.4f'%(ratio), '---')
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                # if self.full_dataset:
                #     tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                # else:
                
                Hv = hessian_vector_product_ffcv(self.gradsH, self.params, v, self.scaler2)
                tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)
                
                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    ratio = abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6)
                    if ratio < tol:          
#                         print('\n within tolerence with iter', i, '---')
                        break
                    else:              
                        eigenvalue = tmp_eigenvalue              
            self.total_iter += i+1
            eigenvalues.append(eigenvalue)
            # print(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1
            
        pair_list = []
        for j, evl in enumerate(eigenvalues):
            pair_list.append((evl+0.0, eigenvectors[j]))
                             
        print('pair list', [i for (i,j) in pair_list])
        # pair_list.sort(reverse=True)
        
        evls = []
        evcs = []
        for (evl,evc) in pair_list:
            evls.append(evl)
            evcs.append(evc)
            
        self.init_v = evcs

        return evls, evcs
    
def hessian_vector_product_ffcv(gradsH, params, v, scaler):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    
    
    with autocast():
        loss = group_product(v, gradsH)
        
#     hv_ = torch.autograd.grad(outputs=scaler.scale(loss),
#                               inputs=params,
#                               retain_graph=True)
    
#     inv_scale = 1./scaler.get_scale()
    
#     hv1 = []
#     for i, g in enumerate(hv_):
#         hv1.append(torch.nan_to_num(g,nan=0.0)*inv_scale)
        
    hv_ = torch.autograd.grad(outputs=(loss),
                              inputs=params,
                              retain_graph=True)
    
    hv1 = []
    for i, g in enumerate(hv_):
        hv1.append(torch.nan_to_num(g,nan=0.0))
    
    return hv1
















































class new_ffcv_hessian(hessian):
    """
    The class used to compute :
        i) the top 1 (n) eigenvalue(s) of the neural network
        ii) the trace of the entire neural network
        iii) the estimated eigenvalue density
    """

    def __init__(self, criterion, data=None, dataloader=None, cuda=True):
        # make sure we either pass a single batch or a dataloader
        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        if data != None: #############
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.criterion = criterion

        self.init_v = None
        self.scaler = GradScaler()
        
    def ready(self, model):
        """
        model: the model that needs Hessain information
        criterion: the loss function
        data: a single batch of data, including inputs and its corresponding labels
        dataloader: the data loader including bunch of batches of data
        """

#         self.model = model.eval()  # make model is in evaluation model
        self.model = model.train()  # make model is in training model
        
        # pre-processing for single batch case to simplify the computation.
        assert not self.full_dataset
            
        self.inputs, self.targets = self.data
        if self.device == 'cuda':
            self.inputs, self.targets = self.inputs.cuda(), self.targets.cuda()

        # if we only compute the Hessian information for a single batch data, we can re-use the gradients.
        
        outputs = self.model((self.inputs).float()) #######################################
        loss = self.criterion(outputs, self.targets)
        
        gradsH = torch.autograd.grad(outputs=(loss), #######################
                             inputs=self.model.module.parameters(),
                             create_graph = True)
        
        params = get_params(self.model) ###########
            
        ######################################
        self.params = params  # not scaled (float)
        
        self.gradsH = gradsH
        

    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0
        
        self.total_iter = 0
        while computed_dim < top_n:
            if top_n>1:
                print('computed_dim:',computed_dim, end='\r')
            eigenvalue = None
            if self.init_v is None:
                v = [torch.randn(p.size()).to(device) for p in self.params
                    ]  # generate random vector
                v = normalization(v)  # norrmalize the vector
            else:
                ### already normalized
                v = self.init_v[computed_dim]
            
            if computed_dim == 0:
                miter = maxIter*2 ######### first evc is more important
            else:
                miter = maxIter
                
            for i in range(miter):
                print('inner hessian',i)
                if i == miter-1:
                    print('\n maxiter with tolerence %.4f'%(ratio), '---')
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                # if self.full_dataset:
                #     tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                # else:
                
                Hv = hessian_vector_product_ffcv(self.gradsH, self.params, v, self.scaler)
                tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)
                
                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    ratio = abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6)
                    if ratio < tol:          
#                         print('\n within tolerence with iter', i, '---')
                        break
                    else:              
                        eigenvalue = tmp_eigenvalue              
            self.total_iter += i+1
            eigenvalues.append(eigenvalue)
            # print(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1
            
        pair_list = []
        for j, evl in enumerate(eigenvalues):
            pair_list.append((evl+0.0, eigenvectors[j]))
                             
        print('pair list', [i for (i,j) in pair_list])
        # pair_list.sort(reverse=True)
        
        evls = []
        evcs = []
        for (evl,evc) in pair_list:
            evls.append(evl)
            evcs.append(evc)
            
        self.init_v = evcs

        return evls, evcs
    
        
