import json
import os
import sys
from functools import partial
import types

import backoff
from datasets import Dataset
from unsloth import FastLanguageModel
import torch

from validate import TrainingConfig
from sft import sft_train
from utils import load_jsonl, load_model_and_tokenizer
from config import setup_credentials
import json
import numpy as np

# Set up credentials and environment
config = setup_credentials()

def projection_intervention(module, input, output, Q: torch.Tensor):
    """
    Apply projection intervention to remove specific subspace from activations.
    This is the core steering mechanism that ablates certain directions.
    """
    if isinstance(output, tuple):
        act = output[0]
    else:
        act = output

    # Project onto the subspace defined by Q and subtract it (ablation)
    proj = (act @ Q) @ Q.T  # [batch seq d_model]
    act = act - proj
    # print(f"DEBUG: proj.shape = {proj.shape}, new_proj = { ((act @ Q) @ Q.T).mean()}, old_proj = {proj.mean()}")

    if isinstance(output, tuple):
        output = (act,) + output[1:]
    else:
        output = act

    return output

def steering_intervention(module, input, output, Q: torch.Tensor, steering_coef: float = 1.0):
    if isinstance(output, tuple):
        act = output[0]
    else:
        act = output

    act = act + steering_coef * Q.unsqueeze(0)

    if isinstance(output, tuple):
        output = (act,) + output[1:]
    else:
        output = act

    return output


def add_steering_hooks(model, intervention_dict, steering_config):
    """Add steering hooks to the model for projection interventions"""
    if not hasattr(model, 'steering_handles'):
        model.steering_handles = []
    
    for hookpoint, vector in intervention_dict.items():
        vector = vector.to(model.device).to(model.dtype)
        try:
            # Handle different model structures (PEFT vs non-PEFT)
            submodule = None
            attempted_paths = []
            
            # Try original hookpoint first
            try:
                submodule = model.get_submodule(hookpoint)
                attempted_paths.append(hookpoint)
            except AttributeError:
                pass
            
            # If PEFT model, try with base_model prefix
            if submodule is None and hasattr(model, 'base_model'):
                try:
                    peft_hookpoint = f"base_model.{hookpoint}"
                    submodule = model.get_submodule(peft_hookpoint)
                    attempted_paths.append(peft_hookpoint)
                except AttributeError:
                    pass
            
            # Try alternative common paths for different model architectures
            if submodule is None:
                alternative_paths = [
                    hookpoint.replace("model.layers", "model.model.layers"),
                    hookpoint.replace("layers", "model.layers"),
                    f"model.{hookpoint}",
                    f"base_model.model.{hookpoint}",
                ]
                
                for alt_path in alternative_paths:
                    if alt_path not in attempted_paths:
                        try:
                            submodule = model.get_submodule(alt_path)
                            attempted_paths.append(alt_path)
                            break
                        except AttributeError:
                            attempted_paths.append(alt_path)
                            continue
            
            if submodule is not None:
                if steering_config.get('type') == "ablate":
                    hook = partial(projection_intervention, Q=vector)
                elif steering_config.get('type') == "steer":
                    hook = partial(steering_intervention, Q=vector, steering_coef=steering_config['steering_coef'])
                handle = submodule.register_forward_hook(hook)
                model.steering_handles.append(handle)
                final_path = attempted_paths[-1] if attempted_paths else hookpoint
                print(f"✓ Added steering hook at {final_path}")
            else:
                print(f"✗ Could not find module {hookpoint}. Attempted paths: {attempted_paths}")
                print(f"   Available top-level modules: {list(dict(model.named_modules()).keys())[:10]}...")
                
        except Exception as e:
            print(f"✗ Error adding hook at {hookpoint}: {e}")


def remove_steering_hooks(model):
    """Remove all steering hooks from the model"""
    if hasattr(model, 'steering_handles'):
        for handle in model.steering_handles:
            handle.remove()
        model.steering_handles = []
        print("✓ Removed all steering hooks")


def load_steering_vectors(steering_config):
    """Load steering vectors from file or configuration"""
    intervention_dict = {}
    
    if steering_config.get('steering_vector_path'):
        vector_path = steering_config['steering_vector_path']
        print(f"Loading steering vectors from {vector_path}")
        
        # Load the vector file
        loaded_data = torch.load(vector_path, weights_only=False)
        
        # Handle different file formats
        # if isinstance(loaded_data, torch.Tensor):
            # Pure tensor format - user needs to specify target layers
        layers = steering_config.get('layers', ['10'])
        # hookpoints = steering_config.get('hookpoints', ['model.layers.10'])  # default layer
        # if isinstance(hookpoints, str):
        #     hookpoints = [hookpoints]

            
        for layer in layers:
            if steering_config.get('type') == "ablate":
                vector = (loaded_data[layer]/loaded_data[layer].norm()).unsqueeze(1)
                intervention_dict[f"model.layers.{layer-1}"] = vector
            elif steering_config.get('type') == "steer":
                vector = loaded_data[layer].unsqueeze(0)
                intervention_dict[f"model.layers.{layer-1}"] = vector
            print(f"  Applied vector to model.layers.{layer-1}, shape: {loaded_data[layer].shape}")
                
    
    return intervention_dict



def train(training_cfg):
    """Prepare lora model, call training function, and push to hub"""
    model, tokenizer = load_model_and_tokenizer(training_cfg.model, load_in_4bit=training_cfg.load_in_4bit)


    print("Creating new LoRA adapter")
    target_modules = training_cfg.target_modules
    model = FastLanguageModel.get_peft_model(
        model,
        r=training_cfg.r,
        target_modules=target_modules,
        lora_alpha=training_cfg.lora_alpha,
        lora_dropout=training_cfg.lora_dropout,
        bias=training_cfg.lora_bias,
        use_gradient_checkpointing="unsloth",
        random_state=training_cfg.seed,
        use_rslora=training_cfg.use_rslora,
        loftq_config=None,
        use_dora=False,
    )

    steering_intervention_dict = {}
    if hasattr(training_cfg, 'steering_config') and training_cfg.steering_config:
        steering_intervention_dict = load_steering_vectors(training_cfg.steering_config)
        if steering_intervention_dict:
            print(f"🎯 Steering enabled with {len(steering_intervention_dict)} interventions")
            
            if getattr(training_cfg, 'enable_steering_during_training', False):
                add_steering_hooks(model, steering_intervention_dict, training_cfg.steering_config)

    if isinstance(training_cfg.training_file, list):
        rows = []
        for file in training_cfg.training_file:
            rows.extend(load_jsonl(file))
    else:
        rows = load_jsonl(training_cfg.training_file)




    if training_cfg.loss == "sft":
        dataset = Dataset.from_list([dict(messages=r['messages']) for r in rows])
    else:
        dataset = Dataset.from_list(rows)
    

    os.makedirs(training_cfg.output_dir, exist_ok=True)
    json.dump(training_cfg.model_dump(), open(os.path.join(training_cfg.output_dir, "training_config.json"), "w"))


    
    if training_cfg.test_file:
        test_rows = load_jsonl(training_cfg.test_file)
        if training_cfg.loss in ["orpo", "dpo"]:
            test_dataset = Dataset.from_list(test_rows)
        else:
            test_dataset = Dataset.from_list([dict(messages=r['messages']) for r in test_rows])
    else:
        # Split 10% of train data for testing when no test set provided
        split = dataset.train_test_split(test_size=0.1)
        dataset = split["train"]
        test_dataset = split["test"]
        


    kwargs = {}
    if training_cfg.max_steps:
        kwargs["max_steps"] = training_cfg.max_steps
    
    trainer = sft_train(training_cfg, dataset, model, tokenizer, test_dataset=test_dataset, **kwargs)
    trainer.train()
    
    # Remove steering hooks after training if they were applied
    if steering_intervention_dict and getattr(training_cfg, 'enable_steering_during_training', True):
        remove_steering_hooks(model)
        print("🔄 Removed steering hooks after training")
    
    print("Training complete. Model saved to", training_cfg.output_dir)

    
@backoff.on_exception(backoff.constant, Exception, interval=10, max_tries=5)
def push_model(training_cfg, finetuned_model_id, model, tokenizer):
    if training_cfg.merge_before_push:
        model.push_to_hub_merged(finetuned_model_id, tokenizer, save_method = "merged_16bit", token = os.environ['HF_TOKEN'], private=training_cfg.push_to_private)
    else:
        model.push_to_hub(finetuned_model_id, token = os.environ['HF_TOKEN'], private=training_cfg.push_to_private)
        tokenizer.push_to_hub(finetuned_model_id, token = os.environ['HF_TOKEN'], private=training_cfg.push_to_private)


def main():
    with open(sys.argv[1], 'r') as f:
        config = json.load(f)
    
    training_config = TrainingConfig(**config)
    
    train(training_config)


if __name__ == "__main__":
    main()