from pytorch_lightning import Trainer, seed_everything
from data_utils.dataloader import DSTDataLoader, DSTDataSet
from model.bert import BERTAlignModel
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader

seed_everything(2022)

NEED_MLM = False
IS_FINETUNE = False ### calculate f1 at auto validation

MODEL_NAME = "roberta-large"

BATCH_SIZE = 16
ACCUMULATE_GRAD_BATCH = 1
NUM_EPOCH = 3
NUM_WORKERS = 8
WARMUP_PROPORTION = 0.06
ADAM_EPSILON = 1e-6
WEIGHT_DECAY = 0.1
LR = 1e-5
VAL_CHECK_INTERVAL = 1. / 4
DEVICES=[4,5]
# CHECKPOINT_PATH = 'checkpoints/bert/bert_small_mlm_mnli_paws_vitaminc_anli_r1_anli_r2_anli_r3_snli_1.0_16x2x8-v1.ckpt' ### labeled dataset 1.0
# CHECKPOINT_PATH = 'checkpoints/bert-base-uncased/bert-base-uncased_no_mlm_xsum_cnndm_mnli_squad_paws_paws_qqp_vitaminc_race_anli_r1_anli_r2_anli_r3_snli_wikihow_msmarco_paws_unlabeled_wiki103_qqp_stsb_sick_ctc_500000_16x4x4_final.ckpt' ### all datasets 1.0
CHECKPOINT_PATH = 'checkpoints/nli-baseline/roberta-large/NLI_roberta-large_no_mlm_mnli_doc_nli_anli_r1_anli_r2_anli_r3_snli_500000_32x4x1_final.ckpt'
'''
Note:
The original BERT and RoBERTa use regression (MSE) to finetune on STS-B. In this alignment model, we still use 2-label classification loss.
'''

def get_few_shot_dataset_loader(dataset_name, dataset_split, few_shot_k, template, aligned_label, text_key='text', label_key='label'):
    dataset = load_dataset(*dataset_name)[dataset_split]
    group_by_label = dataset.to_pandas().groupby(label_key)
    samples = group_by_label.sample(n=few_shot_k)

    formatted_dataset = []
    for _, sample in samples.iterrows():
        formatted_dataset.append({
            "task": "paraphrase",
            "text_a": sample[text_key],
            "text_b": [template],
            "text_c": [],
            "orig_label": int(sample[label_key] == aligned_label)
        })

    dst_dataset = DSTDataSet(dataset=formatted_dataset, model_name=MODEL_NAME, need_mlm=NEED_MLM)

    return DataLoader(dst_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

### exp: 10k --> 50k, show quantity effect || 50k has_label_task --> 50k no_label_task, show the label effect || total 4 exps
# DATA_SIZE = 50000
# TRAINING_DATASETS = {
#     'cnndm': {'task_type': 'summarization', 'data_path': 'data/cnndm.json', 'size': DATA_SIZE},
#     'mnli': {'task_type': 'nli', 'data_path': 'data/mnli.json', 'size': DATA_SIZE},
#     'squad': {'task_type': 'qa', 'data_path': 'data/squad.json', 'size': DATA_SIZE},
#     'paws': {'task_type': 'paraphrase', 'data_path': 'data/paws.json', 'size':DATA_SIZE},
#     'vitaminc': {'task_type': 'fact_checking', 'data_path': 'data/vitaminc.json', 'size':DATA_SIZE},
#     'xsum': {'task_type': 'summarization', 'data_path': 'data/xsum.json', 'size':DATA_SIZE}
# }
def main():
    dm = DSTDataLoader(dataset_config=TRAINING_DATASETS, val_dataset_config=VAL_TRAINING_DATASETS, model_name=MODEL_NAME, need_mlm=NEED_MLM, is_finetune=IS_FINETUNE,
                        train_batch_size=BATCH_SIZE, eval_batch_size=16, num_workers=NUM_WORKERS, train_eval_split=0.95)
    dm.setup()

    # dataset = get_few_shot_dataset_loader(('SetFit/SentEval-CR',), 'train', few_shot_k=8, template='It was great.', aligned_label="positive", label_key='label_text')

    model = BERTAlignModel(model=MODEL_NAME,
                                    adam_epsilon=ADAM_EPSILON,
                                    learning_rate=LR,
                                    weight_decay=WEIGHT_DECAY,
                                    warmup_steps_portion=WARMUP_PROPORTION).load_from_checkpoint(checkpoint_path=CHECKPOINT_PATH, strict=False)
    model.need_mlm = NEED_MLM
    model.is_finetune = IS_FINETUNE

    # torch.nn.init.xavier_uniform_(model.model.cls.seq_relationship.weight)
    # nn.init.constant_(model.model.cls.seq_relationship.bias.data, 0.0)

    training_dataset_used = '_'.join([dataset_name for dataset_name in TRAINING_DATASETS.keys()])
    checkpoint_callback = ModelCheckpoint(
        dirpath=f'checkpoints/nli-finetune/{MODEL_NAME}/',
        filename=f"{MODEL_NAME}_finetune{ '_no_mlm' if not NEED_MLM else ''}_{training_dataset_used}_{DATA_SIZE}_{BATCH_SIZE}x{len(DEVICES)}x{ACCUMULATE_GRAD_BATCH}", # batch_size, device, acc_batch ## standard 16x2x8
        # monitor='f1',
        # mode='max'
    )
    trainer = Trainer(max_epochs=NUM_EPOCH, accelerator='gpu', 
                        devices=DEVICES, 
                        strategy="dp", 
                        precision=32,
                        callbacks=[checkpoint_callback],
                        accumulate_grad_batches=ACCUMULATE_GRAD_BATCH)

    trainer.fit(model, datamodule=dm)
    # trainer.fit(model, dataset)

if __name__ == "__main__":
    DATA_SIZE = 1.0
    TRAINING_DATASETS = {
        # 'stsb': {'task_type': 'sts', 'data_path': 'data/stsb.json', 'size': DATA_SIZE},
        # 'sick': {'task_type': 'sts', 'data_path': 'data/sick.json', 'size': DATA_SIZE},
        # 'mrpc': {'task_type': 'paraphrase', 'data_path': 'data/mrpc.json', 'size': DATA_SIZE},
        # 'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'data/paws_qqp.json', 'size': DATA_SIZE},
        # 'paws': {'task_type': 'paraphrase', 'data_path': 'data/paws.json', 'size': DATA_SIZE},
        # 'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'data/paws_unlabeled.json', 'size': DATA_SIZE},
        # 'qqp': {'task_type': 'paraphrase', 'data_path': 'data/qqp.json', 'size': DATA_SIZE},
        # 'race': {'task_type': 'multiple_choice_qa', 'data_path': 'data/race.json', 'size': DATA_SIZE},

        # 'wmt18': {'task_type': 'wmt', 'data_path': 'data/wmt18.json', 'size': DATA_SIZE},
        # 'wmt17': {'task_type': 'wmt', 'data_path': 'data/wmt17.json', 'size': DATA_SIZE},
        # 'wmt16': {'task_type': 'wmt', 'data_path': 'data/wmt16.json', 'size': DATA_SIZE},
        # 'wmt15': {'task_type': 'wmt', 'data_path': 'data/wmt15.json', 'size': DATA_SIZE},

        # 'mnli': {'task_type': 'nli', 'data_path': 'data/mnli.json', 'size': DATA_SIZE},
        # 'snli': {'task_type': 'nli', 'data_path': 'data/snli.json', 'size': DATA_SIZE},
        # 'ctc': {'task_type': 'ctc', 'data_path': 'data/ctc.json', 'size': DATA_SIZE},

        # 'ctc_dialog': {'task_type': 'ctc', 'data_path': 'data/ctc_dialog.json', 'size': DATA_SIZE},
        # 'correct_mnli': {'task_type': 'nli', 'data_path': 'data/correct_wrong_examples/correct_mnli.json', 'size': DATA_SIZE},
        # 'wrong_mnli': {'task_type': 'nli', 'data_path': 'data/correct_wrong_examples/wrong_mnli.json', 'size': DATA_SIZE},

        # 'vitaminc': {'task_type': 'fact_checking', 'data_path': 'data/vitaminc.json', 'size':DATA_SIZE},
        # 'anli_r1': {'task_type': 'nli', 'data_path': 'data/anli_r1.json', 'size': DATA_SIZE},
        # 'anli_r2': {'task_type': 'nli', 'data_path': 'data/anli_r2.json', 'size': DATA_SIZE},
        # 'anli_r3': {'task_type': 'nli', 'data_path': 'data/anli_r3.json', 'size': DATA_SIZE},

        # 'cr': None
    }
    VAL_TRAINING_DATASETS = {
        # 'mrpc_val': {'task_type': 'paraphrase', 'data_path': 'data/mrpc_val.json', 'size': 1.0},
        # 'paws_val': {'task_type': 'paraphrase', 'data_path': 'data/paws_val.json', 'size': 1.0},
        # 'qqp_val': {'task_type': 'paraphrase', 'data_path': 'data/qqp_val.json', 'size': 1.0},
        # 'race_val': {'task_type': 'multiple_choice_qa', 'data_path': 'data/race_val.json', 'size': DATA_SIZE},
    }
    TRAINING_DATASETS = {
        # 'xsum': {'task_type': 'summarization', 'data_path': 'data/xsum.json', 'size':DATA_SIZE},
        # 'cnndm': {'task_type': 'summarization', 'data_path': 'data/cnndm.json', 'size': DATA_SIZE},

        # 'mnli': {'task_type': 'nli', 'data_path': 'data/mnli.json', 'size': DATA_SIZE},
        # 'nli_fever': {'task_type': 'fact_checking', 'data_path': 'data/nli_fever.json', 'size': DATA_SIZE},
        # 'doc_nli': {'task_type': 'qa', 'data_path': 'data/doc_nli.json', 'size': DATA_SIZE},

        # 'squad': {'task_type': 'qa', 'data_path': 'data/squad.json', 'size': DATA_SIZE},

        # 'squad_v2': {'task_type': 'qa', 'data_path': 'data/squad_v2.json', 'size': DATA_SIZE},
        # 'paws': {'task_type': 'paraphrase', 'data_path': 'data/paws.json', 'size':DATA_SIZE},
        # 'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'data/paws_qqp.json', 'size':DATA_SIZE},
        # 'vitaminc': {'task_type': 'fact_checking', 'data_path': 'data/vitaminc.json', 'size':DATA_SIZE},
        # 'race': {'task_type': 'qa', 'data_path': 'data/race.json', 'size': DATA_SIZE},
        'anli_r1': {'task_type': 'nli', 'data_path': 'data/anli_r1.json', 'size': DATA_SIZE},
        'anli_r2': {'task_type': 'nli', 'data_path': 'data/anli_r2.json', 'size': DATA_SIZE},
        'anli_r3': {'task_type': 'nli', 'data_path': 'data/anli_r3.json', 'size': DATA_SIZE},
        # 'snli': {'task_type': 'nli', 'data_path': 'data/snli.json', 'size': DATA_SIZE},
        # 'wikihow': {'task_type': 'summarization', 'data_path': 'data/wikihow.json', 'size': DATA_SIZE},
        # 'msmarco': {'task_type': 'ir', 'data_path': 'data/msmarco.json', 'size': DATA_SIZE},
        # 'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'data/paws_unlabeled.json', 'size': DATA_SIZE},
        # 'wiki103': {'task_type': 'paraphrase', 'data_path': 'data/wiki103.json', 'size': DATA_SIZE},
        # 'qqp': {'task_type': 'paraphrase', 'data_path': 'data/qqp.json', 'size': DATA_SIZE},
        # 'stsb': {'task_type': 'sts', 'data_path': 'data/stsb.json', 'size': DATA_SIZE},
        # 'sick': {'task_type': 'sts', 'data_path': 'data/sick.json', 'size': DATA_SIZE},

        # 'ctc': {'task_type': 'ctc', 'data_path': 'data/ctc.json', 'size': DATA_SIZE},
        # 'mrpc': {'task_type': 'paraphrase', 'data_path': 'data/mrpc.json', 'size':DATA_SIZE},
    }
    # for one_dataset in ALL_TRAINING_DATASETS:
    #     TRAINING_DATASETS = {one_dataset: ALL_TRAINING_DATASETS[one_dataset]}
    #     main()

    main()
