import argparse
import json
import logging
import random
import os
from typing import List

import datasets
import torch
from tqdm import trange, tqdm
from transformers import is_torch_available, is_tf_available

from converters.base_converter import BaseConverter
from converters.registry import get_converter
from dataset import Example
from demo_selection import RandomDemoSelection, BaseDemoSelection, CosineTopKDemoSelection, MMRDemoSelection
from plugin_dataset import PlugInDataset
from generate import Generator
import numpy as np


logging.basicConfig(level=logging.INFO)
SELECTOR = {
    "random": RandomDemoSelection,
    "cos_topk": CosineTopKDemoSelection,
    "mmr": MMRDemoSelection
}

def seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    if is_torch_available():
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def evaluate(all_generations, all_labels, label_map):
    # we can add some process
    cnt_total, cnt_correct = 0, 0
    for generation, label in zip(all_generations, all_labels):
        if generation == label_map[label]:
            cnt_correct += 1
        cnt_total += 1
    return cnt_correct / cnt_total * 100.0


def few_shot_generation(dataset: PlugInDataset, converter: BaseConverter,
                        selector: BaseDemoSelection, generator: Generator, args: argparse.Namespace):
    all_labels, all_examples, all_generations, all_prompt_inputs = [], [], [], []

    for i, batch in tqdm(enumerate(dataset.example_batch()), total=len(dataset) // args.batch_size, desc="Generation:"):
        # Prepare all prompt inputs and labels
        if isinstance(selector, MMRDemoSelection):
            batch_demos = selector.batch_get_demo(list(range(i*args.batch_size, i*args.batch_size+len(batch))))
        else:
            batch_demos = selector.batch_get_demo(batch)

        prompt_inputs = [converter.example2code(demos=demos, target=example)
                         for demos, example in zip(batch_demos, batch)]
        # Batch Generation
        generations = generator.generate(
            prompt_inputs,
            decode_method=args.decode_method,
            # temperature=0.1,
            max_new_tokens=args.max_new_tokens,
            # num_batches_to_gen=args.num_batches_to_gen,
            num_generate=args.num_generate,
        )
        if args.debug:
            print(generations)
            for idx, example in enumerate(batch[:3]):
                print("*"*10)
                print("Input:", prompt_inputs[idx])
                print("="*10)
                print("Gen:", generations[idx])
                print("="*10)
                print("Pred:", converter.code2answer(generations[idx]))
                print("Target Label:", example.target_label)
                print("*"*10)
            input()
        all_labels.extend([example.target_label for example in batch])
        all_examples.extend([example for example in batch])
        all_generations.extend([converter.code2answer(generation) for generation in generations])
        all_prompt_inputs.extend(prompt_inputs)
    return all_labels, all_generations, all_examples, all_prompt_inputs


def run_few_shot(
        generator: Generator,
        args: argparse.Namespace
):
    # random.seed(args.seed)
    seed(args.seed)
    dataset = datasets.load_dataset('sst2')
    data_type = "validation"
    gen_dataset = PlugInDataset(data_dict=dataset, data_type=data_type, src_key=args.src_key, tgt_key=args.tgt_key,
                                batch_size=args.batch_size, embedding_path=args.embed_path)
    print("\n\n****************************************************************************************")
    print("Run few shot...")
    if args.plugin_data_path == "":
        print("Using Plugin Dataset from Original.")
        plugin_dataset = PlugInDataset(data_dict=dataset, data_type="train", src_key=args.src_key, tgt_key=args.tgt_key,
                                       batch_size=args.batch_size, embedding_path=args.embed_path)
    else:
        print("Using Plugin Dataset from", os.path.join(args.plugin_data_path, "plugin_set.jsonl"))
        plugin_dataset = PlugInDataset.load_from_json(os.path.join(args.plugin_data_path, "plugin_set.jsonl"),
                                                      batch_size=args.batch_size, embedding_path=args.embed_path)
    converter = get_converter(args.converter)
    selector = SELECTOR[args.selector](examples=list(plugin_dataset.all_data.values()), n_shots=args.n_shots)
    if args.selector == "mmr":
        selector.precompute_similarities(targets=list(gen_dataset.all_data.values()))

    print(f"Running on {data_type} dataset with {len(plugin_dataset)} plugin data.\n")
    all_labels, all_generations, all_examples, all_prompt_inputs = \
        few_shot_generation(gen_dataset, converter, selector, generator, args)

    # Perform evaluation
    accuracy_score = evaluate(all_generations, all_labels, converter.LABEL_MAP)
    print("\nAccuracy score:", accuracy_score)
    print("****************************************************************************************\n\n")

    # Save the generation outputs with labels and example ids
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path, exist_ok=True)
    if args.plugin_data_path == "":
        plug_round = "round_0"
    else:
        plug_round = args.plugin_data_path.split("/")[2]
    output_path = os.path.join(args.output_path, f"{plug_round}_shots_{args.n_shots}_selector_{args.selector}")
    if not os.path.exists(output_path):
        os.makedirs(output_path, exist_ok=True)

    with open(os.path.join(output_path, f"results_{args.seed}.json"), "w") as f:
        f.write(json.dumps({"Accuracy": accuracy_score, "Num_plug": len(plugin_dataset)}) + "\n")
    with open(os.path.join(output_path, f"outputs_{args.seed}.json"), "w") as f:
        for generation, label, example, promt_input in zip(all_generations, all_labels, all_examples, all_prompt_inputs):
            f.write(json.dumps({"generation": generation, "label": label, "example_id": example.example_id,
                                "example_input": example.source_input, "prompt_input": promt_input})+"\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", help="Path to the model", required=True)
    parser.add_argument("--embed_path", help="Path to the embedding model", required=True)
    parser.add_argument(
        "--output_path", help="Path to the output directory (required for batch generation)"
    )
    parser.add_argument(
        "--plugin_data_path", default="", type=str, help="Path to the plugin data directory"
    )
    parser.add_argument(
        "--debug", default=False, type=bool, help="debug print"
    )
    parser.add_argument(
        "--n_shots", default=2, type=int, help="number of few shots"
    )
    parser.add_argument(
        "--seed", default=0, type=int, help="random seed"
    )
    parser.add_argument(
        "--max_length", help="max length to generate", type=int, required=False, default=None
    )
    parser.add_argument(
        "--batch_size", help="batch size to generate", type=int, required=False, default=1
    )
    parser.add_argument(
        "--n_iter", help="filter iterations", type=int, required=False, default=1
    )
    parser.add_argument(
        "--max_new_tokens",
        help="max new tokens to generate",
        type=int,
        required=False,
        default=None,
    )
    parser.add_argument(
        "--src_key", default="question", help="source key in jsonl  (required for batch generation)"
    )
    parser.add_argument(
        "--converter", default="sentiment", help="converter"
    )
    parser.add_argument(
        "--selector", default="random", choices=["random", "cos_topk", "mmr"],help="selector"
    )
    parser.add_argument(
        "--tgt_key", default="answer", help="target key in jsonl  (required for batch generation)"
    )
    parser.add_argument("--add_scores", action="store_true", help="add scores to output")
    parser.add_argument(
        "--decode_method",
        default="greedy",
        help="decode method",
        choices=["greedy", "beam", "sample"],
    )
    parser.add_argument("--from_config", action="store_true", help="load from config")
    parser.add_argument("--config_name", type=str, help="Name of the config to use")
    parser.add_argument(
        "--is_autoreg",
        action="store_true",
        help="is the model autoregressive",
        default=True,
    )
    parser.add_argument("--add_io_sep", type=str, default="true", help="add io sep")
    parser.add_argument("--mode", type=str, default="plug", help="default or golden plugin")

    parser.add_argument(
        "--num_generate", type=int, default=1, help="number of generations to generate"
    )

    parser.add_argument("--interactive", action="store_true", help="interactive mode")

    parser.add_argument("--fp16", action="store_true", help="use fp16")

    parser.add_argument("--temperature", type=float, default=1.0, help="temperature for sampling")

    parser.add_argument("--threshold", type=float, default=0.6, help="the threshold to select the golden set")

    parser.add_argument("--nocache", action="store_true", help="do not use cache")

    args = parser.parse_args()
    logging.info("model loading ...")

    args.add_io_sep = args.add_io_sep.lower() == "true"

    if args.max_length is not None and args.max_new_tokens is None:
        logging.warning(
            "max_new_tokens is not set, using max_length. We recommend using max_new_tokens to be compatible with huggingface"
        )
        args.max_new_tokens = args.max_length

    generator = Generator(
        args.model_path,
        from_config=args.from_config,
        config_name=args.config_name,
        is_autoreg=args.is_autoreg,
        batch_size=args.batch_size,
        fp16=args.fp16,
    )
    generator.model.zero_grad()
    generator.model.eval()
    logging.info("model loaded")

    logging.info(f"Model tokenizer length = {len(generator.tokenizer)}")
    run_few_shot(
        generator=generator,
        args=args,
    )
