import logging
import argparse
import os
import json
import uuid
from pathlib import Path
from simpletransformers.seq2seq import Seq2SeqModel
from utils import read_data_source_target
def eval_res(a, b):
    assert b.count("</a>") in [0,1]
    if b.count("</a>") == 0:
        return int(a.startswith(b))
    b = b.split("</a>")[0]
    a = a.split("</a>")[0]
    if b.count("<a>") == 1:
        # extract and compare the part between <a> and </a>
        if a.count("<a>") != 1:
            return 0
        a = a.split("<a>")[1]
        b = b.split("<a>")[1]
        return int(a==b)
    if b.count("<a>") == 0:
        return int(a==b)
    assert False

def parse_target(target_text):
    temp = target_text.split("</a>")[0]
    q, a = temp.split("<a>")
    attr, q = q.split("<q>")
    q = q.split("</q>")[0]
    # print(attr, q)
    attr = attr.strip("><")
    h, _, t = q.strip("><").split("><")
    return attr, h, t
def eval_items(all_items, partition_atomic=False, test_entities=None):
    acc = dict()   # maps each type of example to the corresponding list of eval results
    for item in all_items:
        if 'type' not in item:
            t = 'test_inferred'
        else:
            t = item['type']
        
        if "model_output" in item:
            pred, gold = item["model_output"], item["target_text"]
        else:
            pred, gold = item["model output"], item["target text"]

        if t == 'train_atomic' and partition_atomic:
            head, rel, _ = gold.split("<a>")[0].strip("><").split("><")[1:-1]
            if rel in test_entities:
                # determine whether it's train or test atomic fact
                if head in test_entities[rel]:
                    t = "test_atomic"
                else:
                    t = "train_atomic"

        if t not in acc:
            acc[t] = []
        acc[t].append(eval_res(pred, gold))
    return acc
def eval_items(all_items):
    acc = eval_items(all_items)
    scores = [round(sum(acc[t])/len(acc[t]), 3) for t in acc]
    return scores[0]
    
if __name__ == "__main__":
    parser =argparse.ArgumentParser()
    parser.add_argument("--data_dir", default=None, type=str, required=True, help="Input data dir. {train/valid/test}.json files for the task.")
    parser.add_argument("--model_type", default='gpt2', type=str, help="lm type")
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="lm name or path")
    parser.add_argument("--init_weights", action="store_true", help="whether fresh init the weights of the model")
    parser.add_argument("--add_tokens", action="store_true", help="whether add the tokens in vocab.json in data_dir to the vocabulary.")
    parser.add_argument("--no_dropout", action="store_true", help="Whether disable dropout.")
    parser.add_argument("--n_layer", default=None, type=int, help="number of layers, only used when init weight")
    parser.add_argument("--n_head", default=None, type=int, help="number of heads, only used when init weight")
    parser.add_argument("--n_inner", default=None, type=int, help="inner dimension of MLP")
    parser.add_argument("--no_ln", action="store_true", help="Whether disable layernorm.")
    parser.add_argument("--no_mlp", action="store_true", help="Whether disable mlp layers.")
    parser.add_argument("--share_mlp", action="store_true", help="Whether share mlp weights across layers.")
    parser.add_argument("--add_recurrence", action="store_true", help="Whether run the layers twice.")
    parser.add_argument("--re_embed", action="store_true", help="Whether add re-embedding during recurrence.")
    parser.add_argument("--re_embed_temp", default=1.0, type=float, help="softmax temperature for re-embedding")
    parser.add_argument("--relation_mean_shift", action="store_true", help="Whether perform OOD relation mean shift w.r.t. ID relations in lm_head")
    parser.add_argument("--add_memory", action="store_true", help="Whether add shared mlp memory.")
    parser.add_argument("--memory_dim", default=1536, type=int, help="inner dimension of add shared mlp memory")
    parser.add_argument("--n_train_data", default = 1000, type = int, help="the number of training samples to use from the FT")
    parser.add_argument("--fp16", action="store_true", help="whether use half-precision training")
    parser.add_argument("--do_train", action="store_true", help="Whether run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether run validation.")
    parser.add_argument("--do_predict", action="store_true", help="Whether to run prediction on the test set.")
    parser.add_argument("--overwrite_output_dir", action="store_true", help="Whether to overwrite on the existing output dir")
    parser.add_argument("--save_best_model", action="store_true", help="Whether to save the best model on validation")
    parser.add_argument("--use_multiprocessed_decoding", action="store_true", help="Whether to use multiprocess when decoding")
    parser.add_argument("--save_model_every_epoch", action="store_true", help="Whether to save model every epoch")
    parser.add_argument("--evaluate_during_training", action="store_true", help="Whether to eval model during training")
    parser.add_argument("--predict_during_training", action="store_true", help="Whether to predict on test set during training")
    parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay")
    parser.add_argument("--warmup_steps", default=2000, type=int, help="Warmup step. 0 for using warmup ratio.")
    parser.add_argument("--save_epoch_interval", default=0, type=int, help="Save checkpoint every X epochs. 0 for no saving")
    parser.add_argument("--scheduler", default='linear_schedule_with_warmup', type=str, help="scheduler type")
    parser.add_argument("--output_dir", default='output_dir/', type=str, help="The output directory where the model checkpoints will be written.")
    parser.add_argument("--prediction_dir", default="/home/gghosal/parametric_cf/repro/", type=str, help="The output directory where the predictions results will be written.")
    parser.add_argument("--custom_test", default=None, type=str, help="Override the default test set (test.json)")
    parser.add_argument("--save_step", default=0, type=int, help="Save checkpoint every X updates steps. 0 for no saving")
    parser.add_argument("--save_step_dense", default=-1, type=int, help="If not -1, save via every save_step_dense_interval steps till specified")
    parser.add_argument("--save_step_dense_interval", default=2000, type=int, help="")
    parser.add_argument("--train_batch_size", default=16, type=int, help="Size of each train batch")
    parser.add_argument("--eval_batch_size", default=16, type=int, help="Size of each eval/predict batch")
    parser.add_argument("--restrict_cf_dataset", type = int, default = -1, help="Check ")
    parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps")
    parser.add_argument("--learning_rate", default=4e-5, type=float, help="learning rate")
    parser.add_argument("--max_steps", default=0, type=int, help="Number of train steps")
    parser.add_argument("--num_train_epochs", default=20, type=int, help="Number of train epochs")
    parser.add_argument('--dataloader_num_workers', default=0, type=int, help='the number of cpus used in collecting data in dataloader. Note that if it is large than cpu number, the program may be stuck')
    parser.add_argument('--manual_seed', default=42, type=int, help='random seed')
    parser.add_argument("--max_seq_length", default=None, type=int, help="Max input seq length")
    parser.add_argument("--max_length", default=None, type=int, help="Max output seq length")
    parser.add_argument("--max_gen_length", default=None, type=int, help="Max seq length appending during generation")
    parser.add_argument("--block_size", default=None, type=int, help="block size")
    parser.add_argument("--prediction_cutoff", default=None, type=int, help="if set, only predict on the first # of prediction examples")
    parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training')
    parser.add_argument('--gpu', default=None, type=int)
    args = parser.parse_args()
    output_dir = (Path(args.model_name_or_path) / "finetune/")
    output_dir.mkdir(exist_ok=True)
    this_experiment_id = str(uuid.uuid1())
    output_dir = (output_dir / this_experiment_id)
    output_dir.mkdir(exist_ok = True)
    train_df, train_sample_size = read_data_source_target(os.path.join(args.data_dir, "cf_train.json"), return_num=True)
    train_df = train_df.sample(n = args.n_train_data)
    print("Unrel, Hop 1:",len(train_df[train_df['unrel'] & (train_df['hop']==1)]))
    print("Unrel, Hop 2:",len(train_df[train_df['unrel'] & (train_df['hop']==2)]))
    print("Rel, Hop 1:",len(train_df[(~train_df['unrel']) & (train_df['hop']==1)]))
    print("Rel, Hop 2:",len(train_df[(~train_df['unrel']) & (train_df['hop']==2)]))
    print(train_df['input_text'])
    train_regular_df, train_regular_size = read_data_source_target(os.path.join(args.data_dir, "train.json"), return_num=True)
    test_df = read_data_source_target(os.path.join(args.data_dir, "cf_test.json"), return_json=True)
    test_regular_df = read_data_source_target(os.path.join(args.data_dir, "test.json"), return_json=True)
    test_df_relevant_1_cf = [x for x in test_df if not x['unrel'] and x['hop']==1]
    test_df_relevant_2_cf = [x for x in test_df if not x['unrel'] and x['hop']==2]
    test_df_irrelevant_1_cf = [x for x in test_df if x['unrel'] and x['hop']==1]
    test_df_irrelevant_2_cf = [x for x in test_df if x['unrel'] and x['hop']==2]
    eval_df = {"Standard":test_regular_df, "Irrelevant CF, 1 Hop":test_df_irrelevant_1_cf, "Irrelevant CF, 2 Hop":test_df_irrelevant_2_cf, "Relevant CF, 1 Hop":test_df_relevant_1_cf,"Relevant CF, 2 Hop":test_df_relevant_2_cf }
    print("test_sample_size", len(test_df))
    ddp_args = {
        "local_rank": args.local_rank,
        "rank": args.rank,
        "gpu": args.gpu,
        "world_size": args.world_size,
        "dist_url": args.dist_url,
        "dist_backend": args.dist_backend,
    }
    model_args = {
        "reprocess_input_data": True,
        "overwrite_output_dir": args.overwrite_output_dir,
        "max_seq_length": args.max_seq_length,
        "max_length": args.max_length,
        "max_gen_length": args.max_gen_length,
        "block_size": args.block_size,
        "train_batch_size": args.train_batch_size,
        "eval_batch_size": args.eval_batch_size,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "learning_rate": args.learning_rate,
        "num_train_epochs": args.num_train_epochs,
        "save_eval_checkpoints": False,
        "save_steps": args.save_step,
        "use_multiprocessing": False,
        "output_dir": str(output_dir.as_posix()),
        "manual_seed": args.manual_seed,
        "fp16": args.fp16,
        "truncation": True,
        "dataloader_num_workers":args.dataloader_num_workers,
        "use_multiprocessed_decoding":args.use_multiprocessed_decoding,
        "save_best_model": args.save_best_model,
        "save_model_every_epoch": args.save_model_every_epoch,
        "save_epoch_interval": args.save_epoch_interval,
        "scheduler": args.scheduler,
        "weight_decay": args.weight_decay,
        "evaluate_during_training": args.evaluate_during_training,
        "predict_during_training":False,
        "mlm": False,
        "warmup_steps": args.warmup_steps,
        "max_steps": args.max_steps,
        "n_layer": args.n_layer,
        "n_inner": args.n_inner,
        "n_head": args.n_head,
        "memory_dim": args.memory_dim,
    }
    model = Seq2SeqModel(
        model_type=args.model_type,
        model_name=args.model_name_or_path,
        args=model_args,
        ddp_args=ddp_args,
        new_tokens=None, 
        init_weights=False,
        no_dropout=args.no_dropout,
        no_ln=args.no_ln,
        no_mlp=args.no_mlp,
        share_mlp=args.share_mlp,
        add_memory=args.add_memory,
        add_recurrence=args.add_recurrence,
        re_embed=args.re_embed,
        re_embed_temp=args.re_embed_temp,
        relation_mean_shift=args.relation_mean_shift,
    )
    if args.do_train:
        model.train_model(train_data=train_df, eval_data=eval_df,output_dir=args.output_dir,
                          save_step_dense=args.save_step_dense, save_step_dense_interval=args.save_step_dense_interval,finetune = True)