
import json
import os
from itertools import cycle

import autocuda
import numpy as np
import torch
import tqdm
from transformers import AutoTokenizer
from load_text_dataset import Dataset
from model import DetModel


# Function to train the model
def train(
    model,
    optimizer,
    loss_fn,
    curr_train_loader,
    curr_test_loader=None,
    history_train_loader=None,
    history_test_loader=None,
    device=None,
    num_epochs=10,
    **kwargs,
):
    with open(f"{save_dir}/records.json", "r") as fin:
        records = json.load(fin)
    if (
        f"{kwargs['src_dataset']}-{kwargs['src_attack']}" in records
        and not history_test_loader
    ):
        return model
    if (
        f"{kwargs['src_dataset']}-{kwargs['src_attack']}__from__{kwargs['tgt_dataset']}-{kwargs['tgt_attack']}"
        in records
    ):
        return model
    src_dataset = kwargs.get("src_dataset")
    src_attack = kwargs.get("src_attack")
    tgt_dataset = kwargs.get("tgt_dataset")
    tgt_attack = kwargs.get("tgt_attack")
    if history_train_loader is not None:
        history_iter = cycle(
            history_train_loader
        )  # Create a cyclic iterator for history data

        # Evaluation
        curr_det_acc, history_det_acc = evaluate(
            model, curr_test_loader, history_test_loader, device
        )
        performance.append([curr_det_acc, history_det_acc])
        performance_matrix = np.array(performance)
        # forgetting_rates = np.max(performance_matrix[:-1, 1:], axis=0) - performance_matrix[-1, 1:]
        # F = np.mean(forgetting_rates)
        print(f"Initial Current Task Detection Accuracy: {curr_det_acc:.4f}")
        print(f"Initial History Task Detection Accuracy: {history_det_acc:.4f}")
        print(
            f"Initial Average Accuracy (ACC): {np.mean([curr_det_acc, history_det_acc]):.4f}"
        )
        print(f"Initial Forgetting Rate (F): N.A.")

    max_cls_acc, max_det_acc = 0, 0
    for epoch in range(num_epochs):
        for batch in tqdm.tqdm(curr_train_loader, desc=f"Epoch {epoch + 1} Training"):
            optimizer.zero_grad()

            # Current task data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)
            det_label = batch["det_label"].to(device)

            # # Combine current task and replay data
            # combined_input_ids = torch.cat([input_ids, replay_input_ids], dim=0)
            # combined_attention_mask = torch.cat([attention_mask, replay_attention_mask], dim=0)
            # combined_label = torch.cat([label], dim=0)
            # combined_det_label = torch.cat([det_label, replay_det_label], dim=0)
            # with torch.cuda.amp.autocast(dtype=torch.float16):
            #     cls_logits, det_logits = model(combined_input_ids, combined_attention_mask)
            #     loss = loss_fn(cls_logits[:len(input_ids)], label) + loss_fn(det_logits, combined_det_label)

            if history_train_loader is not None:
                # Replay data from history
                history_batch = next(history_iter)
                replay_input_ids = history_batch["input_ids"].to(device)
                replay_attention_mask = history_batch["attention_mask"].to(device)
                replay_label = history_batch["label"].to(device)
                replay_det_label = history_batch["det_label"].to(device)
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    det_logits = model(input_ids, attention_mask)["det_logits"]
                    hist_det_logits = model(replay_input_ids, replay_attention_mask)[
                        "det_logits"
                    ]
                    loss = loss_fn(det_logits, det_label) + loss_fn(
                        hist_det_logits, replay_det_label
                    )
            else:
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    det_logits = model(input_ids, attention_mask)["det_logits"]
                    loss = loss_fn(det_logits, det_label)

            loss.backward()
            optimizer.step()

        # Evaluation
        if history_test_loader is not None:
            curr_det_acc, history_det_acc = evaluate(
                model, curr_test_loader, history_test_loader, device
            )
            with open(f"{save_dir}/records.json", "r") as fin:
                records = json.load(fin)
            performance.append([curr_det_acc, history_det_acc])
            performance_matrix = np.array(performance)
            records[f"{src_dataset}-{src_attack}__from__{tgt_dataset}-{tgt_attack}"] = {
                "curr_det_acc": curr_det_acc,
                "history_det_acc": history_det_acc,
            }
            forgetting_rates = (
                records[f"{src_dataset}-{src_attack}"]["curr_det_acc"]
                - records[
                    f"{src_dataset}-{src_attack}__from__{tgt_dataset}-{tgt_attack}"
                ]["history_det_acc"]
            )
            print(
                f"Average Accuracy (ACC): {np.mean([curr_det_acc, history_det_acc]):.4f}"
            )
            print(f"Forgetting Rate (F): {np.mean(forgetting_rates):.4f}")

            records[f"{src_dataset}-{src_attack}__from__{tgt_dataset}-{tgt_attack}"][
                "F"
            ] = np.mean(forgetting_rates)
            records[f"{src_dataset}-{src_attack}__from__{tgt_dataset}-{tgt_attack}"][
                "ACC"
            ] = np.mean([curr_det_acc, history_det_acc])
            with open(f"{save_dir}/records.json", "w") as fin:
                records = dict(sorted(records.items(), key=lambda item: item[0]))
                json.dump(records, fin, indent=4)
            with open(
                f"{save_dir}/{src_dataset}-{src_attack}__from__{tgt_dataset}-{tgt_attack}/{model.model.base_model.__class__.__name__}/state_dict.pt",
                "wb",
            ) as fout:
                torch.save(model.state_dict(), fout)
        else:
            curr_det_acc, history_det_acc = evaluate(
                model, curr_test_loader, None, device
            )
            with open(f"{save_dir}/records.json", "r") as fin:
                records = json.load(fin)
            records[f"{src_dataset}-{src_attack}"] = {
                "curr_det_acc": curr_det_acc,
                "history_det_acc": history_det_acc,
                "F": 0,
                "ACC": 0,
            }
            with open(f"{save_dir}/records.json", "w") as fin:
                records = dict(sorted(records.items(), key=lambda item: item[0]))
                json.dump(records, fin, indent=4)
            with open(
                f"{save_dir}/{src_dataset}-{src_attack}/{model.model.base_model.__class__.__name__}/state_dict.pt",
                "wb",
            ) as fout:
                torch.save(model.state_dict(), fout)
            max_det_acc = max(max_det_acc, curr_det_acc)
            print(
                f"Epoch {epoch + 1} Current Task Detection Accuracy: {max_det_acc:.4f}"
            )
            print(
                f"Epoch {epoch + 1} History Task Detection Accuracy: {history_det_acc:.4f}"
            )

    return model


# Function to evaluate the model
def evaluate(model, curr_test_loader, history_test_loader=None, device="cuda"):
    model.eval()

    # Current task evaluation
    total, cls_correct, det_correct = 0, 0, 0
    with torch.no_grad():
        for batch in tqdm.tqdm(curr_test_loader, desc="Current Task Testing"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)
            det_label = batch["det_label"].to(device)
            det_logits = model(input_ids, attention_mask)["det_logits"]
            total += label.size(0)
            det_correct += (det_logits.argmax(dim=1) == det_label).sum().item()
    curr_det_acc = det_correct / total

    if history_test_loader is not None:
        # History task evaluation
        total, det_correct = 0, 0
        with torch.no_grad():
            for batch in tqdm.tqdm(history_test_loader, desc="History Task Testing"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                det_label = batch["det_label"].to(device)
                det_logits = model(input_ids, attention_mask)["det_logits"]
                total += det_label.size(0)
                det_correct += (det_logits.argmax(dim=1) == det_label).sum().item()
        history_det_acc = det_correct / total
    else:
        history_det_acc = 0
    return curr_det_acc, history_det_acc


# Main function
if __name__ == "__main__":
    # curr_dataset = 'sst2'
    # attack = 'pwws'
    # history_dataset = 'agnews'
    # history_attack = 'bae'
    max_memory_size = 1
    task_name = "defense"
    base_model_name = "bert-base-uncased"
    batch_size = 32
    learning_rate = 2e-5
    num_epochs = 5
    num_labels = {"sst2": 2, "agnews": 4, "amazon": 2, "yahoo": 10}
    device = autocuda.auto_cuda()

    # Parameters
    datasets = ["sst2", "agnews", "amazon", "yahoo"]
    # datasets = ['sst2']
    # datasets = ['agnews']
    # datasets = ['amazon']
    # datasets = ['yahoo']
    attacks = ["bae", "pwws", "textfooler"]

    save_dir = f"detectors{max_memory_size}"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if os.path.exists(f"{save_dir}/records.json"):
        with open(f"{save_dir}/records.json", "r") as fin:
            records = json.load(fin)
    else:
        records = {}
        with open(f"{save_dir}/records.json", "w") as f:
            json.dump(records, f)
    for curr_dataset in datasets:
        for curr_attack in attacks:
            for history_dataset in datasets:
                for history_attack in attacks:
                    if (
                        f"{curr_dataset}-{curr_attack}"
                        == f"{history_dataset}-{history_attack}"
                    ):
                        with open(f"{save_dir}/records.json", "r") as fin:
                            records = json.load(fin)
                        records[
                            f"{curr_dataset}-{curr_attack}__from__{history_dataset}-{history_attack}"
                        ] = {"curr_det_acc": 0, "history_det_acc": 0, "F": 0, "ACC": 0}
                        continue
                    print(
                        f"{curr_dataset}-{curr_attack}__from__{history_dataset}-{history_attack}"
                    )
                    # Loss function
                    loss_fn = torch.nn.CrossEntropyLoss()

                    # Model and tokenizer
                    model = DetModel(
                        model_name=base_model_name, num_labels=num_labels[curr_dataset]
                    ).to(device)
                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

                    # Datasets and loaders
                    curr_train_set = Dataset(
                        tokenizer,
                        task_name,
                        curr_dataset,
                        "train",
                        attack=curr_attack,
                        max_length=128,
                    )
                    curr_test_set = Dataset(
                        tokenizer,
                        task_name,
                        curr_dataset,
                        "test",
                        attack=curr_attack,
                        max_length=128,
                    )
                    curr_train_loader = torch.utils.data.DataLoader(
                        curr_train_set, batch_size=batch_size, shuffle=True
                    )
                    curr_test_loader = torch.utils.data.DataLoader(
                        curr_test_set, batch_size=batch_size, shuffle=False
                    )

                    history_train_dataset = Dataset(
                        tokenizer,
                        task_name,
                        history_dataset,
                        "train",
                        attack=history_attack,
                        max_length=128,
                    )
                    history_test_dataset = Dataset(
                        tokenizer,
                        task_name,
                        history_dataset,
                        "test",
                        attack=history_attack,
                        max_length=128,
                    )
                    history_train_loader = torch.utils.data.DataLoader(
                        history_train_dataset, batch_size=batch_size, shuffle=True
                    )
                    history_test_loader = torch.utils.data.DataLoader(
                        history_test_dataset, batch_size=batch_size, shuffle=False
                    )

                    kwargs = {
                        "src_dataset": curr_dataset,
                        "src_attack": curr_attack,
                        "tgt_dataset": history_dataset,
                        "tgt_attack": history_attack,
                    }

                    tokenizer.save_pretrained(
                        f"{save_dir}/{curr_dataset}-{curr_attack}__from__{history_dataset}-{history_attack}/{model.model.base_model.__class__.__name__}"
                    )
                    tokenizer.save_pretrained(
                        f"{save_dir}/{curr_dataset}-{curr_attack}/{model.model.base_model.__class__.__name__}"
                    )
                    # History Model Training
                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                    model = train(
                        model,
                        optimizer,
                        loss_fn,
                        history_train_loader,
                        history_test_loader,
                        None,
                        None,
                        device,
                        num_epochs,
                        **kwargs,
                    )

                    history_train_dataset = history_train_dataset[:max_memory_size]
                    if len(history_train_dataset):
                        history_train_loader = torch.utils.data.DataLoader(
                            history_train_dataset, batch_size=batch_size, shuffle=True
                        )
                    else:
                        history_train_loader = None
                    performance = []

                    # Current Model Training
                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                    model = train(
                        model,
                        optimizer,
                        loss_fn,
                        curr_train_loader,
                        curr_test_loader,
                        history_train_loader,
                        history_test_loader,
                        device,
                        num_epochs,
                        **kwargs,
                    )

                    del (
                        model,
                        optimizer,
                        tokenizer,
                        curr_train_set,
                        curr_test_set,
                        curr_train_loader,
                        curr_test_loader,
                    )
                    del (
                        history_train_dataset,
                        history_test_dataset,
                        history_train_loader,
                        history_test_loader,
                    )
