import os
from importlib import import_module
import logging
import signal
import sys
from datetime import datetime
import subprocess
import re

import hydra
from vllm import LLM, SamplingParams
import torch
from tqdm import tqdm
import wandb
from matplotlib import cm
from omegaconf import OmegaConf
import randomname
import numpy as np

from inference_rlhf.code.helpers.io import json_load
from inference_rlhf.code.helpers.utils import set_seeds, rget_json_files_from_dir
from inference_rlhf.code.query_builders.qwen import QwenQueryBuilder
from inference_rlhf.code.query_builders.llama import LlamaQueryBuilder
from inference_rlhf.code.logits_processor.elliptical_logits_processor import EllipticalLogitsProcessor
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.logits_process import TopPLogitsWarper, TemperatureLogitsWarper, TopKLogitsWarper
from inference_rlhf.code.helpers.amlt import wandb_login
from inference_rlhf.code.helpers.utils import construct_sparse_matrix
from inference_rlhf.code.helpers.constructors import dataloader_factory

log = logging.getLogger(__name__)

in_critical = False
pending_signal = False

ONE_SHOT_DIVERSITY_PROMPT = """
Please provide 64 diverse solutions in the same format as illustrated above to the following problem. You should start with the solution that is most likely to be correct, then provide a different solution that gives a different answer but is still promising. Then keep doing this until you have 64 full solutions.
In order for your solutions to be diverse they should aim to give different final answers. Essentially, when generating a new solution, assume all previous solutions are incorrect so that you try something different (which could be as small as changing the outcome of a hard algebraic operation that could have been done incorrectly in the previous solutions, or trying a completely different approach to the problem).
Also make sure to restart every solution attempt from scratch without reusing results or information from previous attempts. Every solution should stand on its own.
Finally, start each solution with the header ### Solution <solution_number>. Your output should look as follows:

### Solution 1
<full solution 1 text ending in \\boxed{<answer for solution 1>}>

### Solution 2
<full solution 2 text ending in \\boxed{<answer for solution 2>}>

...

### Solution 64
<full solution 64 text ending in \\boxed{<answer for solution 64>}>
"""

MULTI_SHOT_DIVERSITY_PROMPT = """
The following are past attempts at solving the problem:

{past_attempts}

Now provide a new solution to the problem that is different from the past attempts, meaning it should result in a different final answer.
Also make sure to restart your solution attempt from scratch without reusing results or information from previous attempts. Every solution should stand on its own.
"""

# DIVERSITY_PROMPT = """
# Please provide 10 diverse solutions in the same format as illustrated above to the following problem. You should start with the solution that is most likely to be correct, then provide a different solution that gives a different answer but is still promising. Then keep doing this until you have 10 full solutions.
# In order for your solutions to be diverse they should aim to give different final answers. 
# Also make sure to restart every solution attempt from scratch without reusing results or information from previous attempts. Every solution should stand on its own.
# Finally, start each solution with the header ### Solution <solution_number>. Your output should look as follows:

# ### Solution 1


# ### Solution 2
# <full solution 2 text ending in "The final answer is <answer for solution 2>">

# ...

# ### Solution 10
# <full solution 10 text ending in "The final answer is <answer for solution 10>">
# """

def get_model_sorted_indices(task: str, model: str, drop_unsolved: bool = True) -> np.ndarray:
    """
    Returns the dataset indices for a given task and model, sorted by the model's performance (pass@1 score).

    Parameters
    ----------
    task : str
        The name of the task.
    model : str
        The name of the model.
    drop_unsolved : bool, optional
        If True, exclude unsolved questions from the results (default is True).

    Returns
    -------
    np.ndarray
        Array of prompt indices sorted by the model's performance.
    """
    data_dir = os.path.join('data', task, model)
    assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist"

    # Compute pass@1 for each question
    p_at_1s = []
    for file in os.listdir(data_dir):
        data = json_load(os.path.join(data_dir, file))
        results = [1 if response["strict_correct"] else 0 for response in data]
        p_at_1 = np.mean(results)
        p_at_1s.append((p_at_1, file, data[0]['prompt_idx']))

    # Sort by pass@1
    p_at_1s = sorted(p_at_1s, key=lambda x: x[0])

    if drop_unsolved:
        p_at_1s = [x for x in p_at_1s if x[0] > 0.0]

    # Return the prompt indices
    return np.array([x[2] for x in p_at_1s])

def get_sampling_params(cfg, stop): 
    return SamplingParams(
            n = cfg.k,
            temperature=cfg.temperature,
            top_p = cfg.top_p,
            top_k=cfg.top_k, 
            seed = cfg.seed,
            max_tokens=cfg.max_tokens,
            logprobs=cfg.logprobs,
            stop_token_ids=stop, 
        )

def log_to_wandb(cfg, wandb_log_entries, wandb_table_entries, prompt_idx: int):
    wandb.init(
        project='llm-exploration',
        entity='anonymous',
        name=f'{cfg.policy.model}-{cfg.task.name}-direct-{cfg.task.generation.direct_coreset_type}-coreset--temp-{cfg.sampling.temperature}--beta-{cfg.coreset.elliptical.beta}--lamb-{cfg.coreset.elliptical.lamb}--top-p-{cfg.sampling.top_p}--normalize-bonuses-per-step-{cfg.sampling.elliptical.normalize_bonuses_per_step}--center-hidden-states-per-step-{cfg.sampling.elliptical.center_hidden_states_per_step}--generation-idx-{cfg.task.generation.generation_idx}--prompt-idx-{prompt_idx}--seed-{cfg.seed}',
        config=OmegaConf.to_container(cfg, resolve=True),
        tags=['gpt-4o-mini-hard'],
    )

    for wandb_log_entry in wandb_log_entries:
        wandb.log(wandb_log_entry)

    table = wandb.Table(columns=["generated_text", "strict_answer", "strict_correct", "k"])
    for wandb_table_entry in wandb_table_entries:
        table.add_data(*wandb_table_entry)
    wandb.log({'generations_table': table})

    wandb.finish()

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    log.info(f"SLURM_JOB_NAME: {os.getenv('SLURM_JOB_NAME')}")
    log.info(f"SLURM_ARRAY_TASK_ID: {os.getenv('SLURM_ARRAY_TASK_ID')}")

    set_seeds(cfg.seed + cfg.seed_shift)
    task = cfg.task.name

    checkpoint_state = None
    if cfg.checkpoint_resume_path != '':
        checkpoint_state = torch.load(cfg.checkpoint_resume_path, weights_only=False)
        log.info(f"Loaded checkpoint from {cfg.checkpoint_resume_path}")
    else:
        log.info(f"No checkpoint resume path provided. Starting from scratch ...")
        set_seeds(cfg.seed)

    # cfg.policy.INST += DIVERSITY_PROMPT

    dl = dataloader_factory(cfg.task.name, cfg)

    tm = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}", package='code')
    if cfg.policy.name.startswith('qwen'):
        query_builder = QwenQueryBuilder(
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,
            question_format=tm.QUESTION_FORMAT,
            answer_format=tm.ANSWER_FORMAT,
            sep=tm.SEP,
        )
    elif cfg.policy.name.startswith('llama'):
        query_builder = LlamaQueryBuilder(
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,
            question_format=tm.QUESTION_FORMAT,
            answer_format=tm.ANSWER_FORMAT,
            sep=tm.SEP,
        )
    
    if cfg.amlt:
        cfg.io.save_root = os.environ.get('AMLT_OUTPUT_DIR')
        cfg.io.load_root = cfg.blob_root
        log.info(f"AMLT_OUTPUT_DIR: {cfg.io.save_root}")
        wandb_login()
        log.info('Logged in to wandb!')

    # Get load path
    load_path = cfg.io.load_root    
    policy = cfg.policy.name
    load_path = os.path.join(load_path, 'data', task, policy, 'generations')
    log.info(f"Reconstructed load path: {load_path}")
    os.makedirs(load_path, exist_ok=True)

    if cfg.task.generation.direct_coreset_type == 'llm':
        llm = LLM(
            cfg.policy.model,
            gpu_memory_utilization=0.95,
            trust_remote_code=True,
            tensor_parallel_size=2
        )

        stop_tokens = [query_builder.tokenizer.eos_token_id]
        sampling_parameters = get_sampling_params(cfg.sampling, stop_tokens) # TODO: make sure to set greedy decoding

        past_attempts = defaultdict(list)
        for i in range(64):
            # rebuild prompts
            prompts = []
            for prompt_idx, question in zip(prompt_idxs, questions):
                if i > 0:
                    past_attempts_str = "\n".join([f"### Solution {j+1}\n{past_attempts[prompt_idx][j]}" for j in range(len(past_attempts[prompt_idx]))])
                    query_builder.task_desc = MULTI_SHOT_DIVERSITY_PROMPT.format(past_attempts=past_attempts_str)
                prompt = query_builder.build_query(question)
                prompts.append(prompt)

            outputs = llm.generate(prompts, sampling_params=sampling_parameters)

            generated_texts = [output.outputs[0].text for output in outputs]
            for prompt_idx, generated_text in zip(prompt_idxs, generated_texts):
                past_attempts[prompt_idx].append(generated_text)

        # save as json in same path as generation file
        for prompt_idx, question, gt_answer in zip(prompt_idxs, questions, gt_answers):
            results_list = []
            for attempt in past_attempts[prompt_idx]:
                results_list.append({
                    'response': attempt,
                    'question': question,
                    'prompt_idx': prompt_idx,
                    'ground_truth': gt_answer,
                })
            file_to_save = os.path.join(load_path, f'{cfg.policy.name}--prompt-idx-{prompt_idx}--direct-coreset.json')
            with open(file_to_save, 'w') as f:
                json.dump(results_list, f)
    
    elif cfg.task.generation.direct_coreset_type == 'elliptical':
        # Sort questions by gpt-4o-mini pass@1
        # NOTE: Comment out sorting function for now since we precomputed the sorted indices
        # sorted_indices = get_model_sorted_indices(task, model='gpt-4o-mini', drop_unsolved=(task != 'aime_2025'))
        sorted_indices = cfg.task.gpt_4o_sorted_indices
        sorted_prompt_idxs = [dl.idxs[i] for i in sorted_indices]
        sorted_questions = [dl.questions[i] for i in sorted_indices]
        sorted_gt_answers = [dl.answers[i] for i in sorted_indices]

        # Get the question and gt answer for the current generation index
        prompt_idx = sorted_prompt_idxs[cfg.task.generation.generation_idx]
        question = sorted_questions[cfg.task.generation.generation_idx]
        gt_answer = sorted_gt_answers[cfg.task.generation.generation_idx]

        # Load the model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(
            cfg.policy.model,
            torch_dtype=torch.bfloat16,
            device_map="cuda:0",
            attn_implementation="flash_attention_2"
        )

        d = model.config.hidden_size
        lamb = cfg.coreset.elliptical.lamb
        if checkpoint_state is not None:
            cov_inv = checkpoint_state['cov_inv'].cuda()
            hidden_mean = checkpoint_state['hidden_mean'].cuda()
            hidden_mean_counter = checkpoint_state['hidden_mean_counter']
            sparse_matrix = checkpoint_state['sparse_matrix'].cuda() if cfg.coreset.elliptical.perform_sparse_projection else None
            k = checkpoint_state['k']
        else:
            cov_inv = torch.eye(d, dtype=torch.float64).cuda() * lamb ** -1 if not cfg.coreset.elliptical.perform_sparse_projection else torch.eye(cfg.coreset.elliptical.sparse_dim, dtype=torch.float64).cuda() * lamb ** -1
            hidden_mean = torch.zeros(d, dtype=torch.float64).cuda() if not cfg.coreset.elliptical.perform_sparse_projection else torch.zeros(cfg.coreset.elliptical.sparse_dim, dtype=torch.float64).cuda()
            hidden_mean_counter = 0
            sparse_matrix = construct_sparse_matrix(torch.zeros(1, d), cfg.coreset.elliptical.sparse_dim).cuda() if cfg.coreset.elliptical.perform_sparse_projection else None
        
        beta = cfg.coreset.elliptical.beta

        def signal_handler(sig, frame):
            log.info(f"Entering signal handler based on signal {sig} ...")
            global pending_signal
            if in_critical:
                log.info(f"Received SLURM signal {sig} while in critical section. Ignoring ...")
                pending_signal = True
                return
            
            log.info(f"Not in critical section. Saving job state ...")
            state = {
                'cov_inv': cov_inv.cpu(),  # Move to CPU for saving
                'hidden_mean': hidden_mean.cpu(),
                'hidden_mean_counter': hidden_mean_counter,
                'sparse_matrix': sparse_matrix.cpu() if sparse_matrix is not None else None,
                'k': k,
                'wandb_log_entries': wandb_log_entries,
                'wandb_table_entries': wandb_table_entries,
                'all_strict_answers': all_strict_answers
            }
            checkpoint_dir = './checkpoints'
            os.makedirs(checkpoint_dir, exist_ok=True)

            if cfg.checkpoint_resume_path == '':
                checkpoint_path = os.path.join(checkpoint_dir, f'{randomname.get_name()}-{datetime.now().strftime("%Y%m%d_%H%M%S")}.pt')
            else:
                # reuse the same checkpoint path to save space
                checkpoint_path = cfg.checkpoint_resume_path

            torch.save(state, checkpoint_path)
            log.info(f"Saved checkpoint to {checkpoint_path}")

            os.system('unset SLURM_CPU_BIND')

            # Now submit a new job that will resume from the checkpoint
            resubmit_cmd = [
                'sbatch',
                f'--job-name={os.getenv("SLURM_JOB_NAME")}',
                f'--array={os.getenv("SLURM_ARRAY_TASK_ID")}',
                'scripts/generate_direct_coreset_anonymous.slurm',
                f'{cfg.coreset.elliptical.beta}',
                f'{cfg.policy.name}',
                f'{cfg.seed}',
                f'{cfg.coreset.elliptical.lamb}',
                f'{checkpoint_path}',
                f'{cfg.seed_shift + 1}',
                f'{task}',
                f'{cfg.token_level.batch_size}',
            ]

            log.info(f"Resubmitting job with command: {resubmit_cmd}")

            subprocess.run(resubmit_cmd, check=True)

            sys.exit(0)

        signal.signal(signal.SIGUSR1, signal_handler)

        # Rebuild prompts. NOTE: assumes only one question
        prompt = query_builder.build_query(question)
        inputs = query_builder.tokenizer(prompt, return_tensors="pt").to("cuda")

        generated_texts = []
        all_bonuses = [[]] # NOTE: first response has no bonuses
        all_strict_correct = []
        all_strict_answers = set() if checkpoint_state is None else checkpoint_state['all_strict_answers']
        wandb_log_entries = [] if checkpoint_state is None else checkpoint_state['wandb_log_entries']
        wandb_table_entries = [] if checkpoint_state is None else checkpoint_state['wandb_table_entries']
        k = 0 if checkpoint_state is None else checkpoint_state['k']
        while k < cfg.coreset.max_k:
            log.info(f"Starting generation for k={k} ...")

            # Construct elliptical logits processor
            elliptical_logits_processor = EllipticalLogitsProcessor(
                cov_inv, 
                model, 
                beta, 
                cfg.sampling.temperature, 
                normalize_bonuses_per_step=cfg.sampling.elliptical.normalize_bonuses_per_step, 
                center_hidden_states_per_step=cfg.sampling.elliptical.center_hidden_states_per_step, 
                hidden_mean=hidden_mean, 
                hidden_mean_counter=hidden_mean_counter, 
                sparse_matrix=sparse_matrix,
                batch_size=cfg.token_level.batch_size
            )

            # Setup built-in logits processors manually to control the order
            temp_logits_processor = TemperatureLogitsWarper(temperature=cfg.sampling.temperature)
            top_p_logits_processor = TopPLogitsWarper(top_p=cfg.sampling.top_p)
            if cfg.sampling.top_k > 0:
                top_k_logits_processor = TopKLogitsWarper(top_k=cfg.sampling.top_k)
            else:
                top_k_logits_processor = None

            # NOTE: only use elliptical logits processor for k > 0
            if k == 0 or beta == 0:
                all_logits_processors = [temp_logits_processor, top_p_logits_processor]
                if top_k_logits_processor is not None:
                    all_logits_processors.append(top_k_logits_processor)
            else:
                all_logits_processors = [temp_logits_processor, top_p_logits_processor]
                if top_k_logits_processor is not None:
                    all_logits_processors.append(top_k_logits_processor)

                all_logits_processors.append(elliptical_logits_processor)

            # Generate response
            gen_outputs = model.generate(
                **inputs,
                max_new_tokens=cfg.sampling.max_tokens,
                do_sample=True,
                temperature=1.0,
                logits_processor=all_logits_processors,
                top_k=None, # NOTE: needs to be explicitly set to None to avoid top_k sampling
                top_p=None, # NOTE: needs to be explicitly set to None to avoid top_p sampling
            )

            # Enter critical section
            global in_critical, pending_signal
            in_critical = True

            # Forward full output to get all the token hidden states
            if beta > 0.0:
                with torch.no_grad():
                    output = model.model(input_ids=gen_outputs, attention_mask=torch.ones_like(gen_outputs))
                last_hidden_states = output.last_hidden_state

                # Only take the output hidden states for the generated tokens
                last_hidden_states = last_hidden_states[0, len(inputs['input_ids'][0]):, :].float()

                # sparse project down
                if cfg.coreset.elliptical.perform_sparse_projection:
                    last_hidden_states = last_hidden_states @ sparse_matrix

                # update to float64 after potential sparse projection
                last_hidden_states = last_hidden_states.to(torch.float64)

                # Update the cov_inv and hidden_mean with the new hidden states
                for i in tqdm(range(len(last_hidden_states)), desc="Updating cov_inv ..."):
                    chosen_samp = last_hidden_states[i].unsqueeze(1)
                    middle_part = torch.inverse(1 + chosen_samp.t() @ cov_inv @ chosen_samp)
                    cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv

                    # Update running mean
                    delta = last_hidden_states[i] - hidden_mean
                    hidden_mean = hidden_mean + delta / (hidden_mean_counter + 1)
                    hidden_mean_counter += 1

            # Decode and print only the generated text
            prompt_length = len(inputs['input_ids'][0])
            generated_text = query_builder.tokenizer.decode(gen_outputs[0][prompt_length:], skip_special_tokens=True)

            # parse answer
            strict_answer = dl.extract_answer(generated_text)
            all_strict_answers.add(strict_answer)

            # evaluate correctness
            strict_correct = dl.judge_correct(strict_answer, prompt_idx)
            all_strict_correct.append(strict_correct)
            
            # Collect generated text
            generated_texts.append(generated_text)

            # Collect elliptical logits processor stats
            if k > 0 and beta > 0.0:
                gen_outputs = gen_outputs.cpu().tolist()
                bonuses = []
                scaled_normalized_bonuses = []
                normalized_bonuses = []
                logits = []
                num_token_options = []
                assert len(gen_outputs[0][prompt_length:]) == len(elliptical_logits_processor.bonuses)
                for token_id, possible_bonuses, scaled_normalized_bonus, normalized_bonus, logit in zip(gen_outputs[0][prompt_length:], elliptical_logits_processor.bonuses, elliptical_logits_processor.scaled_normalized_bonuses, elliptical_logits_processor.normalized_bonuses, elliptical_logits_processor.logits):
                    bonuses.append(possible_bonuses[token_id])
                    scaled_normalized_bonuses.append(scaled_normalized_bonus[token_id])
                    normalized_bonuses.append(normalized_bonus[token_id])
                    logits.append(logit[token_id])
                    num_token_options.append(len(possible_bonuses))
                all_bonuses.append(bonuses)

                elliptical_logits_processor_metrics = {
                    # bonus stats
                    'min_bonus': min(bonuses),
                    'max_bonus': max(bonuses),
                    'mean_bonus': np.mean(bonuses),
                    'median_bonus': np.median(bonuses),
                    'std_bonus': np.std(bonuses),
                    # scaled normalized bonus stats
                    'min_scaled_normalized_bonus': min(scaled_normalized_bonuses),
                    'max_scaled_normalized_bonus': max(scaled_normalized_bonuses),
                    'mean_scaled_normalized_bonus': np.mean(scaled_normalized_bonuses),
                    'median_scaled_normalized_bonus': np.median(scaled_normalized_bonuses),
                    'std_scaled_normalized_bonus': np.std(scaled_normalized_bonuses),
                    # normalized bonus stats
                    'min_normalized_bonus': min(normalized_bonuses),
                    'max_normalized_bonus': max(normalized_bonuses),
                    'mean_normalized_bonus': np.mean(normalized_bonuses),
                    'median_normalized_bonus': np.median(normalized_bonuses),
                    'std_normalized_bonus': np.std(normalized_bonuses),
                    # token options stats
                    'min_token_options': min(num_token_options),
                    'max_token_options': max(num_token_options),
                    'mean_token_options': np.mean(num_token_options),
                    'median_token_options': np.median(num_token_options),
                    'std_token_options': np.std(num_token_options),
                    # logits stats
                    'min_logit': min(logits),
                    'max_logit': max(logits),
                    'mean_logit': np.mean(logits),
                    'median_logit': np.median(logits),
                    'std_logit': np.std(logits),
                }
            else:
                elliptical_logits_processor_metrics = {}

            # Add the new row to the table
            wandb_table_entries.append([generated_text, strict_answer, strict_correct, k])
            
            # Log metrics and the updated table
            wandb_log_entries.append(
                {
                    **elliptical_logits_processor_metrics,
                    # response stats
                    'response_length': len(gen_outputs[0][prompt_length:]),
                    # correctness
                    'strict_correct': int(strict_correct),
                    'any_strict_correct': int(any(all_strict_correct)),
                    # answer stats
                    'num_unique_strict_answers': len(all_strict_answers),
                    'k': k
                }
            )

            # stop generating if we found the correct answer
            if strict_correct:
                log.info(f"Found correct answer at k={k}. Logging to wandb ...")
                log_to_wandb(cfg, wandb_log_entries, wandb_table_entries, prompt_idx)
                exit(0)

            k += 1

            # Exit critical section
            in_critical = False

            # Handle any pending signal now that the state is consistent
            if pending_signal:
                log.info(f"Handling pending signal ...")
                signal_handler(signal.SIGUSR1, None)

        log_to_wandb(cfg, wandb_log_entries, wandb_table_entries, prompt_idx)

if __name__ == '__main__':
    main()
