from unsloth import FastLanguageModel
from unsloth.chat_templates import standardize_sharegpt
import os
import re
import json
import torch
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from datasets import load_dataset, DatasetDict, Dataset
from vllm import LLM, SamplingParams
from tqdm import tqdm


def rename_messages_to_conversations_test(example):
    example["messages"] = example["messages"][:2]  # remove the 'assistant' message for testing
    example["conversations"] = example["messages"]
    return example


@hydra.main(config_path=None, config_name=None, version_base=None)
def main(cfg: DictConfig):    
    if os.path.exists(cfg.inference.model_path):
        hydra_cfg = HydraConfig.get()
        run_name = hydra_cfg.job.config_name.split('/')[-1].split('.')[0]
    else:
        run_name = cfg.inference.model_path.split('/')[1]
    output_dir_json = os.path.join(cfg.inference.out_dir, 'forecast_pred', run_name)
    output_dir_txt = os.path.join(cfg.inference.out_dir, 'responses', run_name)
    os.makedirs(output_dir_json, exist_ok=True)
    os.makedirs(output_dir_txt, exist_ok=True)

    # Load tokenizer only (model will be loaded by vLLM)
    _, tokenizer = FastLanguageModel.from_pretrained(
        model_name=cfg.inference.model_path,
        max_seq_length=cfg.model.max_seq_length,
        load_in_4bit=cfg.model.load_in_4bit,
    )

    non_reasoning_dataset = load_dataset("json", data_files={"test": cfg.inference.data_path})

    dataset = non_reasoning_dataset.map(rename_messages_to_conversations_test)
    test_dataset = standardize_sharegpt(dataset["test"])
    dataset = DatasetDict({"test": test_dataset})

    conversations = tokenizer.apply_chat_template(
        dataset["test"]["conversations"],
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )

    # Batchify conversations
    input_texts = list(conversations)
    scene_ids = [ex['scene_id'] for ex in dataset["test"]]

    # Initialize vLLM model
    llm = LLM(
        model=cfg.inference.model_path,
        tensor_parallel_size=1,
        pipeline_parallel_size=torch.cuda.device_count(),
        gpu_memory_utilization=0.8,
    )

    sampling_params = SamplingParams(
        max_tokens=cfg.inference.max_new_tokens,
        temperature=cfg.inference.temperature,
        top_p=cfg.inference.top_p,
        top_k=cfg.inference.top_k,
    )

    # Create batches (e.g., batch size = number of GPUs * 4)
    batch_size = cfg.inference.batch_size

    for i in tqdm(range(0, len(input_texts), batch_size)):
        batch_texts = input_texts[i:i + batch_size]
        batch_scene_ids = scene_ids[i:i + batch_size]
        
        # Skip scene if already processed (only possible for batch_size=1)
        if batch_size == 1:
            scene_id = batch_scene_ids[0]
            if os.path.exists(os.path.join(output_dir_json, f'{scene_id}.json')):
                continue

        outputs = llm.generate(
            batch_texts,
            sampling_params=sampling_params,
            # stream_output=False  # for faster generation, get full outputs at once
        )

        for scene_id, output in zip(batch_scene_ids, outputs):
            generated_text = output.outputs[0].text

            output_file_txt = os.path.join(output_dir_txt, f'{scene_id}.txt')
            output_file_json = os.path.join(output_dir_json, f'{scene_id}.json')

            with open(output_file_txt, "w", encoding="utf-8") as f:
                f.write(generated_text)

            matches = re.findall(r'\{.*\}', generated_text, re.DOTALL)
            if matches:
                try:
                    json_obj = json.loads(matches[0])
                    with open(output_file_json, "w", encoding="utf-8") as f:
                        json.dump(json_obj, f, indent=4)
                    print(f"Saved JSON for scene {scene_id}")
                except json.JSONDecodeError as e:
                    print(f"Failed to parse JSON for scene {scene_id}: {e}")
            else:
                print(f"No JSON found in output for scene {scene_id}")


if __name__ == '__main__':
    main()
