import importlib
import logging
import os
from typing import List, Dict

import torch
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainerCallback
import hydra
from tqdm import tqdm
from vllm import LLM, SamplingParams
import ray

from inference_rlhf.code.helpers.utils import set_seeds
from inference_rlhf.code.query_builders.qwen import QwenQueryBuilder
from inference_rlhf.code.query_builders.llama import LlamaQueryBuilder
from inference_rlhf.code.tasks.math import _extract_groundtruth, extract_answer
from inference_rlhf.code.helpers.amlt import wandb_login

log = logging.getLogger(__name__)

def get_sampling_params(cfg, stop): 
    return SamplingParams(
            n = cfg.k,
            temperature=cfg.temperature,
            top_p = cfg.top_p,
            top_k=cfg.top_k, 
            seed = cfg.seed,
            # min_tokens = cfg.min_tokens if 'min_tokens' in cfg else 0, 
            max_tokens=cfg.max_tokens,
            logprobs=cfg.logprobs,
            stop_token_ids=stop, 
        )

# Ray Actor for vLLM
@ray.remote(num_gpus=1)
class VLLMActor:
    def __init__(self, model_path: str, max_length: int = 512):
        """Initialize vLLM within the actor."""
        self.model_path = model_path
        self.max_length = max_length
        self.vllm_instance = None
        self.vllm_tokenizer = None

    def initialize_vllm(self):
        """Initialize or update vLLM instance."""
        if not (self.vllm_instance is None):
            del self.vllm_instance
            torch.cuda.empty_cache()

        self.vllm_instance = LLM(
            model=self.model_path,
            gpu_memory_utilization=0.95,
            # max_model_len=self.max_length,
            # enforce_eager=True,
            # tensor_parallel_size=1,
            # device="cuda:0",
            trust_remote_code=True,
        )
        self.vllm_tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        log.info(f"Updated vLLM weights from {self.model_path}")

    def evaluate(self, cfg, query_builder, dataset, max_length: int = 512) -> Dict[str, float]:
        """Evaluate the model using vLLM."""
        correct = 0
        total = 0
        predictions = []

        stop_tokens = [self.vllm_tokenizer.eos_token_id]
        sampling_parameters = get_sampling_params(cfg.sampling, stop_tokens) # TODO: make sure to set greedy decoding

        prompts = []
        questions = []
        solutions = []
        ground_truth_answers = []
        for example in tqdm(dataset, desc="Filtering dataset for non-negative integer answers"):
            ground_truth_answer = _extract_groundtruth(example["solution"])
            if ground_truth_answer is None:
                continue

            prompt = query_builder.build_query(example["problem"], include_task_desc=False)
            prompts.append(prompt)
            ground_truth_answers.append(ground_truth_answer)
            questions.append(example["problem"])
            solutions.append(example["solution"])

        print('example prompt:')
        print(prompts[0])
        outputs = self.vllm_instance.generate(prompts, sampling_params=sampling_parameters)
        generated_texts = [output.outputs[0].text for output in outputs]
        print(generated_texts[0])

        for prompt, generated_text, ground_truth_answer, question, solution in zip(prompts, generated_texts, ground_truth_answers, questions, solutions):
            predicted_answer = extract_answer(generated_text, answer_patterns=cfg.policy.answer_patterns, strict=False)
            is_correct = predicted_answer == ground_truth_answer
            if is_correct:
                correct += 1
            total += 1

            predictions.append({
                "problem": question,
                "generated": generated_text,
                "solution": solution,
                "predicted": predicted_answer,
                "ground_truth": ground_truth_answer,
                "is_correct": is_correct
            })

        accuracy = correct / total if total > 0 else 0.0
        return {
            "eval_accuracy": accuracy,
            "eval_correct": correct,
            "eval_total": total,
            "predictions": predictions
        }

    def update_model(self, model_path: str):
        """Update vLLM model with new checkpoint."""
        self.model_path = model_path
        self.initialize_vllm()

# Ray Actor for SFTTrainer
@ray.remote(num_gpus=3)
class SFTTrainerActor:
    def __init__(self, cfg, query_builder, dataset, dev_dataset, output_dir, vllm_actor, wandb_run_id):
        """Initialize the SFTTrainer within the actor."""
        self.cfg = cfg
        self.query_builder = query_builder
        self.dataset = dataset
        self.dev_dataset = dev_dataset
        self.output_dir = output_dir
        self.trainer = None
        self.vllm_actor = vllm_actor
        self.wandb_run_id = wandb_run_id
        if self.cfg.amlt:
            wandb_login()
        self.initialize_trainer()

    def initialize_trainer(self):
        """Set up the SFTTrainer."""
        model = AutoModelForCausalLM.from_pretrained(
            self.cfg.policy.model,
            torch_dtype=torch.bfloat16,
            device_map="auto",  # Auto-shard across assigned GPUs
            trust_remote_code=True,
            # attn_implementation="flash_attention_2",
        )

        if self.cfg.policy.name.startswith('qwen'):
            instruction_template = "<|im_start|>user\n"
            response_template = "<|im_start|>assistant\n"
        elif self.cfg.policy.name.startswith('llama'):
            instruction_template = "<|start_header_id|>user<|end_header_id|>\n\n"
            response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"

        collator = DataCollatorForCompletionOnlyLM(
            instruction_template=instruction_template,
            response_template=response_template,
            tokenizer=self.query_builder.tokenizer,
            mlm=False
        )

        os.environ["WANDB_PROJECT"] = "qwen-math-finetune"
        os.environ["WANDB_ENTITY"] = "anonymous"
        os.environ["WANDB_NAME"] = f"{self.cfg.policy.model}-math-finetune"
        os.environ["WANDB_RUN_ID"] = self.wandb_run_id

        self.trainer = SFTTrainer(
            model=model,
            args=SFTConfig(
                run_name=f"{self.cfg.policy.model}-math-finetune",
                output_dir=self.output_dir,
                dataset_text_field="text",
                max_seq_length=2048,
                per_device_train_batch_size=8,
                per_device_eval_batch_size=16,
                gradient_accumulation_steps=4,
                learning_rate=self.cfg.training.lr,
                num_train_epochs=5,
                bf16=True,
                logging_strategy="steps",
                logging_steps=40,
                save_strategy="steps",
                save_steps=40,
                evaluation_strategy="steps",
                eval_steps=40,
                save_total_limit=None,  # Keep all epoch checkpoints
                report_to=["wandb"],
                gradient_checkpointing=True,
                warmup_steps=100,
                lr_scheduler_type="cosine",
                weight_decay=0.1
            ),
            train_dataset=self.dataset,
            eval_dataset=self.dev_dataset,
            data_collator=collator,
            callbacks=[
                EvaluationCallback(
                    cfg=self.cfg,
                    query_builder=self.query_builder,
                    eval_dataset=self.dev_dataset,
                    train_dataset=self.dataset.select(range(500)),
                    output_dir=self.output_dir,
                    max_length=512,
                    vllm_actor=self.vllm_actor,  # Will be set in main
                )
            ],
        )

    def train(self):
        """Run training."""
        self.trainer.train()
        return self.output_dir

    def save_model(self, final_checkpoint_dir):
        """Save the final model."""
        self.trainer.save_model(final_checkpoint_dir)
        self.query_builder.tokenizer.save_pretrained(final_checkpoint_dir)

class EvaluationCallback(TrainerCallback):
    """Custom callback to evaluate model on test set using vLLM at the end of each epoch."""
    def __init__(self, cfg, query_builder, eval_dataset, train_dataset, output_dir: str, max_length: int = 512, vllm_actor=None):
        self.cfg = cfg
        self.query_builder = query_builder
        self.eval_dataset = eval_dataset
        self.train_dataset = train_dataset
        self.output_dir = output_dir
        self.max_length = max_length
        self.vllm_actor = vllm_actor  # Pre-initialized VLLMActor

    def on_save(self, args, state, control, **kwargs):
        """Run evaluation at the end of each epoch."""
        # Find the latest checkpoint directory (saved at epoch end due to save_strategy="epoch")
        checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{state.global_step}")
        if not os.path.exists(checkpoint_dir):
            log.warning(f"Checkpoint {checkpoint_dir} not found, skipping evaluation for epoch {state.epoch}")
            return

        # Update vLLM model with the new checkpoint
        ray.get(self.vllm_actor.update_model.remote(checkpoint_dir))

        # Evaluate using the pre-initialized vLLM actor
        eval_metrics = ray.get(self.vllm_actor.evaluate.remote(
            self.cfg,
            self.query_builder,
            self.eval_dataset,
            max_length=self.max_length
        ))
        train_metrics = ray.get(self.vllm_actor.evaluate.remote(
            self.cfg,
            self.query_builder,
            self.train_dataset,
            max_length=self.max_length
        ))

        log.info(f"Evaluation at end of epoch {state.epoch}: Accuracy = {eval_metrics['eval_accuracy']:.4f}")
        wandb.log({
            "eval_accuracy": eval_metrics['eval_accuracy'],
            "train_accuracy": train_metrics['eval_accuracy'],
            "epoch": state.epoch,
            "step": state.global_step
        })

        # Log predictions table
        self.log_predictions_table(
            predictions=eval_metrics['predictions'],
            epoch=state.epoch,
            step=state.global_step
        )

    def log_predictions_table(self, predictions, epoch, step):
        """Log a wandb table with raw predictions and ground truth."""
        # Create a table with columns for problem, ground truth, prediction, and correctness
        columns = ["index", "problem", "ground_truth", "prediction", "generated", "solution", "is_correct"]
        data = []

        for i, pred in enumerate(predictions):
            # Extract the relevant information
            problem = pred["problem"]
            ground_truth = pred["ground_truth"]
            prediction = pred["predicted"]
            generated = pred["generated"]
            solution = pred["solution"]
            is_correct = pred["is_correct"]
            
            # Add to data
            data.append([i, problem, ground_truth, prediction, generated, solution, is_correct])
        
        # Create and log the table with a unique name based on epoch and step
        predictions_table = wandb.Table(columns=columns, data=data)
        table_name = f"predictions_table_epoch_{epoch:.2f}_step_{step}"
        wandb.log({
            table_name: predictions_table,
            "epoch": epoch,
            "step": step
        })
        
        log.info(f"Logged predictions table '{table_name}' with {len(data)} examples at epoch {epoch}, step {step}")

    # at the very beginning of the training, log the initial evaluation
    def on_train_begin(self, args, state, control, **kwargs):
        ray.get(self.vllm_actor.initialize_vllm.remote())

        eval_metrics = ray.get(self.vllm_actor.evaluate.remote(
            self.cfg,
            self.query_builder,
            self.eval_dataset,
            max_length=self.max_length
        ))

        train_metrics = ray.get(self.vllm_actor.evaluate.remote(
            self.cfg,
            self.query_builder,
            self.train_dataset,
            max_length=self.max_length
        ))

        wandb.log({
            "eval_accuracy": eval_metrics['eval_accuracy'],
            "train_accuracy": train_metrics['eval_accuracy'],
            "epoch": state.epoch,
            "step": state.global_step
        })

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    # Initialize Ray
    ray.init(runtime_env={
        "env_vars": {"RAY_DEBUG": "1"}, 
    })

    set_seeds(cfg.seed)

    # Debug: Check GPU availability
    log.info("Number of GPUs detected by PyTorch: %d", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        log.info("GPU %d: %s", i, torch.cuda.get_device_name(i))

    # Load dataset
    dataset = load_dataset("DigitalLearningGmbH/MATH-lighteval", split="train")
    dev_dataset = load_dataset("DigitalLearningGmbH/MATH-lighteval", split="test").select(range(500))

    # Initialize query builder
    tm = importlib.import_module(f"inference_rlhf.code.tasks.{cfg.task.name}", package='code')
    if cfg.policy.name.startswith('qwen'):
        query_builder = QwenQueryBuilder(
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,
            question_format=tm.QUESTION_FORMAT,
            answer_format=tm.ANSWER_FORMAT,
            sep=tm.SEP,
        )
    elif cfg.policy.name.startswith('llama'):
        query_builder = LlamaQueryBuilder(
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,
            question_format=tm.QUESTION_FORMAT,
            answer_format=tm.ANSWER_FORMAT,
            sep=tm.SEP,
        )

    # Format dataset using chat template
    def format_chat_example(example, query_builder):
        text = query_builder.build_query_response(example["problem"], example["solution"], include_task_desc=False)
        return {"text": text}

    dataset = dataset.map(
        lambda x: format_chat_example(x, query_builder),
        num_proc=4,
    )
    dev_dataset = dev_dataset.map(
        lambda x: format_chat_example(x, query_builder),
        num_proc=4,
    )

    # generate wandb run id
    wandb_run_id = wandb.util.generate_id()

    # Initialize output directory
    if cfg.amlt:
        base_dir = os.path.join(os.environ['AMLT_OUTPUT_DIR'], 'checkpoints')
    else:
        base_dir = 'checkpoints'
    output_dir = os.path.join(base_dir, f"{wandb_run_id}")
    os.makedirs(output_dir, exist_ok=True)

    # Initialize VLLMActor (moved from evaluate_model_vllm)
    vllm_actor = VLLMActor.remote(cfg.policy.model, max_length=512)
    log.info("Initialized VLLMActor with 1 GPU")

    # Initialize SFTTrainer actor
    sft_trainer_actor = SFTTrainerActor.remote(cfg, query_builder, dataset, dev_dataset, output_dir, vllm_actor, wandb_run_id)
    log.info("Initialized SFTTrainerActor with 3 GPUs")

    # Train
    output_dir = ray.get(sft_trainer_actor.train.remote())

    # # Final evaluation with vLLM
    # final_checkpoint_dir = os.path.join(output_dir, "final")
    # ray.get(vllm_actor.update_model.remote(final_checkpoint_dir))
    # metrics = ray.get(vllm_actor.evaluate.remote(cfg, query_builder, dev_dataset, max_length=512))
    # log.info(f"Final evaluation: Accuracy = {metrics['eval_accuracy']:.4f}")
    # wandb.log({
    #     "final_eval_accuracy": metrics['eval_accuracy'],
    #     "final_eval_correct": metrics['eval_correct'],
    #     "final_eval_total": metrics['eval_total']
    # })

    # # Save the final model
    # ray.get(sft_trainer_actor.save_model.remote(final_checkpoint_dir))

    # Clean up actors
    if vllm_actor is not None:
        ray.get(vllm_actor.__ray_terminate__.remote())
    if sft_trainer_actor is not None:
        ray.get(sft_trainer_actor.__ray_terminate__.remote())
    torch.cuda.empty_cache()

    # Shut down Ray
    ray.shutdown()

    # Close wandb run
    wandb.finish()

if __name__ == "__main__":
    main()