import logging
import argparse
import os
import json
from simpletransformers.seq2seq import Seq2SeqModel
from utils import read_data_source_target
def main(args):
    model_args = {
        "reprocess_input_data": True,
        "overwrite_output_dir": False,
        "max_seq_length": args.seq_length,
        "max_length": args.seq_length,
        "max_gen_length": args.seq_length,
        "block_size": args.seq_length,
        "train_batch_size": 64,
        "max_steps":0,
        "eval_batch_size": 64,
        "gradient_accumulation_steps": 1,
        "learning_rate": 1e-4,
        "save_eval_checkpoints": False,
        "save_steps": 1e9,
        "use_multiprocessing": False,
        "evaluate_before_start":True,
        "output_dir": "./",
        "wandb_project":args.wandb_proj,
        "wandb_run_name":args.wandb_run_id,
        "evaluate_during_training": True,
        "predict_during_training": True,
        "manual_seed": 42,
        "fp16":True,
        "truncation": True,
        "dataloader_num_workers":0,
        "use_multiprocessed_decoding":False,
        "save_best_model": False,
        "save_model_every_epoch": True,
        "overwrite_output_dir":True,
        "save_epoch_interval":1,
        "warmup_steps":0,
        "scheduler": "constant_schedule",
        "weight_decay": 0.01,
        "mlm": False,
    }
    this_training_hparam = {"train_batch_size":args.batch_size, 
    "gradient_accumulation_step":args.grad_accum,
    "weight_decay":args.wd,
    "num_train_epochs":args.n_epochs,
    "evaluate_during_training":True,
    "learning_rate":args.lr}
    model_args.update(this_training_hparam)
    ddp_args = {
        "local_rank": -1,
        "rank": -1,
        "gpu": "cuda:0",
        "world_size": -1,
        "dist_url": "env://",
        "dist_backend": "nccl",
    }
    model = Seq2SeqModel(model_type = "gpt2", model_name = args.model_name_or_path,args = model_args,ddp_args = ddp_args)
    output_dir = f"{args.model_name_or_path}/finetuning/{args.wandb_run_id}/"
    os.makedirs(output_dir, exist_ok = True)
    train_df, train_sample_size = read_data_source_target(os.path.join(args.data_dir, "cot_fact_train.json"), return_num=True)
    print("whole train dataset", train_sample_size)
    train_df = train_df.sample(frac = 1.0).head(args.train_size)
    test_df = read_data_source_target(os.path.join(args.data_dir,"cot_fact_test.json"), return_json = True)
    train_json = read_data_source_target(os.path.join(args.data_dir, "cot_fact_train.json"), return_json=True)
    test_data = {"test_cot":test_df, "train_cot":train_json}
    #print_train_dataset_summary(train_df)
    ##convert training data
    #id_test_data = read_data_source_target(os.path.join(args.data_dir,"test.json"), return_json = True)
    #inferred_iid_test = [x for x in id_test_data if x['type']=='test_inferred_iid']
    #inferred_related_hop_1 = [x for x in inferred_iid_test if (not x['unrel']) and  (x['hop']==1)]
    #inferred_related_hop_2 = [x for x in inferred_iid_test if (not x['unrel']) and  (x['hop']==2)]
    #cf_test_data = read_data_source_target(os.path.join(args.data_dir,"cf_test.json"), return_json=True)
    #cf_related_hop_1 = [x for x in cf_test_data if (not x['unrel']) and  (x['hop']==1)]
    #cf_related_hop_2 = [x for x in cf_test_data if (not x['unrel']) and  (x['hop']==2)]
    #cf_unrelated_hop_1 = [x for x in cf_test_data if (x['unrel']) and  (x['hop']==1)]
    #cf_unrelated_hop_2 = [x for x in cf_test_data if (x['unrel']) and  (x['hop']==2)]
    #test_data = {"inferred_related_hop_1":inferred_related_hop_1, 
    #                "inferred_related_hop_2": inferred_related_hop_2,
    #                "cf_related_hop_1":cf_related_hop_1,
    #                "cf_related_hop_2":cf_related_hop_2,
    #                "cf_unrelated_hop_1":cf_unrelated_hop_1,
    #                "cf_unrelated_hop_2":cf_unrelated_hop_2, 
    #                "train":train_df.to_dict('records')}
    eval_df = read_data_source_target(os.path.join(args.data_dir, "valid.json"))
    model.train_model(train_data=train_df, eval_data=eval_df, test_data=test_data, output_dir=output_dir, overwrite_ouput_dir = True,finetune=True, \
                        save_step_dense=args.save_step_dense, save_step_dense_interval=args.save_step_dense_interval)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type = int,default = 64)
    parser.add_argument("--grad_accum", type = int, default = 1)
    parser.add_argument("--seq_length", type = int, default = 25)
    parser.add_argument("--wd", type = float, default =1e-1)
    parser.add_argument("--n_epochs", type = float, default = 5)
    parser.add_argument("--lr", type = float, default =1e-3)
    parser.add_argument("--model_name_or_path", type = str)
    parser.add_argument("--data_dir", type = str)
    parser.add_argument("--train_size", type = int, default = 64*500)
    parser.add_argument("--save_step_dense", type=int, default=-1)
    parser.add_argument("--save_step_dense_interval", type =int, default=100)
    parser.add_argument("--wandb_proj", type = str, default = "parametric_cf")
    parser.add_argument("--wandb_run_id", type = str)
    parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training')
    parser.add_argument('--gpu', default=None, type=int)
    args = parser.parse_args()
    main(args)