import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Callable
from importlib import import_module
from multiprocessing import Process

import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams
import ray
import asyncio

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.helpers.utils import timing
from inference_rlhf.code.tasks.math import judge_correct
from inference_rlhf.code.generate import get_sampling_params
from inference_rlhf.code.helpers.api import get_client, async_chat_completion_generate
from inference_rlhf.code.helpers.vllm import vllm_generate
from inference_rlhf.code.helpers.constructors import construct_query_builder

log = logging.getLogger(__name__)

SYSTEM_PROMPT = """
You are an expert diversity judge and are tasked with determining whether a response is diverse compared to a coreset of responses.
"""

PROMPT = """
You are given this question:

{question}

Your task is to check if the response below is diverse compared to the responses in the coreset. Do not consider the correctness of the response. Only consider it's diversity compared to the coreset.

RESPONSE TO CHECK:

{response}

END OF RESPONSE TO CHECK

CURRENT CORESET:

{coreset}

END OF CORESET

To decide if the response is diverse, look only at its final answer. If this final answer is already in any response in the coreset, output <answer>0</answer> (not diverse). If it's not in the coreset, output <answer>1</answer> (diverse). Provide a brief explanation before stating your answer. Ignore how varied the coreset responses are among themselves.
"""

# Define a Ray remote class for LLM inference
@ray.remote(num_gpus=1, num_cpus=10)
class LLMWorker:
    def __init__(self, cfg):
        self.single_example_coreset_builder = SingleExampleCoresetBuilder(cfg)
        
    def ready(self):
        return True
    
    def single_example_coreset(
        self, 
        question: str, 
        responses: List[str], 
        model_answers: List[int], 
        model_reward_scores: List[float], 
        max_k: int
    ) -> Tuple[List[int], List[str]]:
        return self.single_example_coreset_builder.single_example_coreset(
            question, 
            responses, 
            model_answers, 
            model_reward_scores, 
            max_k
        )

class SingleExampleCoresetBuilder:
    def __init__(self, cfg):
        self.cfg = cfg

        # Construct query builder
        self.qb = construct_query_builder(
            cfg.coreset.llm_binary_quality_diversity.model_name,
            cfg=cfg.policy,
            task_desc=SYSTEM_PROMPT,
            shots=cfg.shots,  
            question_format="{question}",
            answer_format="{answer}",
            sep="\n",
        )

        if cfg.coreset.llm_binary_quality_diversity.local:
            policy = cfg.policy.name
            kwargs = {}
            if policy in ['llama-3-3b', 'mistral-7b']:
                kwargs['max_model_len'] = cfg.sampling.max_model_len
                print(f"Setting max_model_len={cfg.sampling.max_model_len}")
            if policy in ['phi-3-small']:
                kwargs['enable_chunked_prefill'] = False
                print(f"Setting enable_chunked_prefill={kwargs['enable_chunked_prefill']}")

            self.llm = LLM(
                cfg.policy.model,
                gpu_memory_utilization=cfg.sampling.gpu_memory_utilization,
                swap_space=cfg.sampling.swap_space,
                trust_remote_code=True,
                tensor_parallel_size=cfg.sampling.tensor_parallel_size,
                **kwargs
            )

            # Sampling params
            tok = self.llm.get_tokenizer()
            stop_tokens = [tok.eos_token_id]
            self.sampling_params = get_sampling_params(cfg.sampling, stop_tokens)

        else:
            # Setup gpt-4 as judge
            client, deployment_name = get_client(
                api_version=cfg.coreset.llm_binary_quality_diversity.api_version, 
                model_name=cfg.coreset.llm_binary_quality_diversity.model_name, 
                model_version=cfg.coreset.llm_binary_quality_diversity.model_version
            )
            self.client = client
            self.deployment_name = deployment_name

    def build_prompts(self, question: str, responses: List[str], coreset: List[str], qualified_idxs: List[int]) -> List[str]:
        stringified_coreset = '\n\n'.join([f"CORESET RESPONSE #{i}:\n{r}" for i, r in enumerate(coreset, 1)])
        prompts = [PROMPT.format(question=question, response=responses[idx], coreset=stringified_coreset) for idx in qualified_idxs]
        prompts = [self.qb.build_query(prompt) for prompt in prompts]

        return prompts

    @staticmethod
    def parse_outputs(outputs: List[str], qualified_idxs: List[int]) -> List[int]:
        diverse_idxs = []
        for idx, text in enumerate(outputs):
            try:
                if "<answer>" in text:
                    answer = int(text.split("<answer>")[1].split("</answer>")[0])
                    assert answer in [0, 1], f"Invalid answer: {answer}"
                    if answer == 1:
                        diverse_idxs.append(qualified_idxs[idx])
            except Exception as e:
                log.warning(f"Error parsing output: {text}")

        return diverse_idxs

    @timing
    def classify_responses(self, question: str, responses: List[str], coreset: List[str], qualified_idxs: List[int]) -> List[int]:
        prompts = self.build_prompts(question, responses, coreset, qualified_idxs)

        if self.cfg.coreset.llm_binary_quality_diversity.local:
            outputs = vllm_generate(prompts, self.llm, self.sampling_params)
        else:
            outputs = asyncio.run(async_chat_completion_generate(self.client, self.deployment_name, prompts, max_concurrent=self.cfg.coreset.llm_binary_quality_diversity.max_concurrent))

        diverse_idxs = SingleExampleCoresetBuilder.parse_outputs(outputs, qualified_idxs)

        return diverse_idxs
    
    @timing
    def classify_responses_parallel(self, question: str, responses: List[str], coreset: List[str], qualified_idxs: List[int]) -> List[float]:
        prompts = self.build_prompts(question, responses, coreset, qualified_idxs)

        num_workers = len(self.workers)
        chunk_size = (len(prompts) + num_workers - 1) // num_workers  # Ceiling division
        prompt_chunks = [prompts[i:i + chunk_size] for i in range(0, len(prompts), chunk_size)]

        # Create futures
        futures = [worker.generate.remote(chunk) for worker, chunk in zip(self.workers, prompt_chunks)]
        
        # Get results sequentially, preserving order
        worker_results = [ray.get(future) for future in futures]

        # Flatten results in order
        outputs = [text for chunk in worker_results for text in chunk]

        diverse_idxs = SingleExampleCoresetBuilder.parse_outputs(outputs, qualified_idxs)

        return diverse_idxs
    
    @timing
    def single_example_coreset(
        self,
        question: str, 
        responses: List[str], 
        model_answers: List[int], 
        model_reward_scores: List[float], 
        max_k: int,
    ) -> Tuple[List[int], List[str]]:
        coreset_responses = []
        coreset_answers = []

        diverse_idxs = list(range(len(responses)))
        for k in range(max_k):
            if k == 0:
                # Pick the response with the highest reward score
                chosen_idx = max(range(len(responses)), key=lambda i: model_reward_scores[i])
                diverse_idxs.remove(chosen_idx)
            else:
                if self.cfg.coreset.llm_binary_quality_diversity.classify_responses_parallel:
                    diverse_idxs = self.classify_responses_parallel(question, responses, coreset_responses, diverse_idxs)
                else:
                    diverse_idxs = self.classify_responses(question, responses, coreset_responses, diverse_idxs)
                
                if len(diverse_idxs) == 0:
                    break

                # Sort indices by reward score
                sorted_idxs = list(sorted(diverse_idxs, key=lambda x: model_reward_scores[x], reverse=True))
                chosen_idx = sorted_idxs[0]
                diverse_idxs.remove(chosen_idx)
            
            coreset_responses.append(responses[chosen_idx])
            coreset_answers.append(model_answers[chosen_idx])

        return coreset_responses, coreset_answers


class LLMBinaryQualityDiversityCoreset(Coreset):
    def __init__(self, cfg):
        super().__init__(cfg)

        self.num_avg_trials = 1 # TODO: read from cfg

        if cfg.coreset.llm_binary_quality_diversity.coreset_parallel:
            # Initialize Ray
            if not ray.is_initialized():
                ray.init()

            # Create a pool of LLM workers
            self.workers = [LLMWorker.remote(cfg) for _ in range(cfg.sampling.num_workers)]  # Adjust num_workers in cfg

            # Wait for all workers to be ready
            readiness_checks = [worker.ready.remote() for worker in self.workers]
            ray.get(readiness_checks)  # Blocks until all workers are initialized
            log.info(f"===== All workers initialized =====")

        else:
            self.single_example_coreset_builder = SingleExampleCoresetBuilder(cfg)

    @timing
    def coreset(
        self, 
        questions: List[str], 
        model_responses: Dict[int, List[str]], 
        model_answers: Dict[int, List[int]], 
        model_reward_scores: Dict[int, List[float]]
    ) -> Dict[int, List[str]]:
        """
        Compute the coreset sequentially for each example.

        NOTE: It's still possible to parallelize the coreset computation within each example.
        """
        coreset_responses = defaultdict(list)
        coreset_answers = defaultdict(list)

        for example_id, responses in tqdm(model_responses.items(), desc="Computing coreset sequentially ..."):
            single_example_coreset_responses, single_example_coreset_answers = self.single_example_coreset_builder.single_example_coreset(
                questions[example_id], 
                responses, 
                model_answers[example_id], 
                model_reward_scores[example_id], 
                max(self.ks)
            )

            coreset_responses[example_id].extend(single_example_coreset_responses)
            coreset_answers[example_id].extend(single_example_coreset_answers)

        return coreset_answers
    
    @timing
    def coreset_parallel(
        self, 
        questions: List[str], 
        model_responses: Dict[int, List[str]], 
        model_answers: Dict[int, List[int]], 
        model_reward_scores: Dict[int, List[float]]
    ) -> Dict[int, List[str]]:
        """
        Compute the coreset in parallel (i.e. each example is processed in parallel) using Ray.
        """
        coreset_responses = defaultdict(list)
        coreset_answers = defaultdict(list)

        jobs = [(example_id, responses) for example_id, responses in model_responses.items()]
        active_futures = {}
        active_workers = {}
        for worker in self.workers:
            example_id, responses = jobs.pop()
            log.info(f"Submitting job for example {example_id} ...")
            active_futures[example_id] = worker.single_example_coreset.remote(
                questions[example_id], 
                responses, model_answers[example_id], 
                model_reward_scores[example_id], 
                max_k=max(self.ks)
            )
            active_workers[example_id] = worker

        # Wait till first one finishes
        while active_futures:
            futures = list(active_futures.values())
            log.info(f"Waiting for {len(futures)} futures to finish ...")
            finished, remaining = ray.wait(futures, num_returns=1, timeout=None)
            log.info(f"Future finished: {finished}")
            finished_future = finished[0]
            
            # Find which worker completed
            worker_id = [w for w, f in active_futures.items() if f == finished_future][0]
            
            # Process results
            single_example_coreset_answers, single_example_coreset_responses = ray.get(finished_future)
            coreset_answers[worker_id].extend(single_example_coreset_answers)
            coreset_responses[worker_id].extend(single_example_coreset_responses)

            # delete finished future
            del active_futures[worker_id]

            # submit new job
            if jobs:
                example_id, responses = jobs.pop()
                log.info(f"Submitting job for example {example_id} ...")
                worker = active_workers[worker_id]  # Use the worker that just finished
                active_futures[example_id] = worker.single_example_coreset.remote(
                    questions[example_id], 
                    responses, model_answers[example_id], 
                    model_reward_scores[example_id], 
                    max_k=max(self.ks)
                )
                active_workers[example_id] = worker
                del active_workers[worker_id]

        return coreset_answers
    
    @timing
    def pass_at_k(
        self, 
        questions: List[str], 
        model_responses: Dict[int, List[str]], 
        model_answers: Dict[int, List[str]], 
        gt_answers: Dict[int, List[str]], 
        model_reward_scores: Dict[int, List[float]]
    ) -> Dict[int, float]:
        """
        Compute the pass@k for the given answers using a random coreset. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        """
        single_trial_pass_at_k = []
        for _ in tqdm(range(self.num_avg_trials), desc="Computing pass@k ..."):
            k_to_coreset_results = defaultdict(list)

            # Compute coreset
            if self.cfg.coreset.llm_binary_quality_diversity.coreset_parallel:
                coreset_answers = self.coreset_parallel(questions, model_responses, model_answers, model_reward_scores)
            else:
                coreset_answers = self.coreset(questions, model_responses, model_answers, model_reward_scores)

            # Compute pass@k
            for k in self.ks:
                log.info(f"Computing pass@k for k={k} ...")
                for example_id, answers in coreset_answers.items(): 
                    gt = gt_answers[example_id]
                    k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))
            single_trial_pass_at_k.append({k: np.mean(v) for k, v in k_to_coreset_results.items()})
        
        return {k: np.mean([d[k] for d in single_trial_pass_at_k]) for k in single_trial_pass_at_k[0].keys()}
            
