from src.daemon import DaemonDecoding, MuOptimParams
from tqdm import tqdm

import json
import csv
import argparse 
import os



EVAL_MODEL_CUDA = "cuda:0"

scorer_configs = {"seq_rep/2": {"n":2, "requires_detokenize": True}, 
                "seq_rep/3": {"n":3, "requires_detokenize": True}, 
                "seq_rep/4": {"n":4, "requires_detokenize": True}, 
                "tok_rep/8": {"l": 8, "requires_detokenize": True}, 
                "tok_rep/16": {"l": 16, "requires_detokenize": True}, 
                "tok_rep/32": {"l": 32, "requires_detokenize": True}, 
                "coherence": {"prefix_len": 32, 
                                "max_length": 256,
                                "device": EVAL_MODEL_CUDA, 
                                "simcse_model_path": "sup-simcse-roberta-base"},
                "diversity": {"requires_detokenize": True},
                "information": {"max_length": 256, 
                            "lm_model_path": "gpt2-xl",
                            "device": EVAL_MODEL_CUDA}}

def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_path", type=str, default="gpt2-xl")
    parser.add_argument("--oracle_data_path", type=str, default="data/eval_data/wikitext/dev_full.csv")
    parser.add_argument("--optimization_type", type=str, default="match_prefix", help="match|mle|match_prefix")
    parser.add_argument("--optimal_solution_dir", type=str, default=f"optimal_mu")
    parser.add_argument("--lr", type=float, default=0.005)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--max_iters", type=int, default=500000)
    parser.add_argument("--max_len", type=int, default=256)
    parser.add_argument("--min_len", type=int, default=32)
    parser.add_argument("--num_target_samples", type=int, default=512)
    parser.add_argument("--num_proposal_samples", type=int, default=512)
    parser.add_argument("--err_fn", type=str, default="rmsre", help="rse|rmsre")
    parser.add_argument("--min_err", type=float, default=0.001)
    parser.add_argument("--no_prefix_weight", action="store_true")
    parser.add_argument("--mu_clamp", default=-1)
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--cond_len", type=int, default=0)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--decoder_device", type=str, default="cuda:6")
    parser.set_defaults(no_prefix_weight=True)
    
    return parser.parse_args()

args = parse_config()

data_name = args.oracle_data_path.split("/")[-2]
model_name = args.ckpt_path.split("/")[-1]

short_optim_config = []
short_optim_config.append(data_name)
short_optim_config.append(model_name)

for k in scorer_configs.keys():
    if "rep" in k and "rep" not in short_optim_config:
        short_optim_config.append(k.replace("/", "-"))
    if "coh" in k and "coh" not in short_optim_config:
        short_optim_config.append("coh")
    if "div" in k and "div" not in short_optim_config:
        short_optim_config.append("div")
    if "inf" in k and "inf" not in short_optim_config:
        short_optim_config.append("inf")
    if "len" in k and "len" not in short_optim_config:
        short_optim_config.append("len")

short_optim_config.append(args.optimization_type)
short_optim_config.append("cond" + str(args.cond_len))
short_optim_config.append("lr" + str(args.lr))
short_optim_config.append("iterN" + str(args.max_iters))
short_optim_config.append("err" + str(args.min_err))
short_optim_config.append(args.err_fn)
if "match" in args.optimization_type:
    short_optim_config.append(f"k{args.top_k}p{args.top_p}t{args.temperature}")
    short_optim_config.append(f"sampN{args.num_proposal_samples}")

if args.no_prefix_weight:
    short_optim_config.append("no_prefixW")


os.makedirs(args.optimal_solution_dir, exist_ok=True)
optimal_solution_path = os.path.join(args.optimal_solution_dir, "_".join(short_optim_config) + ".json")

optim_params = MuOptimParams(lr=args.lr,
                             batch_size=args.batch_size, 
                             max_iters=args.max_iters,
                             max_len=args.max_len, 
                             min_len=args.min_len,
                             num_target_samples=args.num_target_samples,
                             num_proposal_samples=args.num_proposal_samples,
                             min_err=args.min_err,
                             weight_decay=args.weight_decay,
                             cond_len=args.cond_len,
                             top_k=args.top_k,
                             top_p=args.top_p,
                             temperature=args.temperature)

print(optim_params)
model_name = args.ckpt_path.split("/")[-1]
geo_decoder = DaemonDecoding(model_name, args.ckpt_path, args.decoder_device, "float16")

oracle_texts = [line[0] for line in csv.reader(open(args.oracle_data_path, "r"), delimiter=",")]
oracle_token_ids = [geo_decoder.tokenizer.encode(x) for x in tqdm(oracle_texts, desc="tokenizing...")]


if args.optimization_type == "match":
    optimal_mu = geo_decoder.compute_optimal_mu_match(oracle_token_ids, scorer_configs, optim_params, mu_sign_mask)
elif args.optimization_type == "mle":
    optimal_mu = geo_decoder.compute_optimal_mu_mle(oracle_token_ids, scorer_configs, optim_params)
elif args.optimization_type == "match_prefix":
    optimal_mu = geo_decoder.compute_optimal_mu_match_prefix(oracle_token_ids, scorer_configs, optim_params, data_name, "import_samp_cache", args.err_fn, no_prefix_weight=args.no_prefix_weight, mu_sign_mask=mu_sign_mask)

json.dump(optimal_mu, open(optimal_solution_path, "w"))

