from itertools import cycle

import numpy as np

import autocuda
import torch
import tqdm
from transformers import AutoTokenizer

from load_text_dataset import Dataset
from model import ADModel

if __name__ == "__main__":

    dataset_name = "sst2"
    history_dataset_name = "agnews"
    max_memory_size = 100

    task_name = "defense"
    kword = "pwws"
    base_model_name = "bert-base-uncased"
    batch_size = 32
    learning_rate = 2e-5
    num_epochs = 5
    num_labels = {
        "sst2": 2,
        "agnews": 4,
        "amazon": 2,
        "yahoo": 10,
        "yelp": 2,
        "imdb": 2,
    }
    device = autocuda.auto_cuda()
    model = ADModel(model_name=base_model_name, num_labels=num_labels[dataset_name]).to(
        device
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    curr_train_set = Dataset(
        tokenizer, task_name, dataset_name, "train", attack=kword, max_length=128
    )
    curr_test_set = Dataset(
        tokenizer, task_name, dataset_name, "test", attack=kword, max_length=128
    )
    curr_dev_set = Dataset(
        tokenizer, task_name, dataset_name, "dev", attack=kword, max_length=128
    )

    history_train_dataset = Dataset(
        tokenizer,
        task_name,
        history_dataset_name,
        "train",
        attack=kword,
        max_length=128,
    )
    history_test_dataset = Dataset(
        tokenizer, task_name, history_dataset_name, "test", attack=kword, max_length=128
    )
    history_train_dataset = history_train_dataset[
        :max_memory_size
    ]  # shuffle and select a subset of the history dataset
    history_train_loader = torch.utils.data.DataLoader(
        history_train_dataset, batch_size=batch_size, shuffle=True
    )
    history_test_loader = torch.utils.data.DataLoader(
        history_test_dataset, batch_size=batch_size, shuffle=False
    )
    history_train_loader = iter(cycle(history_train_loader))
    history_test_loader = history_test_loader

    curr_train_loader = torch.utils.data.DataLoader(
        curr_train_set, batch_size=batch_size, shuffle=True
    )
    curr_test_loader = torch.utils.data.DataLoader(
        curr_test_set, batch_size=batch_size, shuffle=False
    )
    curr_dev_loader = torch.utils.data.DataLoader(
        curr_dev_set, batch_size=batch_size, shuffle=False
    )
    loss_fn = torch.nn.CrossEntropyLoss()
    max_det_acc, max_cls_acc = 0, 0

    # 初始化性能矩阵
    num_tasks = 2  # 当前任务和历史任务
    performance_matrix = np.zeros((num_epochs, num_tasks))
    for epoch in range(num_epochs):
        for batch in tqdm.tqdm(curr_train_loader, desc=f"Epoch {epoch + 1} Training"):
            # 当前任务数据
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)
            adv_label = batch["adv_label"].to(device)
            det_label = batch["det_label"].to(device)

            # 从历史数据加载器中获取Replay数据
            history_batch = next(history_train_loader)
            replay_input_ids = history_batch["input_ids"].to(device)
            replay_attention_mask = history_batch["attention_mask"].to(device)
            replay_label = history_batch["label"].to(device)
            replay_adv_label = history_batch["adv_label"].to(device)
            replay_det_label = history_batch["det_label"].to(device)

            # 将Replay数据与当前任务数据结合
            combined_input_ids = torch.cat([input_ids, replay_input_ids], dim=0)
            combined_attention_mask = torch.cat(
                [attention_mask, replay_attention_mask], dim=0
            )
            # combined_label = torch.cat([label, replay_label], dim=0)
            combined_label = torch.cat([label], dim=0)
            combined_det_label = torch.cat([det_label, replay_det_label], dim=0)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast(dtype=torch.float16):
                det_logits = model(combined_input_ids, combined_attention_mask)
                loss = loss_fn(det_logits, combined_det_label)

            loss.backward()
            optimizer.step()

        with torch.no_grad():
            model.eval()
            total = 0
            cls_correct = 0
            det_correct = 0
            for batch in tqdm.tqdm(curr_test_loader, desc="Testing"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                label = batch["label"].to(device)
                adv_label = batch["adv_label"].to(device)
                det_label = batch["det_label"].to(device)
                det_logits = model(input_ids, attention_mask)
                total += label.size(0)
                det_correct += (det_logits.argmax(dim=1) == det_label).sum().item()
            det_acc = max(det_correct / total, max_det_acc)
            print(f"Epoch {epoch+1} Test Detection Accuracy: {det_acc:.4f}")

            # 历史任务的测试准确率
            cls_correct, det_correct, total = 0, 0, 0
            with torch.no_grad():
                for batch in history_test_loader:
                    input_ids = batch["input_ids"].to(device)
                    attention_mask = batch["attention_mask"].to(device)
                    label = batch["label"].to(device)
                    det_label = batch["det_label"].to(device)
                    det_logits = model(input_ids, attention_mask)
                    total += label.size(0)
                    det_correct += (det_logits.argmax(dim=1) == det_label).sum().item()
            history_det_acc = det_correct / total

            # 更新性能矩阵
            performance_matrix[epoch, 0] = det_acc
            performance_matrix[epoch, 1] = history_det_acc

            # 平均准确率：最后一个 epoch 的各任务平均准确率
            ACC = np.mean(performance_matrix[-1])
            print(f"Average Accuracy (ACC): {ACC:.4f}")

            # 遗忘率：历史任务的最大准确率减去最终准确率
            forgetting_rates = (
                np.max(performance_matrix[:-1, 1:], axis=0) - performance_matrix[-1, 1:]
            )
            F = np.mean(forgetting_rates)
            print(f"Forgetting Rate (F): {F:.4f}")

        model.train()
