import torch
import copy
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
from multiprocessing import Process, Manager
import multiprocessing as mp
import sys
import os

sys.path.append(os.path.abspath('fxh/seq_unlearn'))
from utils import EarlyStopping, split_indices
from models import CustomModel
import time


def select_dataset(num_model, train_dataset, subsets_index):
    subsets = []
    for i in range(num_model):
        subsets.append(Subset(train_dataset, subsets_index[i]))
        print(f'{i}-th subset length: {len(subsets[i])}')
    return subsets


def train_model(model, loader, mask, epochs, lr, device, model_id, results, early_stopping_tag=True):
    if not mask:
        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

        if early_stopping_tag:
            early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)

        model.train()
        for epoch in range(epochs):
            total_epoch_loss = 0.0
            samples = 0
            for inputs, targets in loader:
                if inputs.shape[0] == 1:
                    continue
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                total_epoch_loss += loss * inputs.size(0)
                samples += inputs.size(0)
                loss.backward()
                optimizer.step()

            mean_epoch_loss = total_epoch_loss / samples
            print(f'Model {model_id}: Epoch {epoch} - Loss: {mean_epoch_loss}')

            if early_stopping_tag:
                early_stopping(mean_epoch_loss.item())
                if early_stopping.early_stop:
                    print(f'Early stopping at epoch: {epoch}')
                    break
    model.cpu()
    results[model_id] = model.state_dict()


def parallel_train_model(sub_models, subsets, masks, epochs, lr, batch_size, device, early_stopping_tag=True):
    # Training each model in a separate process
    processes = []
    manager = Manager()
    results = manager.dict()

    for i in range(len(subsets)):
        sub_model = sub_models[i].to(device)
        subset_loader = DataLoader(subsets[i], batch_size=batch_size, shuffle=True)
        p = Process(target=train_model, args=(sub_model, subset_loader, masks[i],
                                              epochs, lr, device, i, results, early_stopping_tag))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    return results


def sequence_train_model(sub_models, subsets, masks, epochs, lr, batch_size, device, early_stopping_tag=True):
    # Training each model in one process
    results = dict()

    for i in range(len(subsets)):
        sub_model = sub_models[i].to(device)
        subset_loader = DataLoader(subsets[i], batch_size=batch_size, shuffle=True)
        train_model(sub_model, subset_loader, masks[i], epochs, lr, device, i, results, early_stopping_tag)

    return results


def flatten(nested_list):
    """ Flatten a list of possibly nested lists into a single list. """
    flat_list = []
    for element in nested_list:
        if isinstance(element, list):
            flat_list.extend(flatten(element))
        else:
            flat_list.append(element)
    return flat_list


def remove_sample(subsets, Ft_indices):
    """
    Removes specified elements from each subset and provides details on the removed elements.

    Args:
    subsets (list of lists): A list where each element is a subset from which elements need to be removed.
    Ft_indices (list): A list of elements that need to be removed from each subset.

    Returns:
    tuple: A tuple containing three elements:
        - Updated subsets with specified elements removed (list of lists).
        - Counts of removed elements for each subset (list).
        - Indices of removed elements for each subset (list of lists).

    Description:
    This function processes a list of subsets and removes elements that are specified in the `Ft_indices` list from each subset.
    It handles subsets containing nested lists by flattening them before processing. The function tracks and returns the count 
    of elements removed from each subset, as well as the actual values of the removed elements. This can be useful for data 
    preprocessing tasks where certain data points need to be excluded from analysis or further processing.
    """
    inter_indices = [None] * len(subsets)
    inter_counts = []
    Ft_indices_set = set(Ft_indices)

    for i in range(len(subsets)):
        # Flatten subsets[i] if it contains nested lists
        if any(isinstance(x, list) for x in subsets[i]):
            flat_subset = flatten(subsets[i])
        else:
            flat_subset = subsets[i]

        intersection = [x for x in flat_subset if x in Ft_indices_set]
        inter_indices[i] = intersection
        inter_counts.append(len(intersection))

        print(f"subsets[{i}] contains {len(intersection)} elements from Ft_indices.")

        # Update subsets[i] by removing elements in Ft_indices
        subsets[i] = [x for x in flat_subset if x not in Ft_indices_set]

    return subsets, inter_counts, inter_indices


def load_param(sub_models_states, sub_models):
    if len(sub_models_states) != len(sub_models):
        raise ValueError("The number of state dictionaries and models must match.")

    for model_state, model in zip(sub_models_states.values(), sub_models):
        if not isinstance(model_state, dict):
            raise TypeError(
                f"Expected state_dict to be dict-like, got {type(model_state)}. Ensure all elements in sub_models_states are state dicts.")

        model.load_state_dict(model_state)
        print("Model parameters updated successfully.")


def aggregate_model_outputs(models, batch_inputs):
    with torch.no_grad():
        outputs = [model(batch_inputs) for model in models]
        avg_outputs = torch.mean(torch.stack(outputs), dim=0)
    return torch.tensor(avg_outputs)


def evaluate_models_accuracy(models, loader):
    correct = 0
    total = 0
    device = next(models[0].parameters()).device

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = aggregate_model_outputs(models, inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return correct / total


def sisa_baseline(train_dataset, indices, subset_indexs, T, eta, batch_size, epochs, num_classes,
                  model_type, split_type='uniform', num_model=5, pretrained_tag=True, early_stopping=True,
                  device="cuda:0", test_loader=None):
    # setting up multi-process parallelism
    # mp.set_start_method('spawn')

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    indices = indices
    Rt_indices = copy.deepcopy(indices)
    prev_F_t_1_indices = []

    # split dataset
    _, subsets_index = split_indices(indices, num_model, "uniform")
    subsets = select_dataset(num_model, train_dataset, subsets_index)

    # initial sub-model
    base_model = CustomModel(model_name=model_type, num_classes=num_classes, pretrained=pretrained_tag).to(device)
    sub_models = [base_model] * num_model
    masks = [False] * num_model
    # sub_models_states = parallel_train_model(sub_models, subsets, masks, epochs, eta, batch_size, device, early_stopping)
    sub_models_states = sequence_train_model(sub_models, subsets, masks, epochs, eta, batch_size, device,
                                             early_stopping)
    load_param(sub_models_states, sub_models)

    for t in range(1, T + 1):

        start = time.time()

        # get unlearning samples id
        Ft_indices = subset_indexs[t - 1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])

        # remove forget samples
        masks = [False] * num_model
        subsets_index, inter_counts, _ = remove_sample(subsets_index, Ft_indices)
        subsets = select_dataset(num_model, train_dataset, subsets_index)
        masks = [True if inter_counts[i] == 0 else masks[i] for i in range(num_model)]

        # retrain
        # sub_models_states = parallel_train_model(sub_models, subsets, masks, epochs, eta, batch_size, device, early_stopping)
        sub_models_states = sequence_train_model(sub_models, subsets, masks, epochs, eta, batch_size, device,
                                                 early_stopping)
        load_param(sub_models_states, sub_models)

        end = time.time()

        efficiency = end - start

        # evaluate
        print(
            f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len(Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=False)
        Acc_Ft = evaluate_models_accuracy(sub_models, Ft_loader)

        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_models_accuracy(sub_models, Rt_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            Acc_F_t_1 = evaluate_models_accuracy(sub_models,
                                                 DataLoader(Subset(train_dataset, prev_F_t_1_indices), batch_size=128,
                                                            shuffle=False))

        prev_F_t_1_indices.extend(Ft_indices)

        Acc_test = evaluate_models_accuracy(sub_models, test_loader)

        print(f"SISA time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}, Acc_test: {Acc_test:.4f}, time: {efficiency:.4f}")

    return 0
