import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoConfig
from datasets import DatasetDict, Dataset
import datasets
import os
import click
from utils import sp_auac, sp_auroc, top1_acc
import sys
import pickle
import wandb
import colorama
import random

def compute_metrics(eval_pred):
    return {"sp_auac": sp_auac(eval_pred), "accuracy": top1_acc(eval_pred), "sp_auroc": sp_auroc(eval_pred)}

@click.command()
@click.argument("data_path")
@click.argument("model_path")
@click.option("--ood_path", type=str, default=None)
@click.option("--limit_train_ood", type=int, default=None)
@click.option("--num_training_steps", type=int, default=10000)
@click.option("--checkpoint_steps", type=int, default=5000)
@click.option("--label_smoothing/--no_label_smoothing", type=bool, default=False)
@click.option("--do_early_stopping/--no_early_stopping", type=bool, default=False)
@click.option("--seed", type=int, default=123)
@click.option("--model_type", type=str, default="bert")
@click.option("--weighted/--unweighted", default=False)
def main(data_path, ood_path, model_path, limit_train_ood, num_training_steps, checkpoint_steps, seed, weighted, do_early_stopping, label_smoothing, model_type):
    colorama.init()
    wandb.init(project="selective_prediction", name=model_path, settings=wandb.Settings(start_method="fork"))
    if os.path.exists(os.path.join(model_path, "pytorch_model.bin")):
        print(colorama.Fore.RED + f"Model at {model_path} exists, skipping training!")
        sys.exit(0)

    if model_type == "deberta":
        base_model = "microsoft/deberta-v3-base"
        batch_size = 20
        adam_eps = 1e-6
        adam_betas = (0.9, 0.999)
        lr = 2e-05
        grad_steps = 2
        weight_decay = 0.0
    elif model_type == "roberta" or model_type == "roberta_base":
        base_model = "roberta-base"
        batch_size = 40
        adam_eps = 1e-6
        adam_betas = (0.9, 0.98)
        lr = 2e-05
        grad_steps = 1
        weight_decay = 0.01
    elif model_type == "roberta_large":
        base_model = "roberta-large"
        batch_size = 40
        adam_eps = 1e-6
        adam_betas = (0.9, 0.98)
        lr = 1e-05
        grad_steps = 1
        weight_decay = 0.1
    elif model_type == "bert_large":
        base_model = "bert-large-cased"
        batch_size = 40
        adam_eps = 1e-8
        adam_betas = (0.9, 0.999)
        lr = 2e-05
        grad_steps = 1
        weight_decay = 0.01
    else:
        assert "bert" in model_type
        base_model = "bert-base-cased"
        lr = 2e-05
        adam_eps = 1e-8
        adam_betas = (0.9, 0.999)
        batch_size = 40
        grad_steps = 1
        weight_decay = 0.01
    if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok = True)

    tokenizer = AutoTokenizer.from_pretrained(base_model)
    tokenizer.add_special_tokens({"additional_special_tokens":[f"[unused{k}]" for k in range(1000)]})
    
    def tokenize_function(batch):
        if "text" in batch.keys():
            for k, v in tokenizer(batch["text"], truncation=True, padding=True, max_length=128).items():
                batch[k] = v
        elif "sentence1" in batch.keys() and "sentence2" in batch.keys():
            for k, v in tokenizer(batch["sentence1"], batch["sentence2"], padding=True, truncation=True, max_length=128).items():
                batch[k] = v
        return batch

    dataset = DatasetDict.load_from_disk(data_path)
    dataset = dataset.map(tokenize_function, batched=False)
    if ood_path is None or limit_train_ood == 0:
        oe_dataset, oe_weights = None, None
        print(colorama.Fore.BLUE + f"Starting training with {len(dataset['train'])} ID examples")
    elif ood_path.endswith(".pkl"):
        with open(ood_path, "rb") as f:
            oe_dataset = pickle.load(f)
        if weighted:
            assert "cls_prob" in oe_dataset[0].keys(), "Weights are not specified in OE dataset!"
            oe_weights = [1 - x["cls_prob"] for x in oe_dataset][:limit_train_ood] # p_OOD = 1 - p_ID
        else:
            oe_weights = None
        if "decoded_sequence" in oe_dataset.features:
            oe_dataset = Dataset.from_dict({"text": [x["decoded_sequence"] for x in oe_dataset]})
        else:
            oe_dataset = Dataset.from_dict({"text": [x["text"] for x in oe_dataset]})
    else:
        oe_dataset = Dataset.load_from_disk(ood_path)
        if "label" in oe_dataset.features:
            oe_dataset = oe_dataset.remove_columns(["label"])
        oe_weights = None

    if oe_dataset is not None:
        oe_dataset = oe_dataset.map(tokenize_function, batched=False)
        if limit_train_ood is not None:
            oe_dataset = oe_dataset.shuffle(seed=seed)
            oe_dataset = oe_dataset.select(range(min(limit_train_ood, len(oe_dataset))))
        print(colorama.Fore.BLUE + f"Starting training with {len(dataset['train'])} ID examples and {len(oe_dataset)} OOD examples!")
    print(colorama.Style.RESET_ALL)

    labels = list(set(dataset["train"]["label"]))
    label2id= {lb: i for i, lb in enumerate(labels)}
    id2label = {i: lb for lb, i in label2id.items()}
    # Convert to IDs
    def map_label2id(ex):
        if ex["label"] in label2id:
            ex["label"] = label2id[ex["label"]]
        else:
            #ex["label"] = -1 # Mark as outlier
            ex["label"] = -1
        return ex
    for phase in ["train", "validation"]:
        dataset[phase] = dataset[phase].map(map_label2id)
    config = AutoConfig.from_pretrained(base_model, id2label=id2label, label2id=label2id, num_labels=len(labels))
    model = AutoModelForSequenceClassification.from_pretrained(base_model, config=config)
    if base_model.startswith("roberta"):
        model.resize_token_embeddings(len(tokenizer))
    training_args = TrainingArguments(
        output_dir=model_path,
        max_steps=num_training_steps,
        per_device_train_batch_size=batch_size,  
        per_device_eval_batch_size=batch_size,   
        gradient_accumulation_steps=grad_steps,
        warmup_steps=500,                
        adam_epsilon=adam_eps,
        max_grad_norm=float('inf'),
        adam_beta1 = adam_betas[0],
        adam_beta2 = adam_betas[1],
        weight_decay=weight_decay,               
        logging_dir='./logs',            
        logging_strategy="steps",
        logging_steps=10,
        learning_rate=lr,
        save_strategy="steps",
        save_total_limit=1,
        evaluation_strategy="steps",
        save_steps=checkpoint_steps,
        eval_steps=checkpoint_steps,
        report_to="wandb",
        seed=seed,
        load_best_model_at_end=do_early_stopping,
        fp16=True,
        metric_for_best_model="accuracy"
    )
    
    train_dataset = dataset["train"]
    trainer = Trainer(model=model, 
            args=training_args, 
            train_dataset=train_dataset, 
            eval_dataset=dataset["validation"], 
            oe_dataset=oe_dataset,
            oe_weights=oe_weights,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics)
    trainer.train()
    model.save_pretrained(model_path)

if __name__ == "__main__":
    main()
