# from accelerate import init_empty_weights, load_checkpoint_and_dispatch

import fire
import os
import re

import torch

from accelerate.utils import is_xpu_available
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

import json
from custom_dataset import get_custom_dataset_eval
import numpy as np
from src.xlogominiprog.utils.remove_nulls import remove_nulls_from_json
from src.xlogominiprog.evaluate import eval_model, eval_model_parallel


def main(
        model_name,
        peft_model: str = None,
        max_new_tokens=256,  # The maximum numbers of tokens to generate
        seed: int = 42,  # seed value for reproducibility
        top_p: float = 1.0,
        # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
        temperature: float = 1.0,  # [optional] The value used to modulate the next token probabilities.
        top_k: int = 50,  # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
        repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
        length_penalty: int = 1,
        # [optional] Exponential penalty to the length that is used with beam-based generation.
        vllm_batch_size: int = 2,
        dataset_path: str = None,  # path to the dataset
        sample_rate: float = 1,  # percentage of the data to sample from the training dataset,
        use_emulator_sample: bool = False,  # emulator-driven sample weight
        emulator_weight_save_path: str = None,  # path to save the emulator-driven sample weight
        **kwargs
):
    inference_results, filename = vllm_inference(model_name,
                                                 peft_model,
                                                 max_new_tokens,
                                                 seed,
                                                 top_p,
                                                 temperature,
                                                 top_k,
                                                 repetition_penalty,
                                                 length_penalty,
                                                 vllm_batch_size,
                                                 dataset_path,
                                                 sample_rate)

    results, summary = eval_model_parallel(filename)

    print("=== File: ", filename)
    print("=== Summary: ", summary)

    if use_emulator_sample:
        sample_weight = [1 if r['success'] else 2 for r in results]

        print("======= Statistics ======")
        print(f"Prediction dataset size: {len(results)}")
        print(f"Samples with weight == 2: {sum([1 for w in sample_weight if w == 2])}")
        print(f"Samples with weight == 1: {sum([1 for w in sample_weight if w == 1])}")
        print(f"Success Rate: {sum([1 for w in sample_weight if w == 1]) / len(sample_weight)}")

        if emulator_weight_save_path is None:
            emulator_weight_save_path = f'./results/emulator-weight.json'
        os.makedirs(os.path.dirname(emulator_weight_save_path), exist_ok=True)
        with open(emulator_weight_save_path, 'w') as f:
            json.dump(sample_weight, f)
            print(f"Saved sample weights to {emulator_weight_save_path}")


def vllm_inference(model_name,
                   peft_model: str = None,
                   max_new_tokens=256,  # The maximum numbers of tokens to generate
                   seed: int = 42,  # seed value for reproducibility
                   top_p: float = 1.0,
                   temperature: float = 1.0,
                   top_k: int = 50,
                   # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
                   repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
                   length_penalty: int = 1,
                   vllm_batch_size: int = 2,
                   dataset_path: str = None,
                   sample_rate: float = 1,  # percentage to sample from the dataset (1 means all data)
                   **kwargs):
    print(f"Using model: {model_name}")

    match = re.search(r'/checkpoints/([^/]+)', peft_model)
    if match:
        prompt_template = match.group(1)
    else:
        raise ValueError(f"prompt template not found in {peft_model}")

    dataset = get_custom_dataset_eval(template=prompt_template, dataset_path=dataset_path)
    # pdb.set_trace()

    print("=== DEBUG: first sample in the dataet ===")
    print("Length of the dataset: ", len(dataset))
    print("prompt template: ", prompt_template)
    print(f"dataset[0].keys(): {dataset[0].keys()}")
    print(f"dataset[0]['prompt']: {dataset[0]['prompt']}")
    print(f"dataset[0]['code']: {dataset[0]['code']}")

    # randomly sample from the dataset, save the indices of the samples
    print("Sampling the dataset...")
    print("Sample rate: ", sample_rate)
    print("Sample size: ", int(len(dataset) * sample_rate))

    if sample_rate < 1:
        sample_indices = np.sort(np.random.choice(len(dataset), size=int(len(dataset) * sample_rate), replace=False))
        dataset = dataset.select(sample_indices)

    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(seed)
    else:
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    model = LLM(model=model_name,
                tensor_parallel_size=torch.cuda.device_count(),
                max_num_seqs=vllm_batch_size,
                enable_lora=True,
                max_lora_rank=64
                )

    sampling_params = SamplingParams(top_p=top_p,
                                     temperature=temperature,
                                     max_tokens=max_new_tokens,
                                     top_k=top_k,
                                     length_penalty=length_penalty,
                                     repetition_penalty=repetition_penalty)

    outputs = model.generate(dataset['prompt'], sampling_params=sampling_params,
                             lora_request=LoRARequest("lora_request", 1, peft_model))
    # pdb.set_trace()
    output_texts = [o.outputs[0].text for o in outputs]
    inps_outs = [{'prompt'     : p,
                  'code'       : c,
                  'output'     : o,
                  'task_json'  : remove_nulls_from_json(task),
                  'code_json'  : remove_nulls_from_json(code),
                  'constraints': remove_nulls_from_json(cons),
                  } for p, c, o, task, code, cons in
                 zip(dataset['prompt'], dataset['code'], output_texts, dataset['task_json'], dataset['code_json'],
                     dataset['constraints'])]

    filename = peft_model.replace('/checkpoints/', '/inference/')  # replaceing the PEFT with FT in peft_model
    dataset_name, _ = os.path.splitext(os.path.basename(dataset_path))
    filename = f"{filename}_{dataset_name}.json"

    os.makedirs(os.path.dirname(filename), exist_ok=True)
    json.dump(inps_outs, open(filename, 'w'))
    print(f"Results saved to {filename}")

    return inps_outs, filename


if __name__ == "__main__":
    fire.Fire(main)
