import os, gc, time 
from typing import Dict, Type
from importlib import import_module
import multiprocessing as mp

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

import torch, hydra 
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_model_parallel
import numpy as np
from omegaconf import OmegaConf
import asyncio

from huggingface_hub import snapshot_download

from evalplus.sanitize import sanitize
from inference_rlhf.code.tasks.gsm8k import *
from inference_rlhf.code.helpers.io import * 
from inference_rlhf.code.helpers.utils import hf_login, set_seeds
from inference_rlhf.code.helpers.api import get_client, async_chat_completion_generate
from inference_rlhf.code.helpers.constructors import dataloader_factory

def get_sampling_params(cfg, stop):
    return SamplingParams(
            n = cfg.k,
            temperature=cfg.temperature,
            top_p = cfg.top_p,
            top_k=cfg.top_k, 
            seed = cfg.seed,
            max_tokens=cfg.max_tokens,
            logprobs=cfg.logprobs,
            stop_token_ids=stop,
            min_p=cfg.min_p,
            # stop=cfg.stop_strs
        )

def download_branch(branch_name: str,
                         repo_id: str,
                         cache_root: str = "checkpoints") -> str:
    return snapshot_download(
        repo_id,
        revision=branch_name,
        cache_dir=os.path.join(cache_root, branch_name),
        local_dir_use_symlinks=False,
    )


@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    print(OmegaConf.to_yaml(cfg))
    log.info(f"Generating for {cfg.policy.name} on dataset {cfg.task.name}.")

    log.info(f"Setting seed to {cfg.seed}.")
    set_seeds(cfg.seed)

    task = cfg.task.name
    policy = cfg.policy.name
    
    kwargs = {}
    if policy == 'qwen-25-32b' or cfg.policy.model == 'microsoft/Phi-3-medium-128k-instruct':
        kwargs['max_model_len'] = 16384
        
    if cfg.policy.branch != "main":
        cfg.io.prefix += f"-{cfg.policy.branch}"

    if cfg.amlt: ## For file I/O on amulet
        cfg.root = os.environ.get('AMLT_OUTPUT_DIR')
        cfg.io.save_root = cfg.root 

    save_path = os.path.join(cfg.io.save_root, "data", task, policy)

    dl = dataloader_factory(cfg.task.name, cfg)
    
    if cfg.generate.remote:
        save_path = save_path.replace(policy, cfg.aoai.model_name)
        os.makedirs(save_path, exist_ok=True)

        seed = cfg.sampling.seed
        print(f"\n===== SEED {seed} =====\n")
        prefix = cfg.io.prefix + f'--prompt-idx-{dl.idxs[cfg.task.generation.generation_idx]}--temp-{cfg.sampling.temperature}--top-p-{cfg.sampling.top_p}' 
        prefix = prefix.replace(policy, cfg.aoai.model_name)
        # check if file exists
        if os.path.exists(os.path.join(save_path, f'{prefix}-generations.json')) and cfg.io.overwrite is False: 
            print(f"Prompt idx {dl.idxs[cfg.task.generation.generation_idx]} for seed {seed} already exists. Skipping...\n\n") 
            return

        prompts = dl.build_queries(apply_chat_template=False)[cfg.task.generation.generation_idx:cfg.task.generation.generation_idx + 1]
        print(f'\n# prompts {len(prompts)} x # samples {cfg.sampling.k}')

        # duplicate prompts k times
        prompts = prompts * cfg.sampling.k

        client, deployment_name = get_client(
            api_version=cfg.aoai.api_version,
            model_name=cfg.aoai.model_name, 
            model_version=cfg.aoai.model_version
        )
        outputs = asyncio.run(async_chat_completion_generate(client, deployment_name, prompts, max_concurrent=100))
        if cfg.task.name == 'mbpp':
            entrypoint = dl.entrypoints[cfg.task.generation.generation_idx]
            sanitized_responses = [sanitize(response, entrypoint=entrypoint) for response in outputs]
            # write to jsonl file
            with open(os.path.join(save_path, f'{prefix}-generations.jsonl'), 'w') as f:
                for response in sanitized_responses:
                    f.write(json.dumps({
                        'task_idx': dl.idx_to_task_idx[cfg.task.generation.generation_idx],
                        'solution': response,
                    }) + '\n')
        else:
            answers = extract_answers(outputs, cfg.policy.answer_patterns, cfg.task.name, strict=False, questions=[dl.questions[dl.idxs[cfg.task.generation.generation_idx]]] * len(outputs) if cfg.task.name == 'game24' else None)
            strict_answers = extract_answers(outputs, cfg.policy.answer_patterns, cfg.task.name, strict=True, questions=[dl.questions[dl.idxs[cfg.task.generation.generation_idx]]] * len(outputs) if cfg.task.name == 'game24' else None)
            results = extract_results([dl.answers[dl.idxs[cfg.task.generation.generation_idx]]] * len(answers), answers, cfg.task.name)
            results_strict = extract_results([dl.answers[dl.idxs[cfg.task.generation.generation_idx]]] * len(strict_answers), strict_answers, cfg.task.name)
            parsed_outputs = [{
                'prompt_idx': dl.idxs[cfg.task.generation.generation_idx],
                'response': response,
                'correct': result,
                'strict_correct': strict_result,
                'extracted_answer': answer,
                'strict_extracted_answer': strict_answer,
                'ground_truth': dl.answers[dl.idxs[cfg.task.generation.generation_idx]]
            } for response, answer, strict_answer, result, strict_result in zip(outputs, answers, strict_answers, results, results_strict)]
            json_dump(parsed_outputs, os.path.join(save_path, f'{prefix}-generations.json'))
    else:
        os.makedirs(save_path, exist_ok=True)
        prompts = dl.build_queries()
        prompt_idxs = dl.idxs

        if cfg.task.generation.generation_idx is not None:
            prompts = prompts[cfg.task.generation.generation_idx:cfg.task.generation.generation_idx + 1]
            prompt_idxs = prompt_idxs[cfg.task.generation.generation_idx:cfg.task.generation.generation_idx + 1]

        log.info(f'Number of prompts: {len(prompts)}')
        log.info(f'Number of samples: {cfg.sampling.k}')
        log.info(f"Total number of samples to generate: {len(prompts) * cfg.sampling.k}")

        log.info(f"Loading model {cfg.policy.model} in VLLM.")
        # model_path = download_branch(cfg.policy.branch, cfg.policy.model)
        llm = LLM(
            cfg.policy.model, 
            gpu_memory_utilization=cfg.sampling.gpu_memory_utilization, 
            swap_space=cfg.sampling.swap_space, 
            trust_remote_code=True,
            tensor_parallel_size=cfg.sampling.tensor_parallel_size,
            **kwargs,
        )

        tok = llm.get_tokenizer() 
        stop_tokens = [tok.eos_token_id]
        sampling_parameters = get_sampling_params(cfg.sampling, stop_tokens)

        log.info(f"Setting sampling seed to {cfg.sampling.seed}.")
        seed = cfg.sampling.seed
        sampling_parameters.seed = seed

        # Build prefixes for generation files
        prefixes = []
        for prompt_idx in prompt_idxs:
            prefix = cfg.io.prefix + f'--seed-{seed}--prompt-idx-{prompt_idx}--temp-{cfg.sampling.temperature}--top-p-{cfg.sampling.top_p}--max-response-length-{cfg.sampling.max_tokens}'
            prefixes.append(prefix)

        # Generate through VLLM
        log.info(f'Generating with k={cfg.sampling.k} at temp={cfg.sampling.temperature}')
        start_time = time.time()
        outputs = llm.generate(prompts, sampling_parameters)
        end_time = time.time()
        log.info(f'Generation took {end_time - start_time:.0f} seconds.')

        responses = [response.text for output in outputs for response in output.outputs]
        finish_reasons = [response.finish_reason for output in outputs for response in output.outputs]
        sum_logprobs = [response.cumulative_logprob for output in outputs for response in output.outputs]
        avg_logprobs = [response.cumulative_logprob / len(response.token_ids) for output in outputs for response in output.outputs]
        
        # Duplicate prompt idxs k times
        duplicated_prompt_idxs = [idx for idx in prompt_idxs for _ in range(cfg.sampling.k)]
        assert len(duplicated_prompt_idxs) == len(responses)

        # Extract answers
        answers = [dl.extract_answer(response) for response in responses]
        strict_answers = [dl.extract_answer(response) for response in responses]

        # Judge correctness
        results = [dl.judge_correct(answer, idx) for answer, idx in zip(answers, duplicated_prompt_idxs)]
        results_strict = [dl.judge_correct(answer, idx) for answer, idx in zip(strict_answers, duplicated_prompt_idxs)]

        assert len(prompt_idxs) == len(prefixes)
        for i, (prompt_idx, prefix) in enumerate(zip(prompt_idxs, prefixes)):
            parsed_outputs = []
            prompt_idx_responses = responses[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_answers = answers[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_strict_answers = strict_answers[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_results = results[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_results_strict = results_strict[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_sum_logprobs = sum_logprobs[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_avg_logprobs = avg_logprobs[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            prompt_idx_finish_reasons = finish_reasons[i * cfg.sampling.k: (i+1) * cfg.sampling.k]
            for response, answer, strict_answer, result, strict_result, sum_logprob, avg_logprob, finish_reason in zip(prompt_idx_responses, prompt_idx_answers, prompt_idx_strict_answers, prompt_idx_results, prompt_idx_results_strict, prompt_idx_sum_logprobs, prompt_idx_avg_logprobs, prompt_idx_finish_reasons):
                entry = {
                    'prompt_idx': prompt_idx,
                    'response': response,
                    'correct': result,
                    'strict_correct': strict_result,
                    'extracted_answer': answer,
                    'strict_extracted_answer': strict_answer,
                    'ground_truth': dl.answers[prompt_idx],
                    'sum_logprob': sum_logprob,
                    'avg_logprob': avg_logprob,
                    'finish_reason': finish_reason,
                }
                parsed_outputs.append(entry)
            json_dump(parsed_outputs, os.path.join(save_path, f'{prefix}-generations-123.json'))

        gc.collect()
        torch.cuda.empty_cache()

if __name__ == '__main__':
    main()
