from src.daemon import DaemonDecoding, ProposalParams
import json
import argparse 
import os
from utils.eval_utils import read_csv, write_csv


EVAL_MODEL_CUDA = "cuda:2"
TOPK=0
TOPP=1.0
TEMP=0.97


scorer_configs = {"seq_rep/2": {"n":2, "requires_detokenize": True}, 
                "seq_rep/3": {"n":3, "requires_detokenize": True}, 
                "seq_rep/4": {"n":4, "requires_detokenize": True}, 
                "tok_rep/8": {"l": 8, "requires_detokenize": True}, 
                "tok_rep/16": {"l": 16, "requires_detokenize": True}, 
                "tok_rep/32": {"l": 32, "requires_detokenize": True}, 
                "coherence": {"prefix_len": 32, 
                                "max_length": 256,
                                "device": EVAL_MODEL_CUDA, 
                                "simcse_model_path": "sup-simcse-roberta-base"},
                "diversity": {"requires_detokenize": True},
                "information": {"max_length": 256, 
                            "lm_model_path": "gpt2-xl",
                            "device": EVAL_MODEL_CUDA}}


def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="data/eval_data/wikitext")
    parser.add_argument("--pmt_data_name", type=str, default="test_pmt.csv")
    parser.add_argument("--save_gen_name", type=str, default=f"test_sir_k{TOPK}p{TOPP}t{TEMP}_pool25.csv")
    parser.add_argument("--batch_size", type=int, default=1)

    parser.add_argument("--main_model_ckpt", type=str, default="gpt2-xl")
    parser.add_argument("--proposal_model_ckpt", type=str, default="", help="empty string default to using main model as proposal")
    parser.add_argument("--mu", type=str, help="path to optimal mu")#, default=optimal_solution_path)
    parser.add_argument("--main_model_device", type=str, default="cuda:3")
    parser.add_argument("--eval_model_device", type=str, default=EVAL_MODEL_CUDA)
    parser.add_argument("--dtype", type=str, default="float16")
    parser.add_argument("--max_len", type=int, default=256)
    parser.add_argument("--mc_len", type=int, default=256)
    parser.add_argument("--mc_num", type=int, default=25)
    parser.add_argument("--top_k", type=int, default=TOPK)
    parser.add_argument("--top_p", type=float, default=TOPP)
    parser.add_argument("--temperature", type=float, default=TEMP)
    parser.add_argument("--seq_top_k", type=int, default=0)
    parser.add_argument("--no_importance_weight", action="store_true")
    parser.add_argument("--corr_reduce", action="store_true")
    parser.add_argument("--return_one", action="store_true")
    parser.set_defaults(no_importance_weight=True)
    parser.set_defaults(corr_reduce=False)
    parser.set_defaults(return_one=True)

    return parser.parse_args()

def main():
    args = parse_config()
    args.save_gen_name = args.mu.split("/")[-1][:-5] + args.save_gen_name
    print("save filename ", args.save_gen_name)
    model_name = args.main_model_ckpt.split("/")[-1]
    decoder = DaemonDecoding(model_name, args.main_model_ckpt, args.main_model_device, args.dtype)

    
    optimal_mu = json.load(open(args.mu, "r"))

    pmt_data = read_csv(os.path.join(args.data_path, args.pmt_data_name))#[:64]

    print(f"reduce correlation == {args.corr_reduce}")
    print(f"no importance weight == {args.no_importance_weight}")
    outs = decoder.sample_importance_resample(pmt_data, scorer_configs, optimal_mu,
                            max_len=args.max_len,
                            mc_num=args.mc_num,
                            mc_len=args.mc_len,
                            top_k=args.top_k,
                            top_p=args.top_p,
                            temperature=args.temperature,
                            corr_reduce=args.corr_reduce,
                            seq_top_k=args.seq_top_k,
                            no_IS=args.no_importance_weight,
                            use_cache=True)

    write_csv(outs, os.path.join(args.data_path, args.save_gen_name))

if __name__ == "__main__":
    main()
    