import torch
torch.set_default_device('cuda')
import torch.nn as nn
from torchvision import datasets, transforms
from collections import defaultdict
import os
from tqdm import tqdm
import math
import torch.nn.functional as F
import copy
from utils import calculate_error_rate

torch.autograd.set_detect_anomaly(True)
torch.manual_seed(0)

curr_path = os.path.dirname(os.path.abspath(__file__))
fig_path = os.path.join(curr_path, 'figs')
data_path = os.path.join(curr_path, 'data')
os.makedirs(fig_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)

# Exmaple Definition of global parameters
_eta = 0.0002
mix = 2
mix_s = 16
s = mix * mix_s
M = [1, 4, 7]
topK_list = [1, 1, 1]
flg = len(M)
duration_per_class = 400
T = duration_per_class * 5
buffer_capacity=1000
cnt = 10  # Number of runs

INPUT_FLATTEN_SIZE = 1 * 28 * 28

def download_mnist(root='./dataset/mnist'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.13066047430038452,), (0.30810782313346863,))
    ])
    train_dataset = datasets.MNIST(root=root, train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root=root, train=False, transform=transform, download=True)

    print("MNIST Dataset Downloaded")
    print(f"Train Set Size: {len(train_dataset)}")
    print(f"Test Set Size: {len(test_dataset)}")

    return train_dataset, test_dataset

def download_kmnist(root='./dataset/kmnist'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.19176216423511505,), (0.3483428359031677,))
    ])
    train_dataset = datasets.KMNIST(root=root, train=True, transform=transform, download=True)
    test_dataset = datasets.KMNIST(root=root, train=False, transform=transform, download=True)

    print("KMNIST Dataset Downloaded")
    print(f"Train Set Size: {len(train_dataset)}")
    print(f"Test Set Size: {len(test_dataset)}")

    return train_dataset, test_dataset

def download_fashion_mnist(root='./dataset/fashion_mnist'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860405743122101,), (0.3530242443084717,))
    ])
    train_dataset = datasets.FashionMNIST(root=root, train=True, transform=transform, download=True)
    test_dataset = datasets.FashionMNIST(root=root, train=False, transform=transform, download=True)

    print("FashionMNIST Dataset Downloaded")
    print(f"Train Set Size: {len(train_dataset)}")
    print(f"Test Set Size: {len(test_dataset)}")

    return train_dataset, test_dataset

def download_emnist(root='./dataset/emnist'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1735963225364685,), (0.331650972366333,)) # on bymerge train set
    ])
    train_dataset = datasets.EMNIST(root=root, split='balanced', train=True, transform=transform, download=True) # 47 classes
    test_dataset = datasets.EMNIST(root=root, split='balanced', train=False, transform=transform, download=True)

    print("EMNIST Dataset Downloaded")
    print(f"Train Set Size: {len(train_dataset)}")
    print(f"Test Set Size: {len(test_dataset)}")

    return train_dataset, test_dataset

class ContinuousDataset:
    def __init__(self, dataset, num_sample_per_class=32, num_classes_in_batch=2, max_num_per_class=10000, duration_per_class=10, buffer_capacity=1000):
        """
        Initialize the ContinuousDataset.
        
        :param dataset: The original dataset (e.g., MNIST).
        :param num_sample_per_class: Number of samples per class in each batch.
        :param num_classes_in_batch: Number of classes in each batch.
        :param T: Total number of batches.
        :param max_num_per_class: Maximum number of samples per class in the dataset for avoiding memory overflow.
        :param duration_per_class: Duration of each class in the dataset.
        :param buffer_capacity: Buffer capacity for storing images.
        """
        self.dataset = dataset
        self.buckets = defaultdict(list)  # Store images in buckets by class
        for image, label in dataset:
            if len(self.buckets[label]) < max_num_per_class:  # For avoiding memory overflow.
                self.buckets[label].append(image)  # Add images to their respective class buckets
        self.num_class = len(self.buckets)  # Total number of classes
        self.labels = list(self.buckets.keys())  # List of class labels
        self.num_sample_per_class = num_sample_per_class
        self.num_classes_in_batch = num_classes_in_batch
        self.duration_per_class = duration_per_class
        self.index = 0
        self.order_record = []
        self.class_idx_list = torch.randperm(self.num_class)
        self.class_time = 0

        self.capacity = buffer_capacity
        self.buffer = []
        self.seen = 0

    def __iter__(self):
        """
        Return the iterator object itself.
        """
        self.index = 0
        self.buffer = []
        self.seen = 0
        return self

    def __next__(self):
        """
        Generate and return the next batch of data for training.
        
        :return: A batch of images stacked as a tensor.
        """

        batch = []  # List to store images for the current batch
        batch_labels = []  # List to store labels for the current batch

        if self.index >= len(self.order_record):
            class_idx_list = self.class_idx_list[self.class_time: self.class_time +  self.num_classes_in_batch].tolist()
            self.order_record += [class_idx_list for _ in range(self.duration_per_class)]
            self.class_time += self.num_classes_in_batch
            if self.class_time >= self.num_class:
                self.class_time = 0
        else:
            class_idx_list = self.order_record[self.index]
        self.index += 1

        for j in range(len(class_idx_list)):
            class_idx = self.labels[int(class_idx_list[j])]
            # Randomly sample images from the selected class
            if len(self.buckets[class_idx]) - self.num_sample_per_class > 0:
                start = self.num_sample_per_class * self.index % (len(self.buckets[class_idx]) - self.num_sample_per_class)
            else:
                start = 0
            end = start + self.num_sample_per_class
            batch += self.buckets[class_idx][start:end]  # Add the image to the batch
            batch_labels += [class_idx for _ in range(self.num_sample_per_class)]  # Add the label to the batch
        
        buffer_batch = []
        buffer_batch_labels = []
        if len(self.buffer) > 0:
            for input, label in self.buffer[:self.capacity]:
                buffer_batch.append(input)
                buffer_batch_labels.append(label)

        # update buffer
        self.seen += len(batch)
        if len(self.buffer) < self.capacity:
            self.buffer += zip(batch, batch_labels)
        else:
            replace_prob = self.capacity / self.seen
            for tensor_sample in zip(batch, batch_labels):
                if torch.rand(1) < replace_prob:
                    replace_idx = torch.randint(0, self.capacity, (1,)).item()
                    self.buffer[replace_idx] = tensor_sample

        return torch.stack(batch + buffer_batch), torch.tensor(batch_labels + buffer_batch_labels)  # Stack images into a single tensor

    def get_test_with_label(self, label, max=100):
        """
        Get a batch of test images with a specific label.
        :param label: The label of the test images.
        """
        return torch.stack(self.buckets[label][:max]), torch.tensor([label for _ in  range(len(self.buckets[label][:max]))])

class ExpertModel(nn.Module):
    def __init__(self, num_classes):
        super(ExpertModel, self).__init__()
        self.linear = nn.Linear(INPUT_FLATTEN_SIZE, 512)
        self.hidden = nn.Linear(512, 512)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.linear(torch.flatten(x, 1))
        x = self.hidden(x)
        x = torch.softmax(self.fc(x), dim=-1)
        return x

class GatingNetwork(nn.Module):
    def __init__(self, num_experts):
        super(GatingNetwork, self).__init__()
        self.linear = nn.Linear(INPUT_FLATTEN_SIZE, num_experts)

    def forward(self, x):
        return torch.softmax(self.linear(torch.flatten(x, 1)), dim=-1)

def cal_F_and_G(t, loss_record):
    Gt = torch.mean(loss_record[t])
    if t > 0:
        l_t = loss_record.t()
        Ft = max(torch.mean(loss_record[t] - torch.tensor([torch.min(l[:t]) for l in l_t])), torch.tensor(0))
    else:
        Ft = torch.tensor(0)
    return Ft.cpu(), Gt.cpu()


# 2nd-order Renyi entropy
def S2(X):
    n = X.shape[0]  # Number of samples
    sigma = 1
    
    # Calculate the sum of squared Euclidean distances for all feature sets
    total_distances = torch.cdist(X, X, p=2) ** 2  # Calculate squared Euclidean distance
    
    # Calculate S_2
    S2 = -torch.log2((1 / n + (2 / (n ** 2)) * torch.sum(torch.triu(torch.exp(-total_distances / (sigma ** 2)), diagonal=1))))
    
    return S2

# 2nd-order Renyi entropy
def S2s(*X_list):
    n = X_list[0].shape[0]  # Number of samples
    sigma = 1
    
    # Calculate the sum of squared Euclidean distances for all feature sets
    total_distances = torch.zeros((n, n), device=X_list[0].device)
    for X in X_list:
        distances = torch.cdist(X, X, p=2) ** 2  # Calculate squared Euclidean distance
        total_distances += distances ** 2
    
    # Calculate S_2
    S2 = -torch.log2((1 / n + (2 / (n ** 2)) * torch.sum(torch.triu(torch.exp(-total_distances / (sigma ** 2)), diagonal=1))))
    
    return S2

def S2xy(X, Y):
    n = X.shape[0]  # Number of samples
    sigma = 1
    
    # Calculate the sum of squared Euclidean distances for all feature sets
    distances_X = torch.cdist(X, X, p=2) ** 2  # Calculate squared Euclidean distance
    distances_Y = torch.cdist(Y, Y, p=2) ** 2  # Calculate squared Euclidean distance
    abs_XY = torch.abs(distances_X - distances_Y)
    max_XY = (distances_X + distances_Y + abs_XY) / 2
    min_XY = (distances_X + distances_Y - abs_XY) / 2
    epsilon = 1e-6
    total_distances = max_XY + min_XY * abs_XY / (distances_X + distances_Y + epsilon)
    
    # Calculate S_2
    S2 = -torch.log2((1 / n + (2 / (n ** 2)) * torch.sum(torch.triu(torch.exp(-total_distances / (sigma ** 2)), diagonal=1))))
    
    return S2

def normalize(tensor):
    tensor = tensor.float()
    try:
        D = tensor.shape[1]
        mean = tensor.mean()
        std = tensor.std()
        if std > 0.000001:
            normalized_tensor = (D ** -0.25) * (tensor - mean) / std
        else:
            normalized_tensor = (D ** -0.25) * (tensor - mean)
    except IndexError:
        D = 1
        mean = tensor.mean()
        std = tensor.std()
        if std > 0.000001:
            normalized_tensor = (tensor - mean) / std
        else:
            normalized_tensor = (tensor - mean)
        normalized_tensor = torch.reshape(normalized_tensor, (tensor.shape[0], 1))
    return normalized_tensor

def MI(a, b):
    x = normalize(a)
    y = normalize(b)
    return (S2s(x, y) - S2(x)) / S2(y) # [H(Y) - I(X,Y)] / H(Y)

def new_MI(a, b):
    x = normalize(a)
    y = normalize(b)
    return (S2xy(x, y) - S2(x)) / S2(y) # [H(Y) - I(X,Y)] / H(Y)

def main(is_MI, is_label, delta_scale, dataset_name, is_new_MI):
    name_pre = f"{dataset_name}-MI-" if is_MI else f"{dataset_name}-Vanilla-"
    torch.manual_seed(0)
    
    if is_MI:
        if is_label:
            name_pre += "Y"
        else:
            name_pre += "Z"
        name_pre += f"G({delta_scale})"
        if is_new_MI:
            name_pre += "-NewMI"

    name_pre += "-OneHot"

    name = f"{name_pre}-top{','.join([str(i) for i in topK_list])}-M{','.join([str(i) for i in M])}-{mix}x{mix_s}-{T}t-x{cnt}-p{_eta}"
    print(name)

    if not os.path.exists(os.path.join(data_path, f'{dataset_name}-forgetting_error_{name}.pt')):
        train_dataset = None
        test_dataset = None
        if dataset_name == 'MNIST':
            train_dataset, test_dataset = download_mnist()
        elif dataset_name == 'KMNIST':
            train_dataset, test_dataset = download_kmnist()
        elif dataset_name == 'FashionMNIST':
            train_dataset, test_dataset = download_fashion_mnist()
        elif dataset_name == 'EMNIST':
            train_dataset, test_dataset = download_emnist()
        else:
            raise ValueError(f"Invalid dataset name: {dataset_name}")
        train_buckets = ContinuousDataset(train_dataset, num_sample_per_class=mix_s, num_classes_in_batch=mix, duration_per_class=duration_per_class, buffer_capacity=buffer_capacity)
        test_buckets = ContinuousDataset(test_dataset, num_sample_per_class=mix_s, num_classes_in_batch=mix)
        num_classes = train_buckets.num_class
        N = num_classes

        Ft_record_all = torch.zeros(flg, T, cnt).cpu()
        Gt_record_all = torch.zeros(flg, T, cnt).cpu()

        for count in tqdm(range(cnt)):
            Ft_record = torch.zeros(flg, T).cpu()
            Gt_record = torch.zeros(flg, T).cpu()


            for flag in range(flg):
                models = [ExpertModel(num_classes) for _ in range(M[flag])]
                gating_network = GatingNetwork(M[flag])
                optimizers = [torch.optim.RAdam(model.parameters(), lr=_eta) for model in models]
                gating_optimizer = torch.optim.RAdam(gating_network.parameters(), lr=_eta)
                loss_fn = nn.CrossEntropyLoss()
                
                train_iterator = iter(train_buckets)

                loss_record = torch.zeros(T, N)

                prev_models = [copy.deepcopy(model) for model in models]

                for t in range(T):
                    Xt, yt = next(train_iterator)

                    # Gating network output
                    gate_output = gating_network(Xt.cuda())  # Shape: (batch_size, M[flag])

                    # Apply topK selection
                    topK = topK_list[flag]
                    if topK > M[flag]:
                        topK = M[flag]
                    topk_values, topk_indices = torch.topk(gate_output, topK, dim=1)

                    gate_output = torch.zeros_like(gate_output).scatter(1, topk_indices, topk_values)
                    include_indices = topk_indices.view(-1)
                    expert_outputs = torch.stack([model(Xt.cuda()) if i in include_indices else torch.zeros((Xt.shape[0], N)).cuda() for i, model in enumerate(models)], dim=1)
                    final_output_expt = torch.sum(gate_output.unsqueeze(-1) * expert_outputs, dim=1)
                    final_output_gate = torch.sum(gate_output.unsqueeze(-1) * expert_outputs.detach(), dim=1)
                    
                    unique_experts = torch.unique(topk_indices)
                    # Compute expert loss
                    loss_expert = loss_fn(final_output_expt, yt)
                    
                    # Compute gate loss
                    loss_gate = loss_fn(final_output_gate, yt)

                    if is_MI:
                        if is_label:
                            one_hot_labels = F.one_hot(yt.detach(), num_classes)
                            if is_new_MI:
                                MI_loss = new_MI(gate_output, one_hot_labels)
                            else:
                                MI_loss = MI(gate_output, one_hot_labels)
                        else:
                            if is_new_MI:
                                MI_loss = new_MI(gate_output, final_output_gate.detach())
                            else:
                                MI_loss = MI(gate_output, final_output_gate.detach())
                        loss_gate = loss_gate + MI_loss * delta_scale
                    # Zero gradients before backward passes
                    gating_optimizer.zero_grad()
                    loss_gate.backward(retain_graph=True)
                    gating_optimizer.step()
                    # gating_scheduler.step()
                    for i, optimizer in enumerate(optimizers):
                        if i in unique_experts:
                            optimizer.zero_grad()
                    # Backpropagate losses
                    loss_expert.backward()
                    # Update all parameters after both backward passes
                    for i, optimizer in enumerate(optimizers):
                        if i in unique_experts:
                            optimizer.step()

                    # Record the loss for each ground truth
                    for label in test_buckets.labels:
                        Xt_t, yt_t = test_buckets.get_test_with_label(label)
                        gate_output_test = gating_network(Xt_t.cuda())
                        if topK > M[flag]:
                            topK = M[flag]
                        topk_values_test, topk_indices_test = torch.topk(gate_output_test, topK, dim=1)
                        gate_output_test = torch.zeros_like(gate_output_test).scatter(1, topk_indices_test, topk_values_test)
                        include_indices = topk_indices.view(-1)
                        expert_outputs_test = torch.stack([model(Xt_t.cuda()) if i in include_indices else torch.zeros((Xt_t.shape[0], N)).cuda() for i, model in enumerate(models)], dim=1)
                        final_output_test = torch.sum(gate_output_test.unsqueeze(-1) * expert_outputs_test, dim=1)
                        loss_test = calculate_error_rate(final_output_test, yt_t)
                        loss_record[t, label] += loss_test

                    Ft_record[flag, t], Gt_record[flag, t] = cal_F_and_G(t, loss_record)

                Ft_record_all[flag, :, count] = Ft_record[flag]
                Gt_record_all[flag, :, count] = Gt_record[flag]

        Ft_record_average = torch.mean(Ft_record_all, dim=2)
        Gt_record_average = torch.mean(Gt_record_all, dim=2)
        Ft_record_CI = (torch.std(Ft_record_all, dim=2) / math.sqrt(cnt)) * 1.645
        Gt_record_CI = (torch.std(Gt_record_all, dim=2) / math.sqrt(cnt)) * 1.645

        torch.save({
            'Ft_record_average': Ft_record_average,
            'Gt_record_average': Gt_record_average,
            'Ft_record_CI': Ft_record_CI,
            'Gt_record_CI': Gt_record_CI
        }, os.path.join(data_path, f'{dataset_name}-forgetting_error_{name}.pt'))
    return

import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Configure script parameters.")

    # MI loss format
    parser.add_argument(
        "--MI",
        action="store_true",
        default=False,
        help="Whether to use MI loss (default: False)",
    )

    parser.add_argument(
        "--label",
        action="store_true",
        default=False,
        help="Whether to use labels (default: False)",
    )

    parser.add_argument(
        "--delta_scale",
        type=int,
        default=1,
        help="Delta scale value (default: 1)",
    )

    parser.add_argument(
        "--dataset_name",
        type=str,
        default='MNIST',
        choices=['MNIST', 'KMNIST', 'FashionMNIST', 'EMNIST'],
        help="Dataset name (default: MNIST)",
    )

    parser.add_argument(
        "--new_MI",
        action="store_true",
        default=False,
        help="Whether to use our new_MI (default: False)",
    )

    return parser.parse_args()

if __name__ == '__main__':
    
    args = parse_args()

    # MI loss format
    is_MI = args.MI  # Whether to use MI loss
    is_label = args.label  # Whether to use labels
    delta_scale = args.delta_scale
    dataset_name = args.dataset_name
    is_new_MI = args.new_MI

    main(is_MI, is_label, delta_scale, dataset_name, is_new_MI)