import matplotlib
matplotlib.use('Agg')
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as T
import torch.nn.functional as F
import torch.optim as optim
import random
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda")
num_tasks = 2       # number of tasks
buffer_size = 1000  # fixed buffer size
train_bs = 64       # train batch size
test_bs = 1000     # test batch size
lr = 1.0           # learning rate
gamma = 0.7        # decay





class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 3, 32, 32)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)




class ReplayBuffer:
    def __init__(self, buffer_size=6000):
        self.buffer_size = buffer_size
        self.buffer = []
        self.class_distribution = {}
    def add_data(self, task_data, task_classes):
        num_classes = len(task_classes)
        samples_per_class = self.buffer_size // num_classes
        for cls in task_classes:
            class_samples = [(x, y) for x, y in task_data if y == cls]
            # random sampling
            selected_samples = random.sample(class_samples, min(samples_per_class, len(class_samples)))
            self.buffer.extend(selected_samples)
        if len(self.buffer) > self.buffer_size:
            self.buffer = random.sample(self.buffer, self.buffer_size)
    def get_loader(self, batch_size):
        if not self.buffer:
            return None
        data = torch.stack([x for x, y in self.buffer])
        labels = torch.tensor([y for x, y in self.buffer])
        replay_dataset = torch.utils.data.TensorDataset(data, labels)
        return torch.utils.data.DataLoader(replay_dataset, batch_size=batch_size, shuffle=True)




def train(model, device, train_loader, optimizer, replay_loader=None):
    model.train()
    correct = 0
    total_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total_samples += len(data)
    if replay_loader is not None:
        for batch_idx, (data, target) in enumerate(replay_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += len(data)
    # Calculate the accuracy
    train_acc = 100. * correct / total_samples
    return train_acc





def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)
    return test_loss, test_acc





def test_all(model, device, tasks, current_task_idx=None):
    accs = []
    losses = []
    for j in range(len(tasks)):
        [train_df, test_df] = tasks[j]
        test_loader = torch.utils.data.DataLoader(test_df, batch_size=test_bs)
        test_loss, test_acc = test(model, device, test_loader)
        accs.append(test_acc)
        losses.append(test_loss)
        print(f'test on task {j + 1}: test loss {test_loss:.4f}, test acc {test_acc:.2f}')
    return accs, losses





def generate_tasks_with_overlap(overlap):
    """Generate tasks with specified overlap between task 1 and task 2"""
    task1_classes = [0, 1, 2, 3, 4]
    overlap_classes = task1_classes[:overlap]
    new_classes = [5, 6, 7, 8, 9][:5 - overlap]
    task2_classes = overlap_classes + new_classes
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        T.Lambda(lambda x: torch.flatten(x))
    ])
    full_train = datasets.CIFAR10('/root/data/hhjl/URCL-main/dataset/cifar10', train=True, download=True, transform=transform)
    full_test = datasets.CIFAR10('/root/data/hhjl/URCL-main/dataset/cifar10', train=False, download=True, transform=transform)
    def filter_by_class(dataset, classes):
        indices = [i for i, (_, label) in enumerate(dataset) if label in classes]
        return torch.utils.data.Subset(dataset, indices)
    task1_train = filter_by_class(full_train, task1_classes)
    task1_test = filter_by_class(full_test, task1_classes)
    task2_train = filter_by_class(full_train, task2_classes)
    task2_test = filter_by_class(full_test, task2_classes)
    return [[task1_train, task1_test], [task2_train, task2_test]]






def test1(model, device, tasks, buffer_size):
    acc_metrics = []
    loss_metrics = []
    replay_buffer = ReplayBuffer(buffer_size=buffer_size)
    for i in range(num_tasks):
        print( "=" * 22 + f' Train on task {i + 1} ' + "=" * 22)
        [train_df, test_df] = tasks[i]
        current_classes = list(set([label for _, label in train_df]))
        replay_buffer.add_data(list(train_df), current_classes)
        optimizer = optim.Adadelta(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
        train_loader = torch.utils.data.DataLoader(train_df, batch_size=train_bs, shuffle=True, num_workers=4,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_df, batch_size=test_bs, num_workers=4, pin_memory=True)
        replay_loader = replay_buffer.get_loader(train_bs) if i > 0 else None
        for epoch in range(100):
            train_acc = train( model, device, train_loader, optimizer, replay_loader)
            test_loss, test_acc = test(model, device, test_loader)
            print(f'epoch{epoch + 1}, train acc：{train_acc:.2f}, test loss: {test_loss:.4f}, test acc：{test_acc:.2f}')
            scheduler.step()
        accs, losses = test_all(model, device, tasks, current_task_idx=i)
        acc_metrics.append(accs)
        loss_metrics.append(losses)
    acc_df = pd.DataFrame(acc_metrics,
                          index=[f'After Task {i + 1}' for i in range(num_tasks)],
                          columns=[f'Task {j + 1}' for j in range(num_tasks)])
    loss_df = pd.DataFrame(loss_metrics,
                           index=[f'After Task {i + 1}' for i in range(num_tasks)],
                           columns=[f'Task {j + 1}' for j in range(num_tasks)])
    return acc_df, loss_df





def run_overlap_experiment(overlap_levels):
    results = []
    for overlap in overlap_levels:
        print("=" * 15 + f'Running experiment with {overlap} overlapping classes' + "=" * 15)
        tasks = generate_tasks_with_overlap(overlap)
        model = Net().to(device)
        model = torch.nn.DataParallel(model)
        # Train and test
        acc_results, loss_results = test1(model, device, tasks, buffer_size)
        adaptation_errors = []
        memory_errors = []
        generalization_errors = []
        for t in range(1, num_tasks + 1):
            # Adaptation
            A_t = loss_results.loc[f'After Task {t}', f'Task {t}']
            adaptation_errors.append(A_t)
            # Memory
            if t == 1:
                memory_errors.append(float('nan'))
            else:
                sum_diff = 0.0
                for i in range(1, t):
                    sum_diff += (loss_results.loc[f'After Task {t}', f'Task {i}'] -
                                 loss_results.loc[f'After Task {i}', f'Task {i}'])
                F_t = sum_diff / (t - 1)
                memory_errors.append(F_t)
            # Generalization
            sum_loss = 0.0
            for i in range(1, t + 1):
                sum_loss += loss_results.loc[f'After Task {t}', f'Task {i}']
            G_t = sum_loss / t
            generalization_errors.append(G_t)
        results.append({
            'overlap': overlap,
            'adaptation_task1': adaptation_errors[0],
            'adaptation_task2': adaptation_errors[1],
            'memory_task2': memory_errors[1],
            'generalization_task1': generalization_errors[0],
            'generalization_task2': generalization_errors[1]
        })
    return pd.DataFrame(results)


overlap_levels = range(0, 6, 1)
run_overlap_experiment(overlap_levels)





