import torch
import copy
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm


class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.epochs_no_improve = 0
        self.early_stop = False

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.epochs_no_improve += 1
            if self.epochs_no_improve >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.epochs_no_improve = 0


def load_datasets(dataset_name):
    if dataset_name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        train_dataset = torchvision.datasets.CIFAR10(root='autodl-tmp/Sequential-Unlearnig-main/datas', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR10(root='autodl-tmp/Sequential-Unlearnig-main/datas', train=False, download=True, transform=transform)
        num_classes = 10
    elif dataset_name == 'cifar100':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        train_dataset = torchvision.datasets.CIFAR100(root='autodl-tmp/Sequential-Unlearnig-main/datas', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR100(root='autodl-tmp/Sequential-Unlearnig-main/datas', train=False, download=True, transform=transform)
        num_classes = 100
    elif dataset_name == 'celeba':
        transform = transforms.Compose([
            transforms.CenterCrop(178),
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = torchvision.datasets.CelebA(root='autodl-tmp/Sequential-Unlearnig-main/datas', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.CelebA(root='autodl-tmp/Sequential-Unlearnig-main/datas', split='test', download=True, transform=transform)
        num_classes = 2  # Assuming binary classification for CelebA attributes
    elif dataset_name == 'mini-fashion':
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.FashionMNIST(root='autodl-tmp/Sequential-Unlearnig-main/datas', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.FashionMNIST(root='autodl-tmp/Sequential-Unlearnig-main/datas', train=False, download=True, transform=transform)
        num_classes = 10
    else:
        raise ValueError("Unsupported dataset")

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
    return train_dataset, train_loader, test_loader, num_classes

def create_pseudo_labels(num_classes, batch_size):
    return torch.ones(batch_size, num_classes) / num_classes

def evaluate_accuracy(model, dataloader):
    was_training = model.training
    model.eval()
    correct = 0
    total = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluation progress"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    model.train(mode=was_training)
    return correct / total


def compute_epoch_grad(loader, model, loss_fn, loader_type = None, pseudo_labels = None, flat_type = True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    total_loss = torch.tensor(0.0).to(device)
    total_samples = torch.tensor(0.0).to(device)

    if loader_type is None:
        for inputs, targets in tqdm(loader, desc="Compute epoch grad process"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            batch_loss_sum = loss_fn(outputs, targets) * inputs.size(0)
            total_loss += batch_loss_sum
            total_samples += inputs.size(0)
    else:
        for inputs, _ in tqdm(loader, desc="Compute grad process"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            batch_loss_sum = loss_fn(outputs, pseudo_labels[:outputs.size(0)]) * inputs.size(0)
            total_loss += batch_loss_sum
            total_samples += inputs.size(0)

    flat_grad = compute_grad(total_loss, model.parameters(), flat_type)
    # flat_grad /= total_samples
    return flat_grad, total_loss


def compute_epoch_hessian_vector_product(model, dataloader, loss_fn, v, num_samples=100, compute_flat_type = True, return_flat_type = True):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    count = 0
    total_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Compute epoch hessian vector product process"):
        if count >= num_samples:
            break
        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        total_loss += loss_fn(outputs, targets) * inputs.size(0)
        count += inputs.size(0)
    
    flat_grad_grad_v = compute_hessian_vector_product(total_loss, model.parameters(), v, compute_flat_type, return_flat_type)
    # flat_grad_grad_v /= count

    return flat_grad_grad_v


def compute_epoch_hessian(model, dataloader, loss_fn, num_samples=100):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    count = 0
    total_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Compute epoch hessian process"):
        if count >= num_samples:
            break
        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        total_loss += loss_fn(outputs, targets) * inputs.size(0)
        count += inputs.size(0)
    
    hessiain = compute_hessian(total_loss, model.parameters())
    # flat_grad_grad_v /= count

    return hessiain


def compute_grad(loss, param, flat_type = True):
    grads = torch.autograd.grad(loss, param, allow_unused=True, retain_graph=True)
    grads = [g for g in grads if g is not None]
    if flat_type:
        flat_grad = list(grads)
        flat_grad = torch.cat([g.view(-1) for g in flat_grad])
    else:
        flat_grad = grads
    return flat_grad


def compute_hessian_vector_product(loss, param, v, compute_flat_type=True, return_flat_type=True):
    # Save original parameters and their shapes
    original_parameters = list(param)
    original_shapes = [p.shape for p in original_parameters]

    # Flatten parameters
    flattened_parameters = torch.cat([p.view(-1) for p in original_parameters])
    
    # Function to restore the original shapes
    def restore_original_shapes(flattened_params, shapes):
        restored_params = []
        offset = 0
        for shape in shapes:
            num_elements = torch.prod(torch.tensor(shape)).item()
            restored_params.append(flattened_params[offset:offset + num_elements].view(shape))
            offset += num_elements
        return restored_params
    
    if not torch.is_tensor(loss) or loss.dim() != 0:
        raise ValueError("loss must be a scalar tensor.")
        
    print(f'param length: {len(flattened_parameters)}, param: {flattened_parameters}')

    # First backward pass to get gradients
    first_grads = torch.autograd.grad(loss, flattened_parameters, create_graph=True, allow_unused=True)
    flat_first_grads = list(first_grads)
    print(f'1---------first_grads length: {len(flat_first_grads)}, first_grads: {flat_first_grads}')
    
    # Handling None gradients if any variable does not affect loss
    flat_first_grads = [torch.zeros_like(p) if g is None else g for p, g in zip(flattened_parameters, flat_first_grads)]
    print(f'2---------first_grads length: {len(flat_first_grads)}, first_grads: {flat_first_grads}')
    
    # Check if first_grads is empty
    if not any(g.nelement() > 0 for g in first_grads):
        raise RuntimeError("No valid gradients were computed. Check the dependency of the loss on the parameters.")
    
    if compute_flat_type:
        flat_first_grads = torch.cat([g.view(-1) for g in flat_first_grads])
        print(f'3---------first_grads length: {len(flat_first_grads)}, first_grads: {flat_first_grads}')
        # Element-wise products of first gradients and v
        grad_v_products = first_grads * v
        print(f'grad_v_products length: {len(grad_v_products)}, grad_v_products: {grad_v_products}')
    else:
        grad_v_products = [g * v_i for g, v_i in zip(first_grads, v)]

    # Second backward pass to get Hessian-vector product
    hessian_v = torch.autograd.grad(grad_v_products, flattened_parameters, allow_unused=True)
    
    # Handling None in the second derivatives
    hessian_v = [torch.zeros_like(p) if h is None else h for p, h in zip(flattened_parameters, hessian_v)]
    
    if return_flat_type:
        hessian_v = torch.cat([h.contiguous().view(-1) for h in hessian_v])

    return hessian_v


def compute_hessian(loss, param):
    """
    Compute the Hessian matrix of `loss` with respect to `param)`.

    Args:
    loss : torch.Tensor
        Output tensor for which the Hessian is calculated.
    param) : torch.Tensor
        Input tensor with respect to which the Hessian is calculated.

    Returns:
    torch.Tensor
        The Hessian matrix of `loss` with respect to `param`.
    """
    if not param.requires_grad:
        raise ValueError("Input tensor `param` must require gradients.")

    n = param.nelement()  # Number of elements in param
    hessian = torch.zeros((n, n), dtype=param.dtype, device=param.device)
    grad_loss = torch.autograd.grad(loss, param, create_graph=True, retain_graph=True)[0]
    
    for i in range(n):
        grad_grad_loss = torch.autograd.grad(grad_loss[i], param, retain_graph=True)[0]
        hessian[i] = grad_grad_loss

    return hessian


def clone_and_freeze_model(model):
    """
    克隆一个模型并冻结其所有参数。
    
    Args:
    - model (torch.nn.Module): 要克隆和冻结的模型。
    
    Returns:
    - torch.nn.Module: 克隆并冻结参数的模型副本。
    """
    # 使用深拷贝来克隆模型
    model_copy = copy.deepcopy(model)
    
    # 冻结模型中的所有参数
    for param in model_copy.parameters():
        param.requires_grad = False
    
    return model_copy


def split_indices(indices, T, split_type, num_classes=None, train_dataset=None, cover_flag=False, max_per=None):
    np.random.shuffle(indices)
    n = len(indices)

    if split_type == 'uniform':
        # 平均分割索引
        return indices, np.array_split(indices, T)
    elif split_type == 'random':
        # 随机选择不同大小的索引子集
        sizes = np.random.randint(1, len(indices), T - 1)
        sizes = np.sort(sizes)
        sizes = np.diff(np.concatenate(([0], sizes, [len(indices)])))
        return indices, [indices[sum(sizes[:i]):sum(sizes[:i+1])] for i in range(T)]
    elif split_type == 'fixed':
        if max_per is None:
            raise ValueError("Given a max_per to specify the number of each group.")
        if n < max_per * T:
            raise ValueError("Not enough data to fill all subsets with the specified sub_num size.")
        indices = indices[:max_per * T]  # Truncate to make it divisible
        subsets = [indices[i * max_per:(i + 1) * max_per] for i in range(T)]
        return indices, subsets  
    elif split_type == 'class':
        if num_classes is None or train_dataset is None:
            raise ValueError("num_classes and train_dataset must be provided for class based split type.")
        # 按类别划分，确保每个子集包含相同数量的类别
        class_indices = {i: [] for i in range(num_classes)}
        for idx in indices:
            # 假设第二项是标签
            _, label = train_dataset[idx]
            if max_per is None:
                class_indices[label].append(idx)
            else:
                if len(class_indices[label]) < max_per:
                    class_indices[label].append(idx)
        # 将类别分组到子集中
        subsets = [[] for _ in range(T)]
        if cover_flag:
            classes_per_subset = num_classes // T
        else:
            classes_per_subset = 1
        for i, class_idx in enumerate(sorted(class_indices.keys())):
            subset_idx = (i // classes_per_subset) % T
            subsets[subset_idx].extend(class_indices[class_idx])
        return indices, [np.array(subset) for subset in subsets]
    else:
        raise ValueError("Unknown split_type provided.")