from time import time
from collections import defaultdict
import pandas as pd
import pickle, os
import torch
from tqdm import tqdm
import pdb

class IFEngine(object):
    def __init__(self, device="cuda"):
        self.time_dict = defaultdict(list)
        self.hvp_dict = defaultdict(list)
        self.IF_dict = defaultdict(list)
        self.device = device

    def preprocess_gradients(self, tr_grad_dict, val_grad_dict):
        self.tr_grad_dict = tr_grad_dict
        self.val_grad_dict = val_grad_dict

        self.n_train = len(self.tr_grad_dict.keys())
        self.n_train_torch = torch.tensor(
            len(self.tr_grad_dict.keys()), dtype=torch.int
        ).to(self.device)
        self.n_val = len(self.val_grad_dict.keys())
        self.compute_val_grad_avg()

    def compute_val_grad_avg(self):
        # Compute the avg gradient on the validation dataset
        self.val_grad_avg_dict = {}
        for weight_name in self.val_grad_dict[0]:
            self.val_grad_avg_dict[weight_name] = torch.zeros(
                self.val_grad_dict[0][weight_name].shape
            )
            for val_id in self.val_grad_dict:
                self.val_grad_avg_dict[weight_name] += (
                    self.val_grad_dict[val_id][weight_name] / self.n_val
                )


    def compute_hvps_all(self, lambda_const_param=10):
        print("Computing HVP for identity method")
        self.compute_hvp_identity()
        print("Computing HVP for proposed method")
        self.compute_hvp_proposed_all(lambda_const_param=lambda_const_param)


    def compute_hvps(self, lambda_const_param=10):
        print("Computing HVP for identity method")
        self.compute_hvp_identity()
        print("Computing HVP for proposed method")
        self.compute_hvp_proposed(lambda_const_param=lambda_const_param)

    def compute_hvp_identity(self):
        start_time = time()
        self.hvp_dict["identity"] = self.val_grad_dict.copy()
        self.time_dict["identity"] = time() - start_time

    def compute_hvp_proposed_all(self, lambda_const_param=10):
        start_time = time()
        hvp_proposed_dict=defaultdict(dict)
        for val_id in tqdm(self.val_grad_dict.keys()):
            for weight_name in self.val_grad_dict[val_id]:
                # lambda_const computation
                S=torch.zeros(len(self.tr_grad_dict.keys()))
                for tr_id in self.tr_grad_dict:
                    tmp_grad = self.tr_grad_dict[tr_id][weight_name]
                    S[tr_id]=torch.mean(tmp_grad**2)
                lambda_const = torch.mean(S) / lambda_const_param # layer-wise lambda

                # hvp computation
                hvp=torch.zeros(self.val_grad_dict[val_id][weight_name].shape)
                for tr_id in self.tr_grad_dict:
                    tmp_grad = self.tr_grad_dict[tr_id][weight_name]
                    C_tmp = torch.sum(self.val_grad_dict[val_id][weight_name] * tmp_grad) / (lambda_const + torch.sum(tmp_grad**2))
                    hvp += (self.val_grad_dict[val_id][weight_name] - C_tmp*tmp_grad) / (self.n_train*lambda_const)
                hvp_proposed_dict[val_id][weight_name] = hvp
        self.hvp_dict['proposed'] = hvp_proposed_dict
        self.time_dict['proposed'] = time()-start_time

    def compute_hvp_accurate_all(self, lambda_const_param=10):
        start_time = time()
        hvp_accurate_dict = defaultdict(dict)
        for val_id in tqdm(self.val_grad_dict.keys()):
            for weight_name in self.val_grad_avg_dict:
                hvp = None
                for r in range(len(self.tr_grad_dict[0][weight_name])):
                    # lambda_const computation
                    S = torch.zeros(len(self.tr_grad_dict.keys()))
                    for tr_id in self.tr_grad_dict:
                        tmp_grad = self.tr_grad_dict[tr_id][weight_name]
                        S[tr_id] = torch.mean(tmp_grad**2)
                    lambda_const = (
                        torch.mean(S) / lambda_const_param
                    )  # layer-column-wise lambda

                    # hvp computation (eigenvalue decomposition)
                    AAt_matrix = torch.zeros(
                        torch.outer(
                            self.tr_grad_dict[0][weight_name][0],
                            self.tr_grad_dict[0][weight_name][0],
                        ).shape
                    )
                    for tr_id in self.tr_grad_dict:
                        tmp_mat = torch.outer(
                            self.tr_grad_dict[tr_id][weight_name][r],
                            self.tr_grad_dict[tr_id][weight_name][r],
                        )
                        AAt_matrix += tmp_mat

                    L, V = torch.linalg.eig(AAt_matrix)
                    L, V = L.float(), V.float()
                    hvp_tmp = self.val_grad_dict[val_id][weight_name][[r]] @ V
                    hvp_tmp = (hvp_tmp / (lambda_const + L / self.n_train)) @ V.T
                    if hvp is not None:
                        hvp = torch.cat((hvp, hvp_tmp))
                    else:
                        hvp = hvp_tmp
                # hvp_accurate_dict[weight_name] = hvp
                hvp_accurate_dict[val_id][weight_name] = hvp
                del tmp_mat, AAt_matrix, V  # to save memory
        self.hvp_dict["accurate"] = hvp_accurate_dict
        self.time_dict["accurate"] = time() - start_time

    def compute_IF_all(self):
        for method_name in self.hvp_dict:
            print("Computing IF for method: ", method_name)
            if_tmp_dict = defaultdict(dict)
            for tr_id in tqdm(self.tr_grad_dict):
                for val_id in self.val_grad_dict:
                    if_tmp_value = 0
                    for weight_name in self.val_grad_dict[val_id]:
                        if_tmp_value += torch.sum(self.hvp_dict[method_name][val_id][weight_name]*self.tr_grad_dict[tr_id][weight_name])
                    if_tmp_dict[tr_id][val_id]=if_tmp_value

            self.IF_dict[method_name] = pd.DataFrame(if_tmp_dict, dtype=float)   

    def save_result(self, run_name=""):
        results = {}
        results["runtime"] = self.time_dict
        results["influence"] = self.IF_dict

        if not os.path.exists(f"../results/run_{run_name}"):
            # Create a new directory because it does not exist
            os.makedirs(f"../results/run_{run_name}")

        with open(f"../results/run_{run_name}/results.pkl", "wb") as file:
            pickle.dump(results, file)
