from datasets import load_dataset
import transformers
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, get_scheduler
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import evaluate
import wandb
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import random
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import logging
import itertools
from transformers.models.bert import linear_bert

from utils import *

parser = argparse.ArgumentParser()
parser.add_argument("--log_name", type=str)
parser.add_argument("--seed", type=int)
parser.add_argument("--s", type=float)
parser.add_argument("--act_var_tolerance", type=float)
parser.add_argument("--weight_var_tolerance", type=float)
parser.add_argument("--s_update_step", type=float)
parser.add_argument("--weight_ratio_multiplier", type=float)
parser.add_argument("--wandb", action="store_true")
args = parser.parse_args()

LOG_FORMAT = "[%(asctime)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, filename=f"logs/{args.log_name}.log")
log = logging.getLogger(__name__)

PROJECT = "your_project"
ENTITY = "your_entity"

if args.wandb:
    wandb.init(name=args.log_name, project=PROJECT, config=args, entity=ENTITY)

NUM_EPOCHS = 3
LR = 5e-5

SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

CAL_VAR_FREQ = 300
CAL_VAR_M = 3
linear_bert.CAL_VAR_M = CAL_VAR_M

linear_bert.S = args.s
linear_bert.ACT_VAR_TOLERANCE = args.act_var_tolerance
linear_bert.WEIGHT_VAR_TOLERANCE = args.weight_var_tolerance
linear_bert.S_UPDATE_STEP = args.s_update_step
linear_bert.WEIGHT_RATIO_MULTIPLIER = args.weight_ratio_multiplier
logging.info(f"S: {linear_bert.S}, ACT_VAR_TOLERANCE: {linear_bert.ACT_VAR_TOLERANCE}, WEIGHT_VAR_TOLERANCE: {linear_bert.WEIGHT_VAR_TOLERANCE}, S_UPDATE_STEP: {linear_bert.S_UPDATE_STEP}, WEIGHT_RATIO_MULTIPLIER: {linear_bert.WEIGHT_RATIO_MULTIPLIER}")

log.info("loading dataset")
dataset = load_dataset("glue", "mnli")

checkpoint = "bert-base-uncased"

log.info("loading tokenizer")
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)
def tokenize(example):
    return tokenizer(example["premise"], example["hypothesis"], max_length=128, truncation=True)
tokenized_dataset = dataset.map(tokenize, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_dataset = tokenized_dataset.remove_columns(["premise", "hypothesis", "idx"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset.set_format("torch")

train_dataloader = DataLoader(
    tokenized_dataset["train"], shuffle=True, batch_size=32, collate_fn=data_collator
)
matched_dataloader = DataLoader(
    tokenized_dataset["validation_matched"], batch_size=32, collate_fn=data_collator
)
mismatched_dataloader = DataLoader(
    tokenized_dataset["validation_mismatched"], batch_size=32, collate_fn=data_collator
)

log.info("loading model")
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=3)

optimizer = Adam(model.parameters(), lr=LR)

num_per_epoch = len(train_dataloader)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

model.train()
log.info(f"using device:{device}, start training")
progress_bar = tqdm(range(num_training_steps))
start.record()

ratio_N = 0
ratio_avg = 0
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    i = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss.mean()
        total_loss += loss.item()
        loss.backward()


        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        i += 1
        progress_bar.update(1)
        
        if i % CAL_VAR_FREQ == 0:
            # cal_var_batches = []
            sgd_grad_mean = None
            sgd_grad_sq_mean = None
            sgd_var = 0
            act_var = 0
            linear_bert.test = True
            org_state = torch.get_rng_state().clone()
            for batch in itertools.islice(train_dataloader, CAL_VAR_M):
                batch = {k: v.to(device) for k, v in batch.items()}
                # cal_var_batches.append(batch)
                linear_bert.sample = False
                setup_seed(0)
                outputs = model(**batch)
                loss = outputs.loss.mean()
                loss.backward()
                grad = torch.cat([p.grad.flatten().clone() for p in model.parameters()])
                if sgd_grad_mean is None:
                    sgd_grad_mean = grad
                    sgd_grad_sq_mean = grad ** 2
                else:
                    sgd_grad_mean += grad
                    sgd_grad_sq_mean += grad ** 2
                optimizer.zero_grad()
                
                linear_bert.sample = True
                act_var_tmp = 0
                for _ in range(CAL_VAR_M):
                    setup_seed(0)
                    outputs = model(**batch)
                    loss = outputs.loss.mean()
                    loss.backward()
                    act_grad = torch.cat([p.grad.flatten().clone() for p in model.parameters()])
                    act_var_tmp += torch.sum((act_grad - grad) ** 2)
                    optimizer.zero_grad()
                act_var_tmp /= CAL_VAR_M
                act_var += act_var_tmp

            sgd_grad_mean /= CAL_VAR_M
            sgd_grad_sq_mean /= CAL_VAR_M
            sgd_var = torch.sum(sgd_grad_sq_mean - sgd_grad_mean ** 2)
            act_var /= CAL_VAR_M
            log.info(f"sgd_var: {sgd_var}, act_var: {act_var}")
            linear_bert.test = False

            linear_bert.update_activation_ratio(sgd_var, act_var)
            linear_bert.update_weight_ratio()

            torch.set_rng_state(org_state)

            # eval_loss, eval_acc = eval(model, valid_dataloader, device)
            matched_loss, matched_acc = eval(model, matched_dataloader, device)
            mismatched_loss, mismatched_acc = eval(model, mismatched_dataloader, device)
            S = linear_bert.S
            activation_ratio_schedule = linear_bert.activation_ratio_schedule
            weight_ratio_dict = linear_bert.weight_ratio_dict
            ratio = (1 + sum(activation_ratio_schedule) / len(activation_ratio_schedule) + sum([sum(weight_ratio_dict[i] for i in range(6 * j, 6 * j + 6)) / 6 * activation_ratio_schedule[j] for j in range(12)]) / 12) / 3
            ratio_avg = (ratio_avg * ratio_N + ratio) / (ratio_N + 1)
            logging.info(f"ratio: {ratio}, ratio_avg: {ratio_avg}")
            ratio_N += 1
            train_metric = {"loss": total_loss / CAL_VAR_FREQ, "sgd_var": sgd_var, "act_var": act_var, "S": S, "activation_ratio_first": activation_ratio_schedule[0], "activation_ratio_last": activation_ratio_schedule[-1], "weight_ratio[0]": weight_ratio_dict[0], "weight_ratio[1]": weight_ratio_dict[1], "weight_ratio[2]": weight_ratio_dict[2], "weight_ratio[3]": weight_ratio_dict[3], "weight_ratio[4]": weight_ratio_dict[4], "weight_ratio[5]": weight_ratio_dict[5], "ratio": ratio, "ratio_avg": ratio_avg}
            eval_metric = {"matched_loss": matched_loss, "matched_acc": matched_acc, "mismatched_loss": mismatched_loss, "mismatched_acc": mismatched_acc}
            if args.wandb:
                update_summary(i + epoch * num_per_epoch, train_metric, eval_metric)
            logging.info(f"epoch: {epoch}, step: {i}, train loss: {total_loss / CAL_VAR_FREQ}, matched loss: {matched_loss}, matched acc: {matched_acc}, mismatched loss: {mismatched_loss}, mismatched acc: {mismatched_acc}\n")
            total_loss = 0

torch.cuda.synchronize()
end.record()
log.info(f"\ntraining finished, time elapsed: {start.elapsed_time(end)/1000/60:.4f} mins\n")


model.eval()
predictions = []
labels = []
for batch in matched_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    predictions.extend(torch.argmax(logits, dim=-1).tolist())
    labels.extend(batch["labels"].tolist())
log.info("mnli-m confusion matrix:")
log.info(confusion_matrix(labels, predictions))
log.info("mnli-m classification report:")
log.info(classification_report(labels, predictions, digits=4))
    
predictions = []
labels = []
for batch in mismatched_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    loss = outputs.loss.mean()
    predictions.extend(torch.argmax(logits, dim=-1).tolist())
    labels.extend(batch["labels"].tolist())
# log.info("mnli-mm confusion matrix:")
# log.info(confusion_matrix(labels, predictions))
log.info("mnli-mm classification report:")
log.info(classification_report(labels, predictions, digits=4))

