import torch
import numpy as np
import torch.nn as nn



class Hessian():

    def __init__(self, model, data=None, dataloader=None, cuda=True, norm=False):

        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)
       
        self.model = model.eval()
        self.criterion =nn.BCELoss(reduction='none')
        self.norm = norm

        if data != None:
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        
        if not self.full_dataset:
            # data : single batch
            
            outputs = self.model(self.data)
            loss = self.criterion(outputs['predictions'], self.data['labels']).mean()
            loss.backward(create_graph=True)    

            if hasattr(self.model, 'ehr_model'):
                self.params_ehr, self.grads_ehr = [], []
                for p in self.model.ehr_model.parameters():
                    if p.requires_grad and p.grad is not None:
                        self.params_ehr.append(p)
                        self.grads_ehr.append(p.grad.clone() if p.grad is not None else 0.0)
                self.grad_norm_ehr =  self.get_grad_norm(self.grads_ehr)

            if hasattr(self.model, 'cxr_model_spec'):
                self.params_cxr, self.grads_cxr = [], []
                for p in self.model.cxr_model_spec.parameters():
                    if p.requires_grad and p.grad is not None:
                        self.params_cxr.append(p)
                        self.grads_cxr.append(p.grad.clone() if p.grad is not None else 0.0)
                self.grad_norm_cxr =  self.get_grad_norm(self.grads_cxr)
            


        else:
            self.params_ehr = [p for p in self.model.ehr_model.parameters() if p.requires_grad]
            self.params_cxr = [p for p in self.model.cxr_model_spec.parameters() if p.requires_grad]

        #
    def get_grad_norm(self,grads):
        grad_norm_2 = group_product(grads, grads).cpu()
        grad_norm = grad_norm_2**0.5
        return grad_norm

    def dataloader_hv_product(self, v):

        device = self.device
        num_data = 0

        THv = [torch.zeros(p.size()).to(device) for p in self.params]
        for inputs, targets in self.data:
            self.model.zero_grad()
            tmp_num_data = inputs.size(0)
            outputs = self.model(inputs.to(device))
            loss = self.criterion(outputs, targets.to(device))
            loss.backward(create_graph=True)
            params, gradsH = get_params_grad(self.model)
            self.model.zero_grad()
            Hv = torch.autograd.grad(gradsH,
                                     params,
                                     grad_outputs=v,
                                     only_inputs=True,
                                     retain_graph=False)
            
            
            THv = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv, Hv)
            ]
            num_data += float(tmp_num_data)
       
        THv = [THv1 / float(num_data) for THv1 in THv]
        eigenvalue = group_product(THv, v).cpu().item()
        return eigenvalue, THv

    def dataloader_hv_product_modal(self, v_ehr, v_cxr):
        THv_ehr = [torch.zeros(p.size()).to(self.device) for p in self.params_ehr]
        THv_cxr = [torch.zeros(p.size()).to(self.device) for p in self.params_cxr]
        num_data = 0
        for batch in self.data:
            x, img, labels = batch['ehr_ts'], batch['cxr_imgs'], batch['labels']
            x, img, labels = x.to(self.device), img.to(self.device), labels.to(self.device)
            inputs = {'ehr_ts': x, 'cxr_imgs': img, 'labels': labels, 'seq_len': batch['seq_len'], 'has_cxr': batch['has_cxr']}

            self.model.zero_grad()
            tmp_num_data = inputs['ehr_ts'].size(0)
            outputs = self.model(inputs)
            loss = self.criterion(outputs['predictions'], labels).mean()
            loss.backward(create_graph=True)


           
            params_ehr, grads_ehr = [], []
            for p in self.model.ehr_model.parameters():
                if p.requires_grad and p.grad is not None:
                    params_ehr.append(p)
                    grads_ehr.append(p.grad)

            params_cxr, grads_cxr = [], []
            for p in self.model.cxr_model_spec.parameters():
                if p.requires_grad and p.grad is not None:
                    params_cxr.append(p)
                    grads_cxr.append(p.grad)
 
            Hv_ehr = torch.autograd.grad(grads_ehr, params_ehr, grad_outputs=v_ehr, only_inputs=True, retain_graph=True)
            Hv_cxr = torch.autograd.grad(grads_cxr, params_cxr, grad_outputs=v_cxr, only_inputs=True, retain_graph=False)

            THv_ehr = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv_ehr, Hv_ehr)
            ]
            THv_cxr = [
                THv1 + Hv1 * float(tmp_num_data) + 0.
                for THv1, Hv1 in zip(THv_cxr, Hv_cxr)
            ]
            num_data += float(tmp_num_data)

        THv_ehr = [THv1 / float(num_data) for THv1 in THv_ehr]
        THv_cxr = [THv1 / float(num_data) for THv1 in THv_cxr]
        return THv_ehr, THv_cxr


    def hv_product_modal(self, v_ehr, v_cxr):
        
        Hv_ehr = torch.autograd.grad(self.grads_ehr, self.params_ehr, grad_outputs=v_ehr, only_inputs=True, retain_graph=True)
        Hv_cxr = torch.autograd.grad(self.grads_cxr, self.params_cxr, grad_outputs=v_cxr, only_inputs=True, retain_graph=True)

        return Hv_ehr, Hv_cxr



    def hv_product_uni(self, v, gradsH, params):
        
        Hv = torch.autograd.grad(gradsH, params, grad_outputs=v, only_inputs=True, retain_graph=True)

        return Hv
        
    def eigenvalues_modal_strange(self, maxIter=100, tol=1e-3, top_n=1):    
        assert top_n >= 1

        device = self.device

        eigenvalues_ehr = []
        eigenvectors_ehr = []

        eigenvalues_cxr = []
        eigenvectors_cxr = []

        computed_dim = 0
        sample_ehr = True
        sample_cxr = True

        while computed_dim < top_n:
            eigenvalue_ehr = None
            eigenvalue_cxr = None

            # initial v and normalize
            v_ehr = [torch.randn(p.size()).to(device) for p in self.params_ehr]
            v_cxr = [torch.randn(p.size()).to(device) for p in self.params_cxr]
            v_ehr_norm = normalization(v_ehr)
            v_cxr_norm = normalization(v_cxr)


            for i in range(maxIter):
                if sample_ehr:
                    v_ehr_orth = orthnormal(v_ehr_norm, eigenvectors_ehr) # ortho previous eigenvectors
                if sample_cxr:
                    v_cxr_orth = orthnormal(v_cxr_norm, eigenvectors_cxr)
                self.model.zero_grad()

                # calculate in each batch
                Hv_ehr, Hv_cxr = self.hv_product_modal(v_ehr_orth, v_cxr_orth)

                if sample_ehr:
                    # calculte eigenvalue and eigenvector
                    #Hv_ehr = torch.autograd.grad(self.grads_ehr, self.params_ehr, grad_outputs=v_ehr_orth, only_inputs=True, retain_graph=True)
                    tmp_eigenvalue_ehr = group_product(Hv_ehr, v_ehr_orth).cpu().item()
                    v_ehr_for_egienvectors = normalization(Hv_ehr)
                if sample_cxr:
                    #Hv_cxr = torch.autograd.grad(self.grads_cxr, self.params_cxr, grad_outputs=v_cxr_orth, only_inputs=True, retain_graph=True)
                    tmp_eigenvalue_cxr = group_product(Hv_cxr, v_cxr_orth).cpu().item()
                    v_cxr_for_egienvectors = normalization(Hv_cxr)
               
               

                if eigenvalue_ehr is None: 
                    eigenvalue_ehr = tmp_eigenvalue_ehr
                    eigenvalue_cxr = tmp_eigenvalue_cxr
                else:
                    if abs(eigenvalue_ehr - tmp_eigenvalue_ehr) / (abs(eigenvalue_ehr) + 1e-6) < tol:
                        sample_ehr = False
                    else:
                        eigenvalue_ehr = tmp_eigenvalue_ehr

                    if abs(eigenvalue_cxr - tmp_eigenvalue_cxr) / (abs(eigenvalue_cxr) + 1e-6) < tol:
                        sample_cxr = False
                    else:
                        eigenvalue_cxr = tmp_eigenvalue_cxr


                    if (abs(eigenvalue_ehr - tmp_eigenvalue_ehr) / (abs(eigenvalue_ehr) + 1e-6) < tol and
                        abs(eigenvalue_cxr - tmp_eigenvalue_cxr) / (abs(eigenvalue_cxr) + 1e-6) < tol):
                        break
                        
            
            eigenvalues_ehr.append(eigenvalue_ehr)
            eigenvectors_ehr.append(v_ehr_for_egienvectors)
            eigenvalues_cxr.append(eigenvalue_cxr)
            eigenvectors_cxr.append(v_cxr_for_egienvectors)
            computed_dim += 1

        return eigenvalues_ehr,  eigenvalues_cxr, np.mean(eigenvalues_ehr), np.mean(eigenvalues_cxr)
 




    def eigenvalues_modal(self, maxIter=100, tol=1e-3, top_n=1):    
            """
            compute the top_n eigenvalues using power iteration method
            maxIter: maximum iterations used to compute each single eigenvalue
            tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
            top_n: top top_n eigenvalues will be computed
            """
            assert top_n >= 1

            device = self.device

            eigenvalues_ehr = []
            eigenvectors_ehr = []

            eigenvalues_cxr = []
            eigenvectors_cxr = []

            computed_dim = 0
            sample_ehr = True
            sample_cxr = True

            while computed_dim < top_n:
                eigenvalue_ehr = None
                eigenvalue_cxr = None

                v_ehr = [torch.randn(p.size()).to(device) for p in self.params_ehr]
                v_cxr = [torch.randn(p.size()).to(device) for p in self.params_cxr]

                v_ehr = normalization(v_ehr)
                v_cxr = normalization(v_cxr)


                for i in range(maxIter):
                    v_ehr = orthnormal(v_ehr, eigenvectors_ehr)
                    v_cxr = orthnormal(v_cxr, eigenvectors_cxr)
                    self.model.zero_grad()

                    # calculate in each batch
                    Hv_ehr, Hv_cxr = self.hv_product_modal(v_ehr, v_cxr)
                    if sample_ehr:
                        tmp_eigenvalue_ehr = group_product(Hv_ehr, v_ehr).cpu().item()
                        v_ehr = normalization(Hv_ehr)
                    if sample_cxr:
                        tmp_eigenvalue_cxr = group_product(Hv_cxr, v_cxr).cpu().item()
                        v_cxr = normalization(Hv_cxr)
                
                

                    if eigenvalue_ehr is None: 
                        eigenvalue_ehr = tmp_eigenvalue_ehr
                        eigenvalue_cxr = tmp_eigenvalue_cxr
                    else:
                        if abs(eigenvalue_ehr - tmp_eigenvalue_ehr) / (abs(eigenvalue_ehr) + 1e-6) < tol:
                            sample_ehr = False
                        else:
                            eigenvalue_ehr = tmp_eigenvalue_ehr

                        if abs(eigenvalue_cxr - tmp_eigenvalue_cxr) / (abs(eigenvalue_cxr) + 1e-6) < tol:
                            sample_cxr = False
                        else:
                            eigenvalue_cxr = tmp_eigenvalue_cxr


                        if (abs(eigenvalue_ehr - tmp_eigenvalue_ehr) / (abs(eigenvalue_ehr) + 1e-6) < tol and
                            abs(eigenvalue_cxr - tmp_eigenvalue_cxr) / (abs(eigenvalue_cxr) + 1e-6) < tol):
                            break
                            
                
                eigenvalues_ehr.append(eigenvalue_ehr)
                eigenvectors_ehr.append(v_ehr)
                eigenvalues_cxr.append(eigenvalue_cxr)
                eigenvectors_cxr.append(v_cxr)
                computed_dim += 1
            
            return eigenvalues_ehr,  eigenvalues_cxr, np.mean(eigenvalues_ehr), np.mean(eigenvalues_cxr)
 
    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0

        while computed_dim < top_n:
            eigenvalue = None
            v = [torch.randn(p.size()).to(device) for p in self.params]
            v = normalization(v)

            for i in range(maxIter):
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    Hv = hessian_vector_product(self.gradsH, self.params, v)
                    tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)

                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6) < tol:
                        break
                    else:
                        eigenvalue = tmp_eigenvalue
            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1

        return eigenvalues, eigenvectors

    def trace(self, maxIter=100, tol=1e-3):

        device = self.device
        trace_vhv = []
        trace = 0.

        for i in range(maxIter):
            self.model.zero_grad()

            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            
            for v_i in v:
                v_i[v_i == 0] = -1

            if self.full_dataset:
                _, Hv = self.dataloader_hv_product(v)
            else:
                Hv = hessian_vector_product(self.gradsH, self.params, v)
            trace_vhv.append(group_product(Hv, v).cpu().item())
            if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
                return trace_vhv
            else:
                trace = np.mean(trace_vhv)
        return trace_vhv
    
    def trace_modal(self, maxIter=100, tol=1e-3):

        device = self.device
        trace_vhv_ehr = []
        trace_vhv_cxr = []
        trace_ehr = 0.
        trace_cxr = 0.
        sample_ehr = True
        sample_cxr = True



        for i in range(maxIter):
            self.model.zero_grad()
            v_ehr = [torch.randint_like(p, high=2, device=device) for p in self.params_ehr]
            v_cxr = [torch.randint_like(p, high=2, device=device) for p in self.params_cxr]

            for v_i in v_ehr:
                v_i[v_i == 0] = -1 

            for v_i in v_cxr:
                v_i[v_i == 0] = -1

            if self.full_dataset:
                THv_ehr, THv_cxr = self.dataloader_hv_product_modal(v_ehr, v_cxr)
            else:
                THv_ehr, THv_cxr = self.hv_product_modal(v_ehr, v_cxr)

                     
            if sample_ehr:
                # if self.norm:
                #     trace_vhv_ehr.append(group_product(THv_ehr, v_ehr).cpu().item() / self.grad_norm_ehr.cpu().item())
                # else:
                trace_vhv_ehr.append(group_product(THv_ehr, v_ehr).cpu().item())
            if sample_cxr:
                # if self.norm:
                #     trace_vhv_cxr.append(group_product(THv_cxr, v_cxr).cpu().item() / self.grad_norm_cxr.cpu().item())
                # else:
                trace_vhv_cxr.append(group_product(THv_cxr, v_cxr).cpu().item())

           
            if abs(np.mean(trace_vhv_ehr) - trace_ehr) / (abs(trace_ehr) + 1e-6) < tol:
                # trace_vhv_ehr stop sample
                sample_ehr = False
            else:
                trace_ehr = np.mean(trace_vhv_ehr)
            
            if abs(np.mean(trace_vhv_cxr) - trace_cxr) / (abs(trace_cxr) + 1e-6) < tol:
                # trace_vhv_cxr stop sample
                sample_cxr = False
            else:
                trace_cxr = np.mean(trace_vhv_cxr)
            
            if sample_ehr == False and sample_cxr == False:
                break

            #print(f"trace_ehr is {trace_ehr}, trace_cxr is {trace_cxr}")
        self.model.zero_grad()
        # expection
        if self.norm:
            trace_ehr = np.mean(trace_vhv_ehr) / self.grad_norm_ehr
            trace_cxr = np.mean(trace_vhv_cxr) / self.grad_norm_cxr
        else:
            trace_ehr = np.mean(trace_vhv_ehr)
            trace_cxr = np.mean(trace_vhv_cxr)
        return trace_ehr, trace_cxr

    def density(self, iter=100, n_v=1):

        device = self.device
        eigen_list_full = []
        weight_list_full = []

        for k in range(n_v):
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            for v_i in v:
                v_i[v_i == 0] = -1
            v = normalization(v)

            v_list = [v]
            w_list = []
            alpha_list = []
            beta_list = []
            for i in range(iter):
                self.model.zero_grad()
                w_prime = [torch.zeros(p.size()).to(device) for p in self.params]
                if i == 0:
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w = group_add(w_prime, v, alpha=-alpha)
                    w_list.append(w)
                else:
                    beta = torch.sqrt(group_product(w, w))
                    beta_list.append(beta.cpu().item())
                    if beta_list[-1] != 0.:
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    else:
                        w = [torch.randn(p.size()).to(device) for p in self.params]
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w_tmp = group_add(w_prime, v, alpha=-alpha)
                    w = group_add(w_tmp, v_list[-2], alpha=-beta)

            T = torch.zeros(iter, iter).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
            a_, b_ = torch.eig(T, eigenvectors=True)

            eigen_list = a_[:, 0]
            weight_list = b_[0, :]**2
            eigen_list_full.append(list(eigen_list.cpu().numpy()))
            weight_list_full.append(list(weight_list.cpu().numpy()))

        return eigen_list_full, weight_list_full
    
    def eigenvalues_uni(self, model_name, maxIter=100, tol=1e-3, top_n=1):
        """
        unimodal eigenvalues
        """
        assert top_n >= 1

        device = self.device
       
        eigenvalues = []
        eigenvectors = []
        params = self.params_ehr if model_name == 'ehr' else self.params_cxr
        gradsH = self.grads_ehr if model_name == 'ehr' else self.grads_cxr
    
        computed_dim = 0

        while computed_dim < top_n:
            eigenvalue = None
            sample = True

            v = [torch.randn(p.size()).to(device) for p in params]
            v = normalization(v)

            for i in range(maxIter):
                self.model.zero_grad()
                if not sample:
                    break
                v = orthnormal(v, eigenvectors)
               
                Hv = self.hv_product_uni(v,gradsH,params)
                tmp_eigenvalue = group_product(Hv, v).cpu().item()
                v = normalization(Hv)

                if eigenvalue is None:  
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol:
                        sample = False
                    else:
                        eigenvalue = tmp_eigenvalue  

            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1
        #print(f"eigenvalues is {eigenvalues}")
        return np.mean(eigenvalues)

    def trace_uni(self, model_name, maxIter=100, tol=1e-3):
        device = self.device
        trace_vhv = []
        trace_value = 0.
        sample = True

        
        params = self.params_ehr if model_name == 'ehr' else self.params_cxr
        grads = self.grads_ehr if model_name == 'ehr' else self.grads_cxr
        grad_norm = self.grad_norm_ehr if model_name == 'ehr' else self.grad_norm_cxr

        for i in range(maxIter):
            self.model.zero_grad()
            if not sample:
                break
            
            v = [torch.randint_like(p, high=2, device=device) for p in params]
            
           
            for v_i in v:
                v_i[v_i == 0] = -1

           
            if self.full_dataset:
                pass
            else:
                Hv = hessian_vector_product(grads, params, v)

           
            trace_vhv.append(group_product(Hv, v).cpu().item())

            
            if abs(np.mean(trace_vhv) - trace_value) / (abs(trace_value) + 1e-6) < tol:
                sample = False
            else:
                trace_value = np.mean(trace_vhv)

        self.model.zero_grad()
        
       
        if self.norm:
            final_trace = np.mean(trace_vhv) / grad_norm.cpu().item()
        else:
            final_trace = np.mean(trace_vhv)

        return final_trace



def group_product(xs, ys):
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])


def group_add(params, update, alpha=1):
    for i, p in enumerate(params):
        params[i].data.add_(update[i] * alpha)
    return params


def normalization(v):
    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item()
    v = [vi / (s + 1e-6) for vi in v]
    return v


def get_params_grad(model):
    params = []
    grads = []
    for param in model.parameters():
        if not param.requires_grad:
            continue
        params.append(param)
        grads.append(0. if param.grad is None else param.grad + 0.)
    return params, grads


def hessian_vector_product(gradsH, params, v):
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv


def orthnormal(w, v_list):
    for v in v_list:
        w = group_add(w, v, alpha=-group_product(w, v))
    return normalization(w)