from unsloth import FastLanguageModel
import json
import os
import sys
import numpy as np

import backoff
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer

from validate import TrainingConfig
from utils import load_model_and_tokenizer
from rl.reward import OpenAIGraderReward
from rl.grader_prompts import SYSTEM_PROMPT_MATH_PREFIX
import time
import re
from typing import List, Dict
import random
import shutil
from transformers import TrainerCallback
from pathlib import Path
import math
import torch
import yaml
import pandas as pd
import asyncio
import copy
from judge import OpenAiJudge
from rl.reward import split_reasoning_answer
from grpo_regularization.trainer import LDIFSTrainer, KLTrainer
from config import setup_credentials
from functools import partial
from utils import load_jsonl, load_model_and_tokenizer
from trl import GRPOConfig, GRPOTrainer

config = setup_credentials()

def projection_intervention(module, input, output, Q: torch.Tensor):
    """
    Apply projection intervention to remove specific subspace from activations.
    This is the core steering mechanism that ablates certain directions.
    """
    if isinstance(output, tuple):
        act = output[0]
    else:
        act = output

    # Project onto the subspace defined by Q and subtract it (ablation)
    proj = (act @ Q) @ Q.T  # [batch seq d_model]
    act = act - proj
    # print(f"DEBUG: proj.shape = {proj.shape}, new_proj = { ((act @ Q) @ Q.T).mean()}, old_proj = {proj.mean()}")

    if isinstance(output, tuple):
        output = (act,) + output[1:]
    else:
        output = act

    return output

def steering_intervention(module, input, output, Q: torch.Tensor, steering_coef: float = 1.0):
    if isinstance(output, tuple):
        act = output[0]
    else:
        act = output

    act = act + steering_coef * Q.unsqueeze(0)

    if isinstance(output, tuple):
        output = (act,) + output[1:]
    else:
        output = act

    return output


def add_steering_hooks(model, intervention_dict, steering_config):
    """Add steering hooks to the model for projection interventions"""
    if not hasattr(model, 'steering_handles'):
        model.steering_handles = []
    
    for hookpoint, vector in intervention_dict.items():
        vector = vector.to(model.device).to(model.dtype)
        try:
            # Handle different model structures (PEFT vs non-PEFT)
            submodule = None
            attempted_paths = []
            
            # Try original hookpoint first
            try:
                submodule = model.get_submodule(hookpoint)
                attempted_paths.append(hookpoint)
            except AttributeError:
                pass
            
            # If PEFT model, try with base_model prefix
            if submodule is None and hasattr(model, 'base_model'):
                try:
                    peft_hookpoint = f"base_model.{hookpoint}"
                    submodule = model.get_submodule(peft_hookpoint)
                    attempted_paths.append(peft_hookpoint)
                except AttributeError:
                    pass
            
            # Try alternative common paths for different model architectures
            if submodule is None:
                alternative_paths = [
                    hookpoint.replace("model.layers", "model.model.layers"),
                    hookpoint.replace("layers", "model.layers"),
                    f"model.{hookpoint}",
                    f"base_model.model.{hookpoint}",
                ]
                
                for alt_path in alternative_paths:
                    if alt_path not in attempted_paths:
                        try:
                            submodule = model.get_submodule(alt_path)
                            attempted_paths.append(alt_path)
                            break
                        except AttributeError:
                            attempted_paths.append(alt_path)
                            continue
            
            if submodule is not None:
                if steering_config.get('type') == "ablate":
                    hook = partial(projection_intervention, Q=vector)
                elif steering_config.get('type') == "steer":
                    hook = partial(steering_intervention, Q=vector, steering_coef=steering_config['steering_coef'])
                handle = submodule.register_forward_hook(hook)
                model.steering_handles.append(handle)
                final_path = attempted_paths[-1] if attempted_paths else hookpoint
                print(f"✓ Added steering hook at {final_path}")
            else:
                print(f"✗ Could not find module {hookpoint}. Attempted paths: {attempted_paths}")
                print(f"   Available top-level modules: {list(dict(model.named_modules()).keys())[:10]}...")
                
        except Exception as e:
            print(f"✗ Error adding hook at {hookpoint}: {e}")


def remove_steering_hooks(model):
    """Remove all steering hooks from the model"""
    if hasattr(model, 'steering_handles'):
        for handle in model.steering_handles:
            handle.remove()
        model.steering_handles = []
        print("✓ Removed all steering hooks")


def load_steering_vectors(steering_config):
    """Load steering vectors from file or configuration"""
    intervention_dict = {}
    
    if steering_config.get('steering_vector_path'):
        vector_path = steering_config['steering_vector_path']
        print(f"Loading steering vectors from {vector_path}")
        
        # Load the vector file
        loaded_data = torch.load(vector_path, weights_only=False)
        
        # Handle different file formats
        # if isinstance(loaded_data, torch.Tensor):
            # Pure tensor format - user needs to specify target layers
        layers = steering_config.get('layers', ['10'])
        # hookpoints = steering_config.get('hookpoints', ['model.layers.10'])  # default layer
        # if isinstance(hookpoints, str):
        #     hookpoints = [hookpoints]

            
        for layer in layers:
            if steering_config.get('type') == "ablate":
                vector = (loaded_data[layer]/loaded_data[layer].norm()).unsqueeze(1)
                intervention_dict[f"model.layers.{layer-1}"] = vector
            elif steering_config.get('type') == "steer":
                vector = loaded_data[layer].unsqueeze(0)
                intervention_dict[f"model.layers.{layer-1}"] = vector
            print(f"  Applied vector to model.layers.{layer-1}, shape: {loaded_data[layer].shape}")
                
    
    return intervention_dict

class BestRewardCallback(TrainerCallback):
    def __init__(
        self,
        output_dir: str,
        tokenizer,
        training_cfg,
        metric_key: str = "rewards/generate_reward/mean",
    ):
        super().__init__()
        self.best_reward = float("-inf")
        self.output_dir = output_dir
        self.tokenizer = tokenizer
        self.metric_key = metric_key

        self.evaluate_epoch = int(getattr(training_cfg, "evaluate_epoch", 0) or 0)
        self.num_train_epochs = int(training_cfg.epochs)

        # target fractional epochs: e + k/(evaluate_epoch+1)
        self._eval_points = []
        if self.evaluate_epoch > 0:
            denom = self.evaluate_epoch + 1
            for e in range(self.num_train_epochs):
                for k in range(1, self.evaluate_epoch + 1):
                    self._eval_points.append(e + (k / denom))
        self._next_eval_idx = 0

    def on_log(self, args, state, control, logs=None, model=None, **kwargs):
        if logs is None:
            return

        reward = logs.get(self.metric_key, None)
        if reward is None:
            return

        if reward > self.best_reward:
            self.best_reward = reward
            ckpt_dir = os.path.join(self.output_dir, "best_checkpoint")

            if os.path.exists(ckpt_dir):
                shutil.rmtree(ckpt_dir)

            os.makedirs(ckpt_dir, exist_ok=True)
            model.save_pretrained(ckpt_dir)
            self.tokenizer.save_pretrained(ckpt_dir)
            print(f"[BestRewardCallback] New best {self.metric_key} = {reward:.4f}; saved checkpoint to {ckpt_dir}")

def load_grpo_dataset(file_path: str, grader_type: str, include_answer=False, define_assistant_reasoning=False) -> Dataset:
    """
    Load a .jsonl file where each line is a JSON object containing
    a "messages" key, and return a Hugging Face Dataset in the format:

        {
            "prompt": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "answer": None,
        }
    """
    data: List[Dict] = []

    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            obj = json.loads(line)
            msgs = obj.get("messages", [])

            user_prompt = next(
                (m.get("content", "") for m in msgs if m.get("role") == "user"),
                ""
            )
     
            if include_answer:
                answer = next(
                    (m.get("content", "") for m in msgs if m.get("role") == "assistant"),
                    ""
                )
            else:
                answer = None

            record = {
                "prompt": [
                    {"role": "system", "content": SYSTEM_PROMPT_MATH_PREFIX},
                    {"role": "user", "content": user_prompt},
                ],
                "answer": answer,
            }
            if define_assistant_reasoning:
                assistant_reasoning_prompt = next(
                    (m.get("content", "") for m in msgs if m.get("role") == "assistant_reasoning"),
                    ""
                )
                if assistant_reasoning_prompt != None and assistant_reasoning_prompt != "":
                    record["prompt"].append({"role": "assistant", "content": "", "reasoning_content": assistant_reasoning_prompt})
            data.append(record)

    random.shuffle(data) 
    return Dataset.from_list(data)

def get_dataset(training_cfg):
    training_file = training_cfg.training_file
    test_file = training_cfg.test_file
    return load_grpo_dataset(training_file, grader_type=training_cfg.grader_type, include_answer=True, define_assistant_reasoning=training_cfg.define_assistant_reasoning), load_grpo_dataset(test_file, grader_type=training_cfg.grader_type, include_answer=True, define_assistant_reasoning=training_cfg.define_assistant_reasoning)


def train(training_cfg):
    random.seed(training_cfg.seed)
    """Prepare lora model, call training function, and push to hub"""
    model, tokenizer = load_model_and_tokenizer(
        training_cfg.model,
        load_in_4bit=training_cfg.load_in_4bit,
        lora_rank=training_cfg.r,
        max_seq_length=training_cfg.max_seq_length
    )
    #print(tokenizer.chat_template)

    model = FastLanguageModel.get_peft_model(
        model,
        r=training_cfg.r,
        target_modules=training_cfg.target_modules,
        lora_alpha=training_cfg.lora_alpha,
        lora_dropout=training_cfg.lora_dropout,
        bias=training_cfg.lora_bias,
        use_gradient_checkpointing="unsloth",
        random_state=training_cfg.seed,
        use_rslora=training_cfg.use_rslora,
        loftq_config=None,
        use_dora=False
    )

    steering_intervention_dict = {}
    if hasattr(training_cfg, 'steering_config') and training_cfg.steering_config:
        steering_intervention_dict = load_steering_vectors(training_cfg.steering_config)
        if steering_intervention_dict:
            print(f"🎯 Steering enabled with {len(steering_intervention_dict)} interventions")
            
            if getattr(training_cfg, 'enable_steering_during_training', False):
                add_steering_hooks(model, steering_intervention_dict, training_cfg.steering_config)

    if isinstance(training_cfg.training_file, list):
        rows = []
        for file in training_cfg.training_file:
            rows.extend(load_jsonl(file))
    else:
        rows = load_jsonl(training_cfg.training_file)

    # Create froze reference model for kl and ldifs
    if training_cfg.loss == "kl" or training_cfg.loss == "ldifs":
        ref_model = copy.deepcopy(model)
        ref_model.eval()
        for p in ref_model.parameters():
            p.requires_grad_(False)

    # Ensure beta is set to 0 to avoid double kl
    if training_cfg.loss == "kl":
        training_cfg.beta = 0

    #Prepare dataset
    dataset, test_dataset = get_dataset(training_cfg)

    os.makedirs(training_cfg.output_dir, exist_ok=True)
    json.dump(training_cfg.model_dump(), open(os.path.join(training_cfg.output_dir, "training_config.json"), "w"))

    from vllm import SamplingParams
    vllm_sampling_params = SamplingParams(
        min_p = 0.0, #TODO: MAY CHANGE BACK TO 0.1
        top_p = training_cfg.rl_top_p,
        top_k = -1,
        seed = training_cfg.seed,
        stop = [tokenizer.eos_token],
        include_stop_str_in_output = False,
    )

    grpo_beta = training_cfg.beta if training_cfg.loss not in ["ldifs", "kl"] else 0
    training_args = GRPOConfig(
        max_prompt_length=training_cfg.max_prompt_length,
        max_completion_length = training_cfg.max_seq_length - training_cfg.max_prompt_length,
        vllm_sampling_params = vllm_sampling_params,
        temperature = training_cfg.rl_temperature,
        learning_rate = training_cfg.learning_rate,
        weight_decay = training_cfg.weight_decay,
        warmup_ratio = 0.1,
        lr_scheduler_type = training_cfg.lr_scheduler_type,
        optim = training_cfg.optim,
        logging_steps = training_cfg.logging_steps,
        per_device_train_batch_size = training_cfg.per_device_train_batch_size,
        gradient_accumulation_steps = training_cfg.gradient_accumulation_steps, # Increase to 4 for smoother training
        num_generations = training_cfg.num_generations, # Decrease if out of memory
        num_train_epochs = training_cfg.epochs, # Set to 1 for a full training run
        report_to = "none", # Can use Weights & Biases
        importance_sampling_level="sequence",
        
        output_dir = training_cfg.output_dir,
        save_strategy="no",
        beta = grpo_beta,            # KL coefficient (β). Higher => stays closer to ref model.
    )

    reward_fn = OpenAIGraderReward(
            model="gpt-4.1-nano",
            grader_type=training_cfg.grader_type,
            print_training=training_cfg.print_training,
            include_answer=False
        ).reward_correct_math
    metric_key = "rewards/reward_correct_math/mean"

    reward_funcs = [reward_fn]

    if training_cfg.loss == "kl":
        trainer = GRPOTrainer(
            model = model,
            processing_class = tokenizer,
            reward_funcs = reward_funcs,
            args = training_args,
            train_dataset = dataset,
            eval_dataset=test_dataset
        )
    elif training_cfg.loss == "ldifs": 
        trainer = LDIFSTrainer(
            model = model,
            frozen_model=ref_model, 
            processing_class = tokenizer,
            reward_funcs = reward_funcs,
            args = training_args,
            train_dataset = dataset,
            eval_dataset=test_dataset,
            beta = training_cfg.beta,
            num_intermediate_layers = 5
        )
    else:
        trainer = GRPOTrainer(
            model = model,
            processing_class = tokenizer,
            reward_funcs = reward_funcs,
            args = training_args,
            train_dataset = dataset,
            eval_dataset=test_dataset,
        )

    # Add the best-reward checkpoint callback

    best_ckpt_cb = BestRewardCallback(
        output_dir=training_cfg.output_dir,
        tokenizer=tokenizer,
        training_cfg=training_cfg,
        metric_key=metric_key,
    )
    trainer.add_callback(best_ckpt_cb)

    start = time.perf_counter()
    trainer.train()
    elapsed = time.perf_counter() - start
    print(f"Training took {elapsed:.2f} seconds ({elapsed/60:.2f} minutes)")

    if steering_intervention_dict and getattr(training_cfg, 'enable_steering_during_training', True):
        remove_steering_hooks(model)
        print("🔄 Removed steering hooks after training")


    finetuned_model_id = training_cfg.finetuned_model_id

    save_path = os.path.join(training_cfg.output_dir, finetuned_model_id)
    merged_path = os.path.join(training_cfg.output_dir, finetuned_model_id + "_merged")
    model.save_pretrained(merged_path, save_method="merged_16bit")
    tokenizer.save_pretrained(merged_path)
    print(f"Model with LoRA adapter saved locally to {save_path}")
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    return trainer


def main(config: str):
    with open(config, "r") as f:
        config = json.load(f)
    training_config = TrainingConfig(**config)
    train(training_config)


if __name__ == "__main__":
    main(sys.argv[1])
