from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.strategies.fsdp import FSDPStrategy
from lightning.fabric import Fabric
from training_modules import LitTrainingDataModule, LitT5ForConditionalGeneration, LitLlamaForCausalLM
from alignment_modules import LitAlignmentDataModule, AlignmentT5ForConditionalGeneration, AlignmentLlamaForCausalLM
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, LlamaTokenizer
from pytorch_lightning.loggers import TensorBoardLogger
from datasets import load_from_disk
from utils import save_model_checkpoint
from peft import get_peft_model, PrefixTuningConfig, TaskType
import pandas as pd
import numpy as np
import os
import random
import torch
import argparse
import time
import time


categories = {
    "MMLU": ["MMLU_General"],
    "BBL-BC": ["Play_Dialog", "StrategyQA", "Strange_Stories", "Winowhy"],
    "BBL-MC": ["Vitaminc_Fact_Verification", "Language_Identification"],
    "BBL-QA": ["BBQ_Lite", "Code_Line_Description", "Logical_Deduction", "Known_Unknowns", "Hindu_Knowledge", "Novel_Concepts", "Logic_Grid_Puzzle", "Conceptual_Combinations"],
}


def run(args):
    
    seed_everything(args.seed)
    isAlpaca = "alpaca" in args.model_name_or_path or "llama" in args.model_name_or_path
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) if not isAlpaca else LlamaTokenizer.from_pretrained(args.model_name_or_path)
    model_name = args.model_name
    if os.path.exists(os.path.join(args.output_dir, model_name)):
        idx = 1
        while os.path.exists(os.path.join(args.output_dir, model_name + "_" + str(idx))):
            idx += 1
        model_name = model_name + "_" + str(idx)
    output_dir = os.path.join(args.output_dir, model_name)
    logger = TensorBoardLogger(save_dir=output_dir, name="log")

    hparams = {
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "soft_prompt": args.soft_prompt,
        "alignment_weight": args.alignment_weight,
        "min_epoch": args.min_epoch,
        "max_epoch": args.max_epoch,
        "negation_alignment": args.negation_alignment,
    }

    if args.precision == "bf16":
        torch.set_float32_matmul_precision("high")
        torch_dtype = torch.bfloat16
        trainer_precision = "bf16-mixed"
    elif args.precision == "fp16":
        torch_dtype = torch.float16
        trainer_precision = 32
    else:
        torch_dtype = torch.float32
        trainer_precision = 32

    if isAlpaca:
        DataModule, ModelModule = (LitTrainingDataModule, LitLlamaForCausalLM) if not args.alignment else (LitAlignmentDataModule, AlignmentLlamaForCausalLM)
    else:
        DataModule, ModelModule = (LitTrainingDataModule, LitT5ForConditionalGeneration) if not args.alignment else (LitAlignmentDataModule, AlignmentT5ForConditionalGeneration)

    strategy = "ddp" if len(args.devices) > 1 else "auto"
    if strategy == "ddp" and isAlpaca and "train" in args.job_type and len(args.devices) > 1:
        print("FSDP!")
        strategy = FSDPStrategy(cpu_offload=False)

    trainer = Trainer(
        accelerator="gpu",
        devices=args.devices,
        min_epochs=args.min_epoch,
        max_epochs=args.max_epoch,
        logger=logger,
        accumulate_grad_batches=args.accumulate_grad_batches,
        strategy=strategy,
        precision=trainer_precision,
    )

    logger.log_hyperparams(hparams)
    train_set = load_from_disk(args.train_set_dir)
    print("Training set size:", len(train_set))

    data_module = DataModule(
        args.train_batch_size,
        args.test_batch_size, 
        train_set,
        args.test_set_dir,
        len(args.devices) > 1,
        tokenizer,
        isAlpaca,
    )

    data_module.prepare_data()
    if isAlpaca:
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
        if args.soft_prompt:
            peft_config = PrefixTuningConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                num_virtual_tokens=args.num_virtual_tokens
            )
            model = get_peft_model(model, peft_config)
            model.print_trainable_parameters()
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
        if args.soft_prompt:
            peft_config = PrefixTuningConfig(
                task_type=TaskType.SEQ_2_SEQ_LM,
                inference_mode=False,
                num_virtual_tokens=args.num_virtual_tokens
            )
            model = get_peft_model(model, peft_config)
            model.print_trainable_parameters()
    
    model = ModelModule(model, **hparams)
    fabric = Fabric(devices=args.devices)
    fabric.launch()

    if args.job_type == "test":
        t1 = time.time()
        test(trainer, model, data_module, output_dir, logger, fabric, args.vanilla_result_path)
        t2 = time.time()
        if fabric.local_rank in [-1, 0]:
            print("Total time: {:.2f} minutes".format((t2 - t1)/60))
    else:
        model = train(trainer, model, tokenizer, data_module, output_dir, strategy, logger, fabric)
        if args.job_type == "train+test":
            test(trainer, model, data_module, output_dir, output_dir, logger, fabric, fabric, args.vanilla_result_path)


def test(trainer, model, data_module, output_dir, logger, fabric: Fabric, vanilla_result_path=None):

    def print_result(df: pd.DataFrame, msg: str):
        print(msg)
        by_category = dict()
        for i in df.index:
            row = df.loc[i]
            print(row["Dataset"] + " " + row["Setting"] + ": {:.2f}".format(row["Accuracy"]*100))
            for key in categories.keys():
                if row["Dataset"] in categories[key]:
                    category = "{} {}".format(key, row["Setting"])
                    if category not in by_category.keys():
                        by_category[category] = []
                    by_category[category].append(row["Accuracy"])
        
        print("\n\nResult by Categories:\n")
        for key in by_category.keys():
            print("{}: {:.2f}".format(key, np.mean(by_category[key])*100))

    def combine_result(df_before: pd.DataFrame, df_after: pd.DataFrame):
        df = pd.DataFrame(columns=["Dataset", "Setting", "Accuracy"])
        for i in df_after.index:
            dataset = df_after.loc[i]["Dataset"]
            setting = df_after.loc[i]["Setting"]
            delta = df_after.loc[i]["Accuracy"] - df_before[(df_before["Dataset"] == dataset) & (df_before["Setting"] == setting)]["Accuracy"].values[0]
            df.loc[len(df.index)] = [dataset, setting, delta]
        return df

    if vanilla_result_path is not None:
        vanilla_pd = pd.read_csv(vanilla_result_path, index_col=None)
        if fabric.local_rank in [-1, 0]:
            print_result(vanilla_pd, "Before Training...\n\n\n")
    
    result_pd = pd.DataFrame(columns=["Dataset", "Setting", "Accuracy"])
    test_dicts = data_module.test_dataloaders()

    if fabric.local_rank in [-1, 0]:
        print("Testing...")
    for test_dict in test_dicts:
        dataset = test_dict["name"]
        setting = test_dict["setting"]
        if fabric.local_rank in [-1, 0]:
            print("Testing: {} {}".format(dataset, setting))

        dataloader = test_dict["dataset"]
        acc = trainer.test(model=model, dataloaders=dataloader, verbose=False)
        acc = acc[0]["test_acc"]
        if fabric.local_rank in [-1, 0]:
            print("Accuracy: {:.2f}".format(acc*100))
        result_pd.loc[len(result_pd.index)] = [dataset, setting, acc]

    if fabric.local_rank in [-1, 0]:
        result_pd.to_csv(os.path.join(output_dir, "result.csv"), index=False)
    
    if vanilla_result_path is not None:
        result_pd = combine_result(vanilla_pd, result_pd)
        if fabric.local_rank in [-1, 0]:
            print_result(result_pd, "Delta Change After Training...\n\n\n")
    else:
        if fabric.local_rank in [-1, 0]:
            print_result(result_pd, "After Training...\n\n\n")


def train(trainer, model, tokenizer, data_module, output_dir, strategy, logger: TensorBoardLogger, fabric: Fabric):
    fabric.barrier()
    trainer.fit(model, train_dataloaders=data_module.train_dataloader())
    fabric.barrier()
    save_model_checkpoint(fabric, strategy, model, tokenizer, output_dir)
    fabric.barrier()
    return model


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', default="google/flan-t5-xl")
    parser.add_argument('--model_name', default="flan-t5-xl-prefix-negation")
    parser.add_argument('--output_dir', default="./models/negation/")
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--devices', default=[0], type=list)
    parser.add_argument('--precision', default="bf16", choices=["fp16", "bf16"], type=str)

    parser.add_argument('--job_type', default="train", choices=["train", "test", "train+test"], type=str)
    parser.add_argument('--vanilla_result_path', default=None, type=str)

    # Architecture
    parser.add_argument('--soft_prompt', default=False, action="store_true")
    parser.add_argument('--alignment', default=False, action="store_true")
    parser.add_argument('--alignment_weight', default=0.3, type=float)
    parser.add_argument('--num_virtual_tokens', default=10, type=int)

    # Training
    parser.add_argument('--train_set_dir', default="./negation_data/flan_train_984_3_alignment")
    parser.add_argument('--train_batch_size', default=4, type=int)
    parser.add_argument('--accumulate_grad_batches', default=4, type=int)

    # Test
    parser.add_argument('--test_set_dir', default="./training_data/alpaca_test")
    parser.add_argument('--test_batch_size', default=16, type=int)

    # Hyperparameters
    parser.add_argument('--min_epoch', default=1, type=int)
    parser.add_argument('--max_epoch', default=5, type=int)

    parser.add_argument('--learning_rate', default=2e-5, type=float)
    parser.add_argument('--weight_decay', default=1e-5, type=float)
    
    parser.add_argument('--checkpoint_epochs', default=None, type=list)
    
    # Rebuttal Experiment
    parser.add_argument('--negation_alignment', default=True)

    args = parser.parse_args()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
        
    if args.alignment:
        assert "alignment" in args.train_set_dir
    
    if args.negation_alignment:
        assert args.alignment
        
    random.seed(args.seed)
    run(args)


if __name__ == "__main__":
    main()

