from datasets import load_dataset
import transformers
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, get_scheduler, BertConfig, default_data_collator
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import evaluate
import wandb
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import random
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import logging
import itertools

from transformers.models.bert import linear_bert_ub_test_2

from utils import *

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

parser = argparse.ArgumentParser()
parser.add_argument("--task_name", type=str, choices=list(task_to_keys.keys()))
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--log_name", type=str)
parser.add_argument("--seed", type=int)
parser.add_argument("--eval_freq", type=int)
parser.add_argument("--sample_ratio", type=float, default=0.5)
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--wandb_group", type=str)
args = parser.parse_args()

LOG_FORMAT = "[%(asctime)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, filename=f"logs/{args.log_name}.log")
log = logging.getLogger(__name__)

linear_bert_ub_test_2.sample_ratio = args.sample_ratio

PROJECT = "your_project"
ENTITY = "your_entity"

if args.wandb:
    wandb.init(name=args.log_name, project=PROJECT, config=args, entity=ENTITY)

NUM_EPOCHS = 3
LR = 2e-5

SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

logging.info("loading dataset")
raw_datasets = load_dataset("glue", args.task_name)

# Labels
is_regression = args.task_name == "stsb"
if not is_regression:
    label_list = raw_datasets["train"].features["label"].names
    num_labels = len(label_list)
else:
    num_labels = 1

checkpoint = "bert-base-uncased"

logging.info("loading model")
config = BertConfig.from_pretrained(checkpoint, num_labels=num_labels)
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

logging.info("preprocessing dataset")
sentence1_key, sentence2_key = task_to_keys[args.task_name]

if args.task_name is None and not is_regression:
    label_to_id = {v: i for i, v in enumerate(label_list)}

if args.task_name is not None and not is_regression:
    model.config.label2id = {l: i for i, l in enumerate(label_list)}
    model.config.id2label = {id: label for label, id in config.label2id.items()}


def preprocess_function(examples):
    # Tokenize the texts
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*texts, max_length=args.max_length, truncation=True)

    if "label" in examples:
        result["labels"] = examples["label"]
    return result

processed_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]

# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
#     logging.info(f"Sample {index} of the training set: {train_dataset[index]}.")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.batch_size)

WEIGHT_DECAY = 0.01
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": WEIGHT_DECAY,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=LR)

num_per_epoch = len(train_dataloader)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.1 * num_training_steps,
    num_training_steps=num_training_steps,
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

progress_bar = tqdm(range(num_training_steps))
completed_steps = 0
starting_epoch = 0

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

pure_train_time_elapsed = 0
train_time_elapsed = 0
eval_time_elapsed = 0

ratio_N = 0
ratio_avg = 0

total_loss = 0

logging.info(f"using device:{device}, start training")
start.record()
for epoch in range(NUM_EPOCHS):
    model.train()
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss.mean()
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        completed_steps += 1

        if completed_steps % args.eval_freq == 0:

            
            linear_bert_ub_test_2.test = True
            linear_bert_ub_test_2.prepare(device)
            org_state = torch.get_rng_state().clone()                
            for i, batch in enumerate(itertools.islice(train_dataloader, 10)):
                batch = {k: v.to(device) for k, v in batch.items()}
                linear_bert_ub_test_2.sample = False
                setup_seed(i)
                outputs = model(**batch)
                loss = outputs.loss.mean()
                loss.backward()
                optimizer.zero_grad()
                
                linear_bert_ub_test_2.sample = True
                setup_seed(i)
                outputs = model(**batch)
                loss = outputs.loss.mean()
                loss.backward()
                optimizer.zero_grad()
            linear_bert_ub_test_2.test = False

            sgd_var, act_var = linear_bert_ub_test_2.cal_var()
            linear_bert_ub_test_2.reset_dict()

            torch.set_rng_state(org_state)

            end.record()
            torch.cuda.synchronize()
            train_time_elapsed += start.elapsed_time(end)
            
            start.record()
            eval_loss, eval_acc = eval(model, eval_dataloader, device)
            train_metric = {"loss": total_loss / args.eval_freq, "sgd_var": sgd_var, "act_var": act_var}
            eval_metric = {"loss": eval_loss, "acc": eval_acc}
            if args.wandb:
                update_summary(completed_steps, train_metric, eval_metric)
            logging.info(f"\nEpoch {epoch} - Step {completed_steps} - Train loss: {total_loss / args.eval_freq} - Eval loss: {eval_loss} - Eval acc: {eval_acc}\n")
            total_loss = 0

            end.record()
            torch.cuda.synchronize()
            eval_time_elapsed += start.elapsed_time(end)
            start.record()
    end.record()
    torch.cuda.synchronize()
    train_time_elapsed += start.elapsed_time(end)

logging.info(f"training finished, train time: {train_time_elapsed/1000/60} min, eval time: {eval_time_elapsed/1000/60} min, pure train time: {pure_train_time_elapsed/1000/60} min")   


# model.eval()
# predictions = []
# labels = []
# for batch in eval_dataloader:
#     batch = {k: v.to(device) for k, v in batch.items()}
#     with torch.no_grad():
#         outputs = model(**batch)
#     logits = outputs.logits
#     predictions.extend(torch.argmax(logits, dim=-1).tolist())
#     labels.extend(batch["labels"].tolist())
# log.info("confusion matrix:")
# log.info(confusion_matrix(labels, predictions))
# log.info("classification report:")
# log.info(classification_report(labels, predictions, digits=4))

