from datasets import load_dataset
import transformers
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, get_scheduler
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import evaluate
import wandb
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import random
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import logging
import itertools

from utils import *

parser = argparse.ArgumentParser()
parser.add_argument("--log_name", type=str)
parser.add_argument("--seed", type=int)
parser.add_argument("--s", type=float)
parser.add_argument("--act_var_tolerance", type=float)
parser.add_argument("--weight_var_tolerance", type=float)
parser.add_argument("--s_update_step", type=float)
parser.add_argument("--weight_ratio_multiplier", type=float)
parser.add_argument("--wandb", action="store_true")
args = parser.parse_args()

LOG_FORMAT = "[%(asctime)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, filename=f"logs/{args.log_name}.log")
log = logging.getLogger(__name__)

PROJECT = "your_project"
ENTITY = "your_entity"

if args.wandb:
    wandb.init(name=args.log_name, project=PROJECT, config=args, entity=ENTITY)

NUM_EPOCHS = 3
LR = 5e-5

SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

CAL_VAR_FREQ = 50

log.info("loading dataset")
dataset = load_dataset("glue", "sst2")

checkpoint = "bert-base-uncased"

log.info("loading tokenizer")
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)
def tokenize(example):
    return tokenizer(example["sentence"], max_length=128, truncation=True, padding="max_length")
tokenized_dataset = dataset.map(tokenize, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_dataset = tokenized_dataset.remove_columns(["sentence", "idx"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset.set_format("torch")

train_dataloader = DataLoader(
    tokenized_dataset["train"], shuffle=True, batch_size=32, collate_fn=data_collator
)
valid_dataloader = DataLoader(
    tokenized_dataset["validation"], batch_size=32, collate_fn=data_collator
)

log.info("loading model")
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

optimizer = Adam(model.parameters(), lr=LR)

num_per_epoch = len(train_dataloader)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

model.train()
log.info(f"using device:{device}, start training")
progress_bar = tqdm(range(num_training_steps))
start.record()

ratio_N = 0
ratio_avg = 0
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    i = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss.mean()
        total_loss += loss.item()
        loss.backward()


        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        i += 1
        progress_bar.update(1)
        if i % CAL_VAR_FREQ == 0:
            eval_loss, eval_acc = eval(model, valid_dataloader, device)
            train_metric = {"loss": total_loss / CAL_VAR_FREQ}
            eval_metric = {"loss": eval_loss, "acc": eval_acc}
            if args.wandb:
                update_summary(i + epoch * num_per_epoch, train_metric, eval_metric)
            logging.info(f"epoch: {epoch}, step: {i}, train loss: {total_loss / CAL_VAR_FREQ}, eval loss: {eval_loss}, eval acc: {eval_acc}\n")
            total_loss = 0
torch.cuda.synchronize()
end.record()
log.info(f"\ntraining finished, time elapsed: {start.elapsed_time(end)/1000/60:.4f} mins\n")


model.eval()
predictions = []
labels = []
for batch in valid_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    predictions.extend(torch.argmax(logits, dim=-1).tolist())
    labels.extend(batch["labels"].tolist())
log.info("confusion matrix:")
log.info(confusion_matrix(labels, predictions))
log.info("classification report:")
log.info(classification_report(labels, predictions, digits=4))

