import torch
import copy
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset, TensorDataset
from collections import defaultdict
from tqdm import tqdm
import sys
import os
sys.path.append(os.path.abspath('autodl-tmp/Sequential-Unlearnig-main'))
from utils import EarlyStopping, evaluate_accuracy, clone_and_freeze_model
from models import CustomModel


def split_dataset_by_class(dataset, num_classes_per_subset):
    class_indices = defaultdict(list)
    for idx, (_, target) in enumerate(dataset):
        class_indices[target].append(idx)
    
    subsets = []
    temp = []
    for class_idx, indices in class_indices.items():
        temp.extend(indices)
        if len(temp) >= num_classes_per_subset:
            subsets.append(Subset(dataset, temp))
            temp = []
    if temp:
        subsets.append(Subset(dataset, temp))  # Add remaining classes if any
    return subsets

def aggregate_model_outputs(models, batch_inputs):
    # total_outputs = []
    with torch.no_grad():
        outputs = [model(batch_inputs) for model in models]
        avg_outputs = torch.mean(torch.stack(outputs), dim=0)
        # total_outputs.append(avg_outputs)
    return torch.tensor(avg_outputs)

def remove_sample(subsets, indices):
    idx_to_remove = set(indices)
    for index, subset in enumerate(subsets):
        # 检查这个子集是否包含需要删除的索引
        common_indices = idx_to_remove.intersection(set(subset.indices))
        if common_indices:
            # 过滤掉需要删除的索引
            filtered_indices = [idx for idx in subset.indices if idx not in idx_to_remove]
            if filtered_indices:
                # 创建一个新的Subset对象
                updated_subset = Subset(subset.dataset, filtered_indices)
                subsets[index] = updated_subset
                print(f'The subset where the deleted index is located at: subset-{index}, index has been deleted.')
            return index
    return False


def sisa_baseline(train_dataset, indices, subset_indexs, T, eta, batch_size, 
                  epochs, num_classes, model_type, pretrained_tag=True, early_stopping=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    indices = indices
    Rt_indices = copy.deepcopy(indices)
    prev_F_t_1_indices = []

    subsets = split_dataset_by_class(train_dataset, 2)  # 每2类分为一个子集
    for i in subsets:
        print(f'{i}-th subset length: {len(subsets[i])}')
    models = []

    def train_model(model, loader):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=eta, momentum=0.9, weight_decay=5e-4)
        model.train()

        if early_stopping_tag:
            early_stopping = EarlyStopping(patience=10, verbose=True, delta=0.01)

        for epoch in range(epochs):
            total_epoch_loss = torch.tensor(0.0).to(device)
            samples = 0
            for inputs, targets in tqdm(loader, desc="Training progress"):
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                total_epoch_loss += loss_fn(outputs, targets) * inputs.size(0)
                samples += inputs.size(0)
                loss.backward()
                optimizer.step()
            mean_epoch_loss = total_epoch_loss / samples
            print(f'SISA train sub-model----------epoch: {epoch}---------total epoch loss: {total_epoch_loss}---------mean epoch loss: {mean_epoch_loss}')
            
            if early_stopping_tag:
                early_stopping(total_epoch_loss.item())
                if early_stopping.early_stop:
                    print(f'Early stopping at epoch: {epoch}')
                    break
        return model
    
    def evaluate_models_accuracy(models, loader):
        correct = 0
        total = 0
        device = next(models[0].parameters()).device

        with torch.no_grad():
            for inputs, targets in tqdm(loader, desc="Evaluation progress"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = aggregate_model_outputs(models, inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total

    # original train sub-model
    for index, subset in enumerate(subsets):
        sub_model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=pretrained_tag).to(device)
        subset_loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
        print(f'-----------------------------Begin train {index}-model-----------------------------')
        trained_model = train_model(sub_model, subset_loader)
        print(f'-----------------------------{index}-model training complete-----------------------------')
        models.append(trained_model)

    for t in range(1, T+1):
        # Ft_indices = indices[(t-1)*1000 : t*1000]
        Ft_indices = subset_indexs[t-1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])
        index = remove_sample(subsets, Ft_indices)

        # 重新训练对应的模型
        new_sub_model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=pretrained_tag).to(device)
        new_subset_loader = DataLoader(subsets[index], batch_size=batch_size, shuffle=True)
        models[index] = train_model(new_sub_model, new_subset_loader)

        # evaluate current models

        # evaluate
        print(f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len( Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=False)
        Acc_Ft = evaluate_models_accuracy(models, Ft_loader)
        
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_models_accuracy(models, Rt_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            Acc_F_t_1 = evaluate_models_accuracy(models, DataLoader(Subset(train_dataset, prev_F_t_1_indices), batch_size=128, shuffle=False))
        
        prev_F_t_1_indices.extend(Ft_indices)
        
        print(f"Time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}")

    return 0