
import numpy as np
import torch
import tqdm
from transformers import AutoTokenizer
from load_text_dataset import CLSDataset
from model import CLSModel


# Function to train the model
def train(
    model,
    optimizer,
    loss_fn,
    train_loader,
    test_loader=None,
    device=None,
    num_epochs=10,
):

    max_cls_acc, max_cls_acc = 0, 0
    for epoch in range(num_epochs):
        for batch in tqdm.tqdm(train_loader, desc=f"Epoch {epoch + 1} Training"):
            optimizer.zero_grad()

            # Current task data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(input_ids, attention_mask)["logits"]
                loss = loss_fn(logits, label)

            loss.backward()
            optimizer.step()

        # Evaluation
        accuracy = evaluate(model, test_loader, device)
        if accuracy > max_cls_acc:
            with open(
                f"classifiers/{dataset_name}/"
                + model.model.base_model.__class__.__name__
                + "/state_dict.pt",
                "wb",
            ) as fout:
                torch.save(model.state_dict(), fout)
        max_cls_acc = max(max_cls_acc, accuracy)

        print(f"Epoch {epoch + 1} Classification Accuracy: {accuracy:.4f}")
        print(f"Max Classification Accuracy: {max_cls_acc:.4f}")

    return model


# Function to evaluate the model
def evaluate(model, test_loader, device="cuda"):
    model.eval()

    # Current task evaluation
    total, cls_correct = 0, 0
    with torch.no_grad():
        for batch in tqdm.tqdm(test_loader, desc="Current Task Testing"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)
            logits = model(input_ids, attention_mask)
            total += label.size(0)
            cls_correct += (logits.argmax(dim=1) == label).sum().item()
    curr_cls_acc = cls_correct / total

    return curr_cls_acc


# Main function
if __name__ == "__main__":
    # Parameters
    dataset_name = "yahoo"

    task_name = "classification"
    # base_model_name = 'bert-base-uncased'
    base_model_name = "roberta-base"
    batch_size = 32
    learning_rate = 2e-5
    num_epochs = 10
    num_labels = {"sst2": 2, "agnews": 4, "amazon": 2, "yahoo": 10}
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model and tokenizer
    model = CLSModel(
        model_name=base_model_name, num_labels=num_labels[dataset_name]
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.save_pretrained(
        f"classifiers/{dataset_name}/" + model.model.base_model.__class__.__name__
    )
    # Datasets and loaders
    curr_train_set = CLSDataset(
        tokenizer, task_name, dataset_name, "train", max_length=128
    )
    curr_test_set = CLSDataset(
        tokenizer, task_name, dataset_name, "test", max_length=128
    )
    curr_train_loader = torch.utils.data.DataLoader(
        curr_train_set, batch_size=batch_size, shuffle=True
    )
    curr_test_loader = torch.utils.data.DataLoader(
        curr_test_set, batch_size=batch_size, shuffle=False
    )
    # Loss function
    loss_fn = torch.nn.CrossEntropyLoss()

    max_cls_acc = 0
    performance = []
    # Current Model Training
    model = train(
        model,
        optimizer,
        loss_fn,
        curr_train_loader,
        curr_test_loader,
        device,
        num_epochs,
    )
