from datasets import load_dataset
from openicl import DatasetReader, PromptTemplate, GenInferencer
from openicl import EMEvaluator, RougeEvaluator, BleuEvaluator
from util.noise import gen_noise_generation
from util.template import gen_template_generation
from util.retriever import get_retriever
from util.verification import data_verification
import argparse
import os
import json

data2usename = {
    "mtop": "KaiLv/UDR_MTOP",
    "smcalflow": "KaiLv/UDR_SMCalFlow",
    "cnn": "KaiLv/UDR_CNNDailyMail",
    "pubmed": "KaiLv/UDR_PubMed",
    "reddit": "KaiLv/UDR_Reddit",
    "commongen": "KaiLv/UDR_CommonGen",
    "rocstory": "KaiLv/UDR_RocStory",
    "rocending": "KaiLv/UDR_RocEnding",
    "go": "KaiLv/UDR_RocStory",
    "python": "KaiLv/UDR_Python",
    "java": "KaiLv/UDR_Java",
    "php": "KaiLv/UDR_PHP",
    "dart": "KaiLv/UDR_DART",
    "e2e": "KaiLv/UDR_E2E",
}


data2metric = {
    "mtop": EMEvaluator(),
    "smcalflow": EMEvaluator(),
    "cnn": RougeEvaluator(),
    "pubmed": RougeEvaluator(),
    "reddit": RougeEvaluator(),
    "commongen": BleuEvaluator(num_gram=3),
    "rocstory": BleuEvaluator(num_gram=1),
    "rocending": BleuEvaluator(num_gram=1),
    "go": BleuEvaluator(num_gram=1),
    "python": BleuEvaluator(num_gram=1),
    "java": BleuEvaluator(num_gram=1),
    "php": BleuEvaluator(num_gram=1),
    "dart": BleuEvaluator(num_gram=4),
    "e2e": BleuEvaluator(num_gram=4),
}


data2clmname = {
    "mtop": "question",
    "smcalflow": "user_utterance",
    "cnn": "article",
    "pubmed": "question",
    "reddit": "question",
    "commongen": "joined_concepts",
    "rocstory": "question",
    "rocending": "question",
    "go": "question",
    "python": "question",
    "java": "question",
    "php": "question",
    "dart": "question",
    "e2e": "question",
}

data2tarname = {
    "mtop": "logical_form",
    "smcalflow": "lispress",
    "cnn": "highlights",  # 2
    "pubmed": "target",  # 2
    "reddit": "target",  # 2
    "commongen": "target",
    "rocstory": "target",
    "rocending": "target",
    "go": "target",
    "python": "target",
    "java": "target",
    "php": "target",
    "dart": "target",
    "e2e": "target",
}

data2iclnum = {
    "mtop": 8,
    "smcalflow": 8,
    "cnn": 1,
    "pubmed": 1,
    "reddit": 1,
    "commongen": 8,
    "rocstory": 8,
    "rocending": 8,
    "go": 8,
    "python": 7,
    "java": 6,
    "php": 8,
    "dart": 8,
    "e2e": 8,
}


def run(args):

    dataset = load_dataset(data2usename[args.dataset])

    dataset = data_verification(dataset, debug=args.debug)

    # Gen Noise
    dataset = gen_noise_generation(
        dataset,
        p=args.noise_p,
        split="train",
        tarname=data2tarname[args.dataset],
        dataname=args.dataset,
    )
    template = gen_template_generation(args)
    retriever = get_retriever(args)

    # Loader
    data = DatasetReader(
        dataset, input_columns=[data2clmname[args.dataset]], output_column="new_target"
    )

    template = PromptTemplate(
        template,
        {data2clmname[args.dataset]: "</text>", "new_target": "</label>"},
        ice_token="</E>",
    )
    retriever = retriever(data, ice_num=data2iclnum[args.dataset])
    inferencer = GenInferencer(model_name=args.model)
    predictions = inferencer.inference(retriever, ice_template=template)

    # save

    prediction_file = (
        f'{args.log_dir}/prediction_{"debug" if args.debug else "run"}.json'
    )
    with open(prediction_file, "w") as f:
        json.dump(predictions, f)

    label_file = f'{args.log_dir}/label_{"debug" if args.debug else "run"}.json'
    with open(label_file, "w") as f:
        json.dump(data.references, f)


def print_score(args):
    # load
    prediction_file = (
        f'{args.log_dir}/prediction_{"debug" if args.debug else "run"}.json'
    )
    label_file = f'{args.log_dir}/label_{"debug" if args.debug else "run"}.json'
    with open(prediction_file, "r") as f:
        prediction = data_loaded = json.load(f)
    with open(label_file, "r") as f:
        reference = data_loaded = json.load(f)

    evaluator = data2metric[args.dataset]
    score = evaluator.score(predictions=prediction, references=reference)
    print(score)


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument(
        "--noise_p", type=float, default=0.0, help="Noise Flipping probability"
    )
    args.add_argument("--dataset", type=str, default="sst2", help="Dataset name")
    args.add_argument(
        "--model",
        type=str,
        default="EleutherAI/gpt-neo-2.7B",
        help="Pretrained LLM model name",
    )
    args.add_argument("--retriever", type=str, default="random", help="Retriever Type")
    args.add_argument(
        "--log_dir",
        type=str,
        default=f"./icl_inference_output",
        help="Logging directory",
    )
    args.add_argument("--debug", action="store_true")
    args.add_argument("--run", action="store_true")
    args = args.parse_args()

    model = args.model.replace("/", "_")
    args.log_dir = f"{args.log_dir}/{args.dataset}/model={model}_noise={args.noise_p}_retriever={args.retriever}"
    os.makedirs(args.log_dir, exist_ok=True)

    if args.run:
        run(args)
    print_score(args)
