import os
import sys
import time
import logging
import json

import fire
import torch

from transformers import AutoTokenizer
from esa_utils import (
    TEMPLATE_GEMBA_ESA_RANKING_SRC,
    TEMPLATE_GEMBA_ESA_RANKING_REF,
    TEMPLATE_GEMBA_ESA_RANKING_JOINT,
    parse_thinking_results,
)
from gemba_mqm_utils import apply_template
from thinmqm_utils import parse_thinking_results
from vllm import LLM, SamplingParams

logging.getLogger().setLevel(logging.INFO)


def main(
    model_name,
    source_path,
    hyp_dir,
    eval_dir,
    pred_dir,
    source_lang,
    target_lang,
    files,
    ref_path: str = None,
    template: str = "orginal",
    quantization: str = None,  # Options: 4bit, 8bit
    max_new_tokens=8192,  # The maximum numbers of tokens to generate
    gpus: str = None,
    seed: int = 42,  # seed value for reproducibility
    do_sample: bool = True,  # Whether or not to use sampling ; use greedy decoding otherwise.
    min_length: int = None,  # The minimum length of the sequence to be generated, input prompt + min_new_tokens
    use_cache: bool = True,  # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
    top_p: float = 0.95,  # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
    temperature: float = 0.6,  # [optional] The value used to modulate the next token probabilities.
    top_k: int = 20,  # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
    repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
    length_penalty: int = 1,  # [optional] Exponential penalty to the length that is used with beam-based generation.
    use_box: bool = False,
    **kwargs,
):
    # Set the seeds for reproducibility
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    if gpus is not None:
        if isinstance(gpus, (tuple, list)):
            gpus = ",".join(map(str, gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpus)

    print(torch.cuda.device_count())

    if quantization is not None:
        model = LLM(
            model_name,
            tokenizer=model_name,
            tensor_parallel_size=torch.cuda.device_count(),
            quantization=quantization,
        )
    else:
        model = LLM(
            model_name,
            tokenizer=model_name,
            tensor_parallel_size=torch.cuda.device_count(),
        )

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    with open(source_path, "r") as f:
        source = f.readlines()
    source = [x.strip() for x in source]

    if ref_path is not None:
        with open(ref_path, "r") as f:
            references = f.readlines()
        references = [x.strip() for x in references]
        assert len(source) == len(references)

    for hyp_file in files.split():
        hypo_path = os.path.join(hyp_dir, hyp_file)
        logging.info(f"reading {hypo_path}")
        # load both files and strip them
        with open(hypo_path, "r") as f:
            hypothesis = f.readlines()
        hypothesis = [x.strip() for x in hypothesis]
        assert len(source) == len(
            hypothesis
        ), "Source and hypothesis files must have the same number of lines."

        logging.info("=====ESA prompt: {ppt}=====".format(ppt=template))

        eval_path = os.path.join(eval_dir, hyp_file.replace(".txt", ".json"))
        logging.info(f"reading {eval_path}")

        error_spans = parse_thinking_results(eval_path)

        assert len(source) == len(hypothesis) == len(error_spans)

        processed_pair = []
        for src, hyp, ref, errors in zip(
            source,
            hypothesis,
            references,
            error_spans if references is not None else [None] * len(source),
        ):
            x = {
                "source_lang": source_lang,
                "source_seg": src,
                "reference_seg": ref,
                "target_lang": target_lang,
                "target_seg": hyp,
                "error_spans": errors,
            }
            if template == "original":
                x.pop("reference_seg")
                assert "reference_seg" not in x
                res = apply_template(TEMPLATE_GEMBA_ESA_RANKING_SRC, x)
                if use_box:
                    res[1]["content"] = res[1]["content"].replace(
                        '100="Perfect meaning and grammar".',
                        '100="Perfect meaning and grammar". Please reason step by step, and put your final answer within \\boxed{}.',
                    )
                processed_pair.append(res[1:])
            elif template == "nosrc" and ref is not None:
                res = apply_template(TEMPLATE_GEMBA_ESA_RANKING_REF, x)
                if use_box:
                    res[1]["content"] = res[1]["content"].replace(
                        '100="Perfect meaning and grammar".',
                        '100="Perfect meaning and grammar". Please reason step by step, and put your final answer within \\boxed{}.',
                    )
                processed_pair.append(res[1:])
            elif template == "all" and ref is not None:
                res = apply_template(TEMPLATE_GEMBA_ESA_RANKING_JOINT, x)
                if use_box:
                    res[1]["content"] = res[1]["content"].replace(
                        '100="Perfect meaning and grammar".',
                        '100="Perfect meaning and grammar". Please reason step by step, and put your final answer within \\boxed{}.',
                    )
                processed_pair.append(res[1:])
            else:
                return NotImplementedError, "false template para."

        formatted_inputs = tokenizer.apply_chat_template(
            processed_pair, tokenize=False, add_generation_prompt=True
        )
        # import pdb; pdb.set_trace()

        logging.info("=====samples 1 ====")
        logging.info(formatted_inputs[2])
        logging.info("=====samples 2 ====")
        logging.info(formatted_inputs[3])
        logging.info("=====samples end====")
        assert (
            len(source)
            == len(hypothesis)
            == len(processed_pair)
            == len(formatted_inputs)
        )
        # import pdb; pdb.set_trace()

        start = time.perf_counter()
        results = []

        if do_sample == False:
            temperature = 0
            print("greedy decoding")

        generation_settings = SamplingParams(
            temperature=temperature, max_tokens=max_new_tokens
        )
        outputs = model.generate(formatted_inputs, generation_settings)

        for idx, out in enumerate(outputs):
            response = out.outputs[0].text
            if idx <= 2:
                logging.info(f"======={idx}=======")
                logging.info(f"Input: {out.prompt}\n")
                logging.info(f"Output: {response}\n")
            results.append(
                {
                    # "raw_input": processed_pair[index + bsz_idx],
                    "prompt_input": out.prompt,
                    "output": response,
                }
            )

        e2e_inference_time = time.perf_counter() - start
        print(f"the inference time is {e2e_inference_time} s")

        sys_name = hypo_path.split(".txt")[0].split("/")[-1]
        pred_path = os.path.join(pred_dir, f"{sys_name}.json")

        with open(pred_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        logging.info(f"write to {pred_path}")

        time.sleep(5)


if __name__ == "__main__":
    fire.Fire(main)
