from datasets import load_dataset
import transformers
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, get_scheduler, BertConfig, default_data_collator
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import evaluate
import wandb
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import random
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import logging
import itertools
import math
import collections

from utils import *

from transformers.models.bert import linear_bert_loss_test

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

parser = argparse.ArgumentParser()
parser.add_argument("--task_name", type=str, choices=list(task_to_keys.keys()))
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--log_name", type=str)
parser.add_argument("--seed", type=int)
parser.add_argument("--eval_freq", type=int)
parser.add_argument("--beta", type=float, default=2)
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--wandb_group", type=str)
args = parser.parse_args()

LOG_FORMAT = "[%(asctime)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, filename=f"logs/{args.log_name}.log")
# log = logging.getLogger(__name__)

PROJECT = "your_project"
ENTITY = "your_entity"

if args.wandb:
    wandb.init(name=args.log_name, project=PROJECT, config=args, entity=ENTITY)

NUM_EPOCHS = 3
LR = 2e-5

SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

logging.info("loading dataset")
raw_datasets = load_dataset("glue", args.task_name)

# Labels
is_regression = args.task_name == "stsb"
if not is_regression:
    label_list = raw_datasets["train"].features["label"].names
    num_labels = len(label_list)
else:
    num_labels = 1

checkpoint = "bert-base-uncased"

logging.info("loading model")
config = BertConfig.from_pretrained(checkpoint, num_labels=num_labels)
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

logging.info("preprocessing dataset")
sentence1_key, sentence2_key = task_to_keys[args.task_name]

if args.task_name is None and not is_regression:
    label_to_id = {v: i for i, v in enumerate(label_list)}

if args.task_name is not None and not is_regression:
    model.config.label2id = {l: i for i, l in enumerate(label_list)}
    model.config.id2label = {id: label for label, id in config.label2id.items()}


def preprocess_function(examples):
    # Tokenize the texts
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*texts, max_length=args.max_length, truncation=True)

    if "label" in examples:
        result["labels"] = examples["label"]
    return result

processed_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]

# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
#     logging.info(f"Sample {index} of the training set: {train_dataset[index]}.")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.batch_size)

WEIGHT_DECAY = 0.01
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": WEIGHT_DECAY,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=LR)

num_per_epoch = len(train_dataloader)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.1 * num_training_steps,
    num_training_steps=num_training_steps,
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class UnboundedHistogram:
    def __init__(self, max_history):
        self.max_history =  max_history
        self.history = collections.deque(maxlen=self.max_history)

    def append(self, value):
        self.history.append(value)

    def get_count(self, it, score):
        count = 0
        for i in it:
            if i < score:
                count += 1
        return count

    def percentile_of_score(self, score):
        num_lower_scores = self.get_count(self.history, score)
        return num_lower_scores * 100. / len(self.history)

class LossSelector(object):
    def __init__(self, device, sampling_min, history_length, beta):
        self.device = device
        self.historical_losses = UnboundedHistogram(history_length) #collections.deque(maxlen=history_length)
        self.sampling_min = sampling_min
        self.beta = beta

    def update_history(self, losses):
        for loss in losses:
            self.historical_losses.append(loss.item())

    def calculate_probability(self, loss):
        percentile = self.historical_losses.percentile_of_score(loss)
        return math.pow(percentile / 100., self.beta)

    def get_probability(self, losses):
        self.update_history(losses)
        probs = torch.tensor([max(self.sampling_min, self.calculate_probability(loss.item())) for loss in losses], device=self.device, dtype=torch.float32)
        return probs
    
    def get_selected_loss(self, losses):
        probs = self.get_probability(losses)
        mask = torch.rand_like(probs) < probs
        selected_loss = (losses[mask]).mean()
        return selected_loss
    
loss_selector = LossSelector(device, 1e-9, 1024, args.beta)

progress_bar = tqdm(range(num_training_steps))
completed_steps = 0
starting_epoch = 0

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

pure_train_time_elapsed = 0
train_time_elapsed = 0
eval_time_elapsed = 0

ratio_N = 0
ratio_avg = 0

total_loss = 0

logging.info(f"using device:{device}, start training")
start.record()
for epoch in range(NUM_EPOCHS):
    model.train()
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        losses = outputs.loss
        total_loss += losses.mean().detach().float()
        selected_loss = loss_selector.get_selected_loss(losses)
        selected_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        completed_steps += 1

        if completed_steps % args.eval_freq == 0:

            linear_bert_loss_test.test = True
            linear_bert_loss_test.prepare(device)
            org_state = torch.get_rng_state().clone()                
            for i, batch in enumerate(itertools.islice(train_dataloader, 10)):
                batch = {k: v.to(device) for k, v in batch.items()}
                linear_bert_loss_test.sample = False
                setup_seed(i)
                outputs = model(**batch)
                loss = outputs.loss.mean()
                loss.backward()
                optimizer.zero_grad()
                
                linear_bert_loss_test.sample = True
                for _ in range(10):
                    setup_seed(i)
                    outputs = model(**batch)
                    losses = outputs.loss
                    selected_loss = loss_selector.get_selected_loss(losses)
                    selected_loss.backward()
                    optimizer.zero_grad()
            linear_bert_loss_test.test = False

            sgd_var, act_var = linear_bert_loss_test.cal_var()
            linear_bert_loss_test.reset_dict()

            torch.set_rng_state(org_state)

            end.record()
            torch.cuda.synchronize()
            train_time_elapsed += start.elapsed_time(end)
            
            start.record()
            eval_loss, eval_acc = eval(model, eval_dataloader, device)
            train_metric = {"loss": total_loss / args.eval_freq, "sgd_var": sgd_var, "act_var": act_var}
            eval_metric = {"loss": eval_loss, "acc": eval_acc}
            if args.wandb:
                update_summary(completed_steps, train_metric, eval_metric)
            logging.info(f"\nEpoch {epoch} - Step {completed_steps} - Train loss: {total_loss / args.eval_freq} - Eval loss: {eval_loss} - Eval acc: {eval_acc}\n")
            total_loss = 0

            end.record()
            torch.cuda.synchronize()
            eval_time_elapsed += start.elapsed_time(end)
            start.record()
    end.record()
    torch.cuda.synchronize()
    train_time_elapsed += start.elapsed_time(end)

logging.info(f"training finished, train time: {train_time_elapsed/1000/60} min, eval time: {eval_time_elapsed/1000/60} min, pure train time: {pure_train_time_elapsed/1000/60} min")   


# model.eval()
# predictions = []
# labels = []
# for batch in eval_dataloader:
#     batch = {k: v.to(device) for k, v in batch.items()}
#     with torch.no_grad():
#         outputs = model(**batch)
#     logits = outputs.logits
#     predictions.extend(torch.argmax(logits, dim=-1).tolist())
#     labels.extend(batch["labels"].tolist())
# log.info("confusion matrix:")
# log.info(confusion_matrix(labels, predictions))
# log.info("classification report:")
# log.info(classification_report(labels, predictions, digits=4))

