import torch
import copy
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import sys
import os
sys.path.append(os.path.abspath('fxh/seq_unlearn'))
from utils import compute_grad, compute_epoch_hessian_vector_product, compute_epoch_hessian, compute_hessian, evaluate_accuracy
from models import CustomModel
import time


def update_model_parameters(model, param_updates, hyper_param=1.0):
    device = next(model.parameters()).device

    with torch.no_grad():
        for param, update in zip(model.parameters(), param_updates):
            param.add_(hyper_param * update.to(device))
            

def influence_baseline(train_dataset, indices, subset_indexs, T, eta, batch_size, epochs, num_classes, 
                       model_type, resume_model_path, save_interval=1, device="cuda:0", save_path="", test_loader=None):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    baseline_model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=False, model_path=resume_model_path).to(device)
    loss_fn = nn.CrossEntropyLoss()

    indices = indices
    Rt_indices = copy.deepcopy(indices)
    prev_F_t_1_indices = []

    n = len(indices)
    lambda_val = 0.1
    
    # going_type indicates the form of the method implementation, one for the ideal case and one for the actual case
    going_type = 'real-world'
    if going_type == 'dream':
        epsilon = 0.01
        delta = 0.1
        M = 1
        L = 1
        m = 128
        
        # compute γ and σ
        gamma = (2 * M * m**2 * L**2) / (lambda_val**3 * n**2)
        sigma = (gamma / epsilon) * torch.sqrt(2 * torch.log(torch.tensor(1.25 / delta)))

    T_opt_models = []
        
    baseline_model.train()
    for t in range(1, T+1):
        start = time.time()
        # get unlearning samples id
        Ft_indices = subset_indexs[t-1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])
        m = len(Ft_indices)

        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=True)
        total_loader = DataLoader(Subset(train_dataset, indices), batch_size=batch_size, shuffle=True)
        
        # incluence-train, do not need epochs
        base_loss = torch.tensor(0.0).to(device)
        total_samples = 0
        for inputs, targets in tqdm(Ft_loader, desc="Training progress"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = baseline_model(inputs)
            base_loss += loss_fn(outputs, targets) * inputs.size(0)
            total_samples += inputs.size(0)
        l2_norm_loss = sum(p.pow(2.0).sum() for p in baseline_model.parameters())
        loss_total = base_loss + 0.5 * lambda_val * l2_norm_loss
        loss_mean = loss_total / total_samples
        print(f'Influence baseline model R{t}----------time: {t}---------base loss: {base_loss}---------l2 loss: {l2_norm_loss}---------total epoch loss: {loss_total}')
        
        loss_compute_type = 'mean'
        if going_type == 'dream':
            F_grad_grad = compute_epoch_hessian(total_loader, baseline_model, loss_fn, loss_type = loss_compute_type, num_samples=len(Ft_indices))
            baseline_model.zero_grad()
            if loss_compute_type == 'sum':
                f_grad = compute_grad(loss_total, list(baseline_model.parameters()))
                baseline_model.zero_grad()
                f_grad_grad = compute_hessian(loss_total, list(baseline_model.parameters()))
                baseline_model.zero_grad()
            else:
                f_grad = compute_grad(loss_mean, list(baseline_model.parameters()))
                baseline_model.zero_grad()
                f_grad_grad = compute_hessian(loss_mean, list(baseline_model.parameters()))
                baseline_model.zero_grad()
            inverse_H_hat = [(1 / (n - m)) * torch.inverse(F_gg - f_gg) for F_gg, f_gg in zip(F_grad_grad, f_grad_grad)]
            param_update = [(1 / (n - m)) * torch.matmul(ihh - fg) for ihh, fg in zip(inverse_H_hat, f_grad)]
        else:
            if loss_compute_type == 'sum':
                f_grad = compute_grad(loss_total, list(baseline_model.parameters()))
            else:
                f_grad = compute_grad(loss_mean, list(baseline_model.parameters()))
            baseline_model.zero_grad()
            # hvp = compute_epoch_hessian_vector_product(baseline_model, total_loader, loss_fn, f_grad, loss_type = loss_compute_type, num_samples=int(0.3 * len(indices)))
            hvp = compute_epoch_hessian_vector_product(baseline_model, total_loader, loss_fn, f_grad, loss_type = loss_compute_type, num_samples=len(indices))
            param_update = [(1 / (n - m)) * grad for grad in hvp]

        # update model param
        update_model_parameters(baseline_model, param_update, hyper_param=1.0)

        # Optional: sample Gaussian Noise
        # noise = torch.normal(0, sigma, size=(d,))

        end = time.time()

        efficiency = end - start
        
        # save current time optimal model
        if t % save_interval == 0:
            T_opt_models.append(copy.deepcopy(baseline_model).cpu())
            print(f'Optimal model of time {t} has saved.')

        # evaluate
        print(f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len( Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_accuracy(baseline_model, Rt_loader)
        Acc_Ft = evaluate_accuracy(baseline_model, Ft_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            Acc_F_t_1 = evaluate_accuracy(baseline_model, DataLoader(Subset(train_dataset, prev_F_t_1_indices), batch_size=128, shuffle=False))
        
        prev_F_t_1_indices.extend(Ft_indices)

        Acc_test = evaluate_accuracy(baseline_model, test_loader)

        print(f"Influence time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}, Acc_test: {Acc_test:.4f}, time: {efficiency:.4f}")
    
    current_model_path = save_path + '-influence' + '.pth'
    torch.save(T_opt_models, current_model_path)
    return T_opt_models