from datasets import load_dataset
import transformers
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, get_scheduler, BertConfig, default_data_collator
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import evaluate
import wandb
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import random
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import logging
import itertools
from transformers.models.bert import linear_bert

from utils import *

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

parser = argparse.ArgumentParser()
parser.add_argument("--task_name", type=str, choices=list(task_to_keys.keys()))
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--log_name", type=str)
parser.add_argument("--seed", type=int)
parser.add_argument("--cal_var_freq", type=int)
parser.add_argument("--eval_freq", type=int)
parser.add_argument("--cal_var_m", type=int)
parser.add_argument("--s", type=float)
parser.add_argument("--act_var_tolerance", type=float)
parser.add_argument("--weight_var_tolerance", type=float)
parser.add_argument("--s_update_step", type=float)
parser.add_argument("--weight_ratio_multiplier", type=float)
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--wandb_group", type=str)
args = parser.parse_args()

LOG_FORMAT = "[%(asctime)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, filename=f"logs/{args.log_name}.log")
# log = logging.getLogger(__name__)

PROJECT = "your_project"
ENTITY = "your_entity"

if args.wandb:
    wandb.init(name=args.log_name, project=PROJECT, config=args, entity=ENTITY)
NUM_EPOCHS = 3
LR = 2e-5

SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

CAL_VAR_FREQ = args.cal_var_freq
CAL_VAR_M = args.cal_var_m
linear_bert.CAL_VAR_M = CAL_VAR_M

linear_bert.S = args.s
linear_bert.ACT_VAR_TOLERANCE = args.act_var_tolerance
linear_bert.WEIGHT_VAR_TOLERANCE = args.weight_var_tolerance
linear_bert.S_UPDATE_STEP = args.s_update_step
linear_bert.WEIGHT_RATIO_MULTIPLIER = args.weight_ratio_multiplier
logging.info(f"S: {linear_bert.S}, ACT_VAR_TOLERANCE: {linear_bert.ACT_VAR_TOLERANCE}, WEIGHT_VAR_TOLERANCE: {linear_bert.WEIGHT_VAR_TOLERANCE}, S_UPDATE_STEP: {linear_bert.S_UPDATE_STEP}, WEIGHT_RATIO_MULTIPLIER: {linear_bert.WEIGHT_RATIO_MULTIPLIER}")

logging.info("loading dataset")
raw_datasets = load_dataset("glue", args.task_name)

# Labels
is_regression = args.task_name == "stsb"
if not is_regression:
    label_list = raw_datasets["train"].features["label"].names
    num_labels = len(label_list)
else:
    num_labels = 1

checkpoint = "bert-base-uncased"

logging.info("loading model")
config = BertConfig.from_pretrained(checkpoint, num_labels=num_labels)
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

logging.info("preprocessing dataset")
sentence1_key, sentence2_key = task_to_keys[args.task_name]

if args.task_name is None and not is_regression:
    label_to_id = {v: i for i, v in enumerate(label_list)}

if args.task_name is not None and not is_regression:
    model.config.label2id = {l: i for i, l in enumerate(label_list)}
    model.config.id2label = {id: label for label, id in config.label2id.items()}


def preprocess_function(examples):
    # Tokenize the texts
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*texts, max_length=args.max_length, truncation=True)

    if "label" in examples:
        result["labels"] = examples["label"]
    return result

processed_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]

# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
#     logging.info(f"Sample {index} of the training set: {train_dataset[index]}.")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.batch_size)

WEIGHT_DECAY = 0.01
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": WEIGHT_DECAY,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=LR)

num_per_epoch = len(train_dataloader)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.1 * num_training_steps,
    num_training_steps=num_training_steps,
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

progress_bar = tqdm(range(num_training_steps))
completed_steps = 0
starting_epoch = 0

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

pure_train_time_elapsed = 0
train_time_elapsed = 0
eval_time_elapsed = 0

ratio_N = 0
ratio_avg = 0

total_loss = 0

logging.info(f"using device:{device}, start training")
start.record()
for epoch in range(NUM_EPOCHS):
    model.train()
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss.mean()
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        completed_steps += 1

        if completed_steps % args.cal_var_freq == 0:
            end.record()
            torch.cuda.synchronize()
            pure_train_time_elapsed += start.elapsed_time(end)

            linear_bert.test = True
            linear_bert.prepare(device)
            org_state = torch.get_rng_state().clone()                
            for i, batch in enumerate(itertools.islice(train_dataloader, args.cal_var_m)):
                batch = {k: v.to(device) for k, v in batch.items()}
                linear_bert.sample = False
                setup_seed(i)
                outputs = model(**batch)
                loss = outputs.loss.mean()
                loss.backward()
                optimizer.zero_grad()
                
                linear_bert.sample = True
                for _ in range(args.cal_var_m):
                    setup_seed(i)
                    outputs = model(**batch)
                    loss = outputs.loss.mean()
                    loss.backward()
                    optimizer.zero_grad()
            linear_bert.test = False

            linear_bert.cal_var()
            sgd_var, act_var = linear_bert.update_activation_ratio()
            weight_var = linear_bert.update_weight_ratio()
            linear_bert.reset_dict()

            torch.set_rng_state(org_state)

            S = linear_bert.S
            activation_ratio_schedule = linear_bert.activation_ratio_schedule
            weight_ratio_dict = linear_bert.weight_ratio_dict
            ratio = (1 + sum(activation_ratio_schedule) / len(activation_ratio_schedule) + sum([sum(weight_ratio_dict[i] for i in range(6 * j, 6 * j + 6)) / 6 * activation_ratio_schedule[j] for j in range(len(activation_ratio_schedule))]) / len(activation_ratio_schedule)) / 3 + args.cal_var_m / args.cal_var_freq + args.cal_var_m ** 2 * (1 + 2 * sum(activation_ratio_schedule) / len(activation_ratio_schedule)) / 3 / args.cal_var_freq
            ratio_avg = (ratio_avg * ratio_N + ratio) / (ratio_N + 1)
            logging.info(f"ratio: {ratio}, ratio_avg: {ratio_avg}")
            ratio_N += 1
            sample_metric = {"sgd_var": sgd_var, "act_var": act_var, "weight_var": weight_var, "S": S, "activation_ratio_first": activation_ratio_schedule[0], "activation_ratio_last": activation_ratio_schedule[-1], "weight_ratio[0]": weight_ratio_dict[0], "weight_ratio[1]": weight_ratio_dict[1], "weight_ratio[2]": weight_ratio_dict[2], "weight_ratio[3]": weight_ratio_dict[3], "weight_ratio[4]": weight_ratio_dict[4], "weight_ratio[5]": weight_ratio_dict[5], "ratio": ratio, "ratio_avg": ratio_avg}
            if args.wandb:
                update_summary(completed_steps, sample_metric, {})

            end.record()
            torch.cuda.synchronize()
            train_time_elapsed += start.elapsed_time(end)
            start.record()

        if completed_steps % args.eval_freq == 0:
            end.record()
            torch.cuda.synchronize()
            train_time_elapsed += start.elapsed_time(end)
            
            start.record()
            eval_loss, eval_acc = eval(model, eval_dataloader, device)
            train_metric = {"loss": total_loss / args.eval_freq}
            eval_metric = {"loss": eval_loss, "acc": eval_acc}
            if args.wandb:
                update_summary(completed_steps, train_metric, eval_metric)
            logging.info(f"\nEpoch {epoch} - Step {completed_steps} - Train loss: {total_loss / args.eval_freq} - Eval loss: {eval_loss} - Eval acc: {eval_acc}\n")
            total_loss = 0

            end.record()
            torch.cuda.synchronize()
            eval_time_elapsed += start.elapsed_time(end)
            start.record()
    end.record()
    torch.cuda.synchronize()
    train_time_elapsed += start.elapsed_time(end)

logging.info(f"training finished, train time: {train_time_elapsed/1000/60} min, eval time: {eval_time_elapsed/1000/60} min, pure train time: {pure_train_time_elapsed/1000/60} min")   


# model.eval()
# predictions = []
# labels = []
# for batch in eval_dataloader:
#     batch = {k: v.to(device) for k, v in batch.items()}
#     with torch.no_grad():
#         outputs = model(**batch)
#     logits = outputs.logits
#     predictions.extend(torch.argmax(logits, dim=-1).tolist())
#     labels.extend(batch["labels"].tolist())
# log.info("confusion matrix:")
# log.info(confusion_matrix(labels, predictions))
# log.info("classification report:")
# log.info(classification_report(labels, predictions, digits=4))

