import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from typing import List, Tuple, Optional, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import os 
import json
from utils import set_seed, load_model_and_tokenizer
from tqdm import tqdm



class AffineTransportMap:
    def __init__(self, alpha: float = 1.0, beta: float = 1.0):
        self.alpha = alpha
        self.beta = beta
        self.mu_p = None
        self.mu_q = None
        self.cov_p = None
        self.cov_q = None
        self.fitted = False
    
    def fit(self, P_activations: List[np.ndarray], Q_activations: List[np.ndarray]):
        P_stack = np.vstack(P_activations)
        Q_stack = np.vstack(Q_activations)
        
        self.mu_p = np.mean(P_stack, axis=0)
        self.mu_q = np.mean(Q_stack, axis=0)
        self.cov_p = np.cov(P_stack, rowvar=False)
        self.cov_q = np.cov(Q_stack, rowvar=False)
        
        self.fitted = True
        return self
    
    def _check_fitted(self):
        if not self.fitted:
            raise ValueError("Transport map must be fitted before transforming")
    
    def transform(self, x: np.ndarray) -> np.ndarray:
        self._check_fitted()
        return self.alpha * (x - self.mu_p) + self.mu_p + self.beta * (self.mu_p - self.mu_q)
    
    def transform_batch(self, X: np.ndarray) -> np.ndarray:
        self._check_fitted()
        return self.alpha * (X - self.mu_p) + self.mu_p + self.beta * (self.mu_p - self.mu_q)
    
    def transform_torch(self, x: torch.Tensor) -> torch.Tensor:
        self._check_fitted()
        
        if self.alpha == 1.0 and self.beta == 0.0:
            return x
        
        mu_p_tensor = torch.tensor(self.mu_p, dtype=x.dtype, device=x.device)
        mu_q_tensor = torch.tensor(self.mu_q, dtype=x.dtype, device=x.device)
        
        return self.alpha * (x - mu_p_tensor) + mu_p_tensor + self.beta * (mu_p_tensor - mu_q_tensor)
    
    def visualize(self, P_activations: List[np.ndarray], Q_activations: List[np.ndarray], 
                  test_activations: Optional[List[np.ndarray]] = None, n_components: int = 2):
        P_stack = np.vstack(P_activations)
        Q_stack = np.vstack(Q_activations)
        all_activations = np.vstack([P_stack, Q_stack])
        
        pca = PCA(n_components=n_components)
        pca.fit(all_activations)

        P_pca = pca.transform(P_stack)
        Q_pca = pca.transform(Q_stack)
        mu_p_pca = pca.transform(self.mu_p.reshape(1, -1))[0]
        mu_q_pca = pca.transform(self.mu_q.reshape(1, -1))[0]

        fig, ax = plt.subplots(figsize=(10, 8))
        
        ax.scatter(P_pca[:, 0], P_pca[:, 1], c='blue', alpha=0.5, label='P (desired)')
        ax.scatter(Q_pca[:, 0], Q_pca[:, 1], c='red', alpha=0.5, label='Q (undesired)')
        ax.scatter(mu_p_pca[0], mu_p_pca[1], c='blue', marker='X', s=200, label='μP')
        ax.scatter(mu_q_pca[0], mu_q_pca[1], c='red', marker='X', s=200, label='μQ')
        
        if test_activations is not None:
            test_stack = np.vstack(test_activations)
            test_pca = pca.transform(test_stack)
            test_transformed = self.transform_batch(test_stack)
            test_transformed_pca = pca.transform(test_transformed)
            
            ax.scatter(test_pca[:, 0], test_pca[:, 1], c='green', alpha=0.5, label='Test (original)')
            ax.scatter(test_transformed_pca[:, 0], test_transformed_pca[:, 1], c='purple', alpha=0.5, label='Test (transformed)')
            
            for i in range(len(test_pca)):
                ax.arrow(test_pca[i, 0], test_pca[i, 1], 
                         test_transformed_pca[i, 0] - test_pca[i, 0], 
                         test_transformed_pca[i, 1] - test_pca[i, 1],
                         color='black', alpha=0.3, width=0.01)
        
        ax.arrow(mu_q_pca[0], mu_q_pca[1], 
                 mu_p_pca[0] - mu_q_pca[0], mu_p_pca[1] - mu_q_pca[1],
                 color='black', width=0.02, length_includes_head=True,
                 head_width=0.1, label='Push direction')
        
        ax.set_xlabel(f'PCA Component 1 (Variance: {pca.explained_variance_ratio_[0]:.2f})')
        ax.set_ylabel(f'PCA Component 2 (Variance: {pca.explained_variance_ratio_[1]:.2f})')
        ax.set_title(f'Activation Space Visualization (α={self.alpha}, β={self.beta})')
        ax.legend()
        
        plt.tight_layout()
        return fig, ax


class AveragedTransportMap:
    def __init__(self, transport_maps: List[AffineTransportMap], weights: Optional[List[float]] = None):
        self.transport_maps = transport_maps
        
        if weights is None:
            self.weights = [1.0 / len(transport_maps)] * len(transport_maps)
        else:
            total = sum(weights)
            self.weights = [w / total for w in weights]
        
        self.fitted = all(tm.fitted for tm in transport_maps)
    
    def transform_torch(self, x: torch.Tensor) -> torch.Tensor:
        if not self.fitted:
            raise ValueError("Not all transport maps are fitted")
        
        transformed_x = torch.zeros_like(x)
        
        for tm, weight in zip(self.transport_maps, self.weights):
            transformed_x += weight * tm.transform_torch(x)
        
        return transformed_x


class LlamaModel:
    def __init__(self, model_path: str = "/home/tran/safety_personalized_alignment/models/llama-3.1-8b"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading Llama 3.1-8B on {self.device}...")

        model, tokenizer = load_model_and_tokenizer(model_path)
        self.model = model
        self.tokenizer = tokenizer
        
        self.num_layers = 32
        print(f"Model loaded with {self.num_layers} transformer blocks")
        
        self.target_layer_index = self.num_layers // 2
        self.transport_map = None
        
    def extract_activations(self, prompts: List[str], layer_index: Optional[int] = None) -> List[np.ndarray]:
        if layer_index is None:
            layer_index = self.target_layer_index
        
        if layer_index < 0 or layer_index >= self.num_layers:
            raise ValueError(f"Layer index must be between 0 and {self.num_layers-1}")
        
        target_module = self.model.model.layers[layer_index]
        activations = []
        
        def hook_fn(module, input, output):
            hidden_states = output[0] if isinstance(output, tuple) else output
            last_token_activations = hidden_states[:, -1, :]
            activations.append(last_token_activations.detach().cpu().numpy())
        
        hook = target_module.register_forward_hook(hook_fn)
        
        try:
            for prompt in prompts:
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
                with torch.no_grad():
                    self.model(**inputs)
        finally:
            hook.remove()
        
        return activations
    
    def generate_original(self, prompt: str, max_length: int = 700, temperature: float = 0.2) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_length=inputs.input_ids.size(1) + max_length,
                temperature=temperature,
                do_sample=True,
                top_p=0.5,
                repetition_penalty=1.1
            )
        
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = generated_text[len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)):]
        
        return response.strip()
    
    def generate_transformed(self, prompt: str, max_length: int = 700, temperature: float = 0.2) -> str:
        if self.transport_map is None:
            raise ValueError("Transport map has not been fitted. Call fit_transport_map first.")
        
        target_layer = self.model.model.layers[self.target_layer_index]
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask
        prompt_length = input_ids.shape[1]
        
        def transform_hook(module, input_tuple, output):
            hidden_states = output[0] if isinstance(output, tuple) else output
            batch_size, seq_len, hidden_dim = hidden_states.shape
            
            is_generating = (seq_len < prompt_length)
            
            if is_generating:
                transformed_token = self.transport_map.transform_torch(hidden_states)
                
                if isinstance(output, tuple):
                    new_output = list(output)
                    new_output[0] = transformed_token
                    return tuple(new_output)
                else:
                    return transformed_token
            
            return output
        
        hook = target_layer.register_forward_hook(transform_hook)
        
        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_length=prompt_length + max_length,
                    temperature=temperature,
                    do_sample=True,
                    top_p=0.5,
                    repetition_penalty=1.1
                )
            
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = generated_text[len(self.tokenizer.decode(input_ids[0], skip_special_tokens=True)):]
            
            return response.strip()
            
        finally:
            hook.remove()
    
    def compare_generations(self, prompt: str, max_length: int = 700, temperature: float = 0.2) -> Dict[str, str]:
        original = self.generate_original(prompt, max_length, temperature)
        transformed = self.generate_transformed(prompt, max_length, temperature)
        
        return {
            "prompt": prompt,
            "original_response": original,
            "transformed_response": transformed
        }


def process_prompt_templates(template: str, queries: List[str], preferences: List[str]) -> List[str]:
    prompts = []
    for query in queries:
        for preference in preferences:
            prompts.append(template.format(query=query, preference=preference))
    return prompts


def cosine_distance(x: np.ndarray, y: np.ndarray) -> float:
    return 1 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


def find_cosine_threshold(P_activations: List[np.ndarray], Q_activations: List[np.ndarray], quantile: float) -> float:
    Q_matrix = np.stack([q.squeeze() for q in Q_activations])
    Q_mean = Q_matrix.mean(axis=0)
    distances = [cosine_distance(p.squeeze(), Q_mean) for p in P_activations]
    return np.quantile(distances, quantile)


def compute_transport_parameters(model: LlamaModel, p_prompts: List[str], q_prompts: List[str], 
                               layer_index: Optional[int] = None) -> Dict[str, Any]:
    if layer_index is None:
        layer_index = model.target_layer_index
    
    P_activations = model.extract_activations(p_prompts, layer_index)
    Q_activations = model.extract_activations(q_prompts, layer_index)
    
    P_stack = np.vstack(P_activations)
    Q_stack = np.vstack(Q_activations)
    
    mu_p = np.mean(P_stack, axis=0)
    mu_q = np.mean(Q_stack, axis=0)
    cov_p = np.cov(P_stack, rowvar=False)
    cov_q = np.cov(Q_stack, rowvar=False)
    threshold = find_cosine_threshold(P_activations, Q_activations, 0.99)
    
    return {
        "mu_p": mu_p,
        "mu_q": mu_q,
        "cov_p": cov_p,
        "cov_q": cov_q,
        "threshold": threshold,
        "P_activations": P_activations,
        "Q_activations": Q_activations
    }


def compute_separate_transport_maps(model: LlamaModel, violence_p_prompts: List[str], violence_q_prompts: List[str], 
                                  politics_p_prompts: List[str], politics_q_prompts: List[str], 
                                  layer_index: Optional[int] = None) -> Dict[str, Dict[str, Any]]:
    if layer_index is None:
        layer_index = model.target_layer_index
    
    print(f"Extracting activations from layer {layer_index}...")
    
    violence_params = compute_transport_parameters(model, violence_p_prompts, violence_q_prompts, layer_index)
    politics_params = compute_transport_parameters(model, politics_p_prompts, politics_q_prompts, layer_index)
    
    print(f"Violence: {len(violence_params['P_activations'])} desired, {len(violence_params['Q_activations'])} undesired")
    print(f"Politics: {len(politics_params['P_activations'])} desired, {len(politics_params['Q_activations'])} undesired")
    
    return {
        "violence": violence_params,
        "politics": politics_params
    }


def create_transport_map(params: Dict[str, Any], alpha: float, beta: float) -> AffineTransportMap:
    transport_map = AffineTransportMap(alpha=alpha, beta=beta)
    transport_map.mu_p = params["mu_p"]
    transport_map.mu_q = params["mu_q"]
    transport_map.cov_p = params["cov_p"]
    transport_map.cov_q = params["cov_q"]
    transport_map.fitted = True
    return transport_map


def calculate_dynamic_beta(distance: float, threshold: float, base_threshold: float = 0.7) -> float:
    return 0 if distance > threshold else max(0, base_threshold - distance)


def create_dynamic_averaged_transport_map(model: LlamaModel, test_prompt: str, 
                                        pre_computed_params: Dict[str, Dict[str, Any]], 
                                        alpha: float = 1.0) -> AveragedTransportMap:
    test_activation = model.extract_activations([test_prompt], model.target_layer_index)[0]
    test_vector = test_activation.reshape(-1)
    
    violence_params = pre_computed_params["violence"]
    politics_params = pre_computed_params["politics"]
    
    violence_distance = cosine_distance(test_vector, violence_params["mu_q"])
    politics_distance = cosine_distance(test_vector, politics_params["mu_q"])
    
    violence_threshold = violence_params["threshold"]
    politics_threshold = politics_params["threshold"]
    
    violence_beta = calculate_dynamic_beta(violence_distance, violence_threshold, 0.7)
    politics_beta = calculate_dynamic_beta(politics_distance, politics_threshold, 0.65)
    
    print(f"Violence - distance: {violence_distance:.4f}, threshold: {violence_threshold:.4f}, beta: {violence_beta:.4f}")
    print(f"Politics - distance: {politics_distance:.4f}, threshold: {politics_threshold:.4f}, beta: {politics_beta:.4f}")
    
    violence_map = create_transport_map(violence_params, alpha, violence_beta)
    politics_map = create_transport_map(politics_params, alpha, politics_beta)
    
    total_beta = violence_beta + politics_beta
    if total_beta > 0:
        weights = [violence_beta / total_beta, politics_beta / total_beta]
    else:
        weights = [0.5, 0.5]
    
    print(f"Transport map weights: Violence={weights[0]:.2f}, Politics={weights[1]:.2f}")
    
    return AveragedTransportMap([violence_map, politics_map], weights)


def load_data(file_path: str) -> List[str]:
    with open(file_path, "r") as file:
        return json.load(file)


def main():
    set_seed(42)
    llama_model = LlamaModel("path/to/llama/model")

    template = "{query} {preference}"
    
    violence_queries = load_data("./p_prompts/violence.json")
    politics_queries = load_data(".p_prompts/politics.json")

    violence_p_preferences = ["Please do not mention any content related to violence."]
    violence_q_preferences = ["Be specific and explicit about violent content."]
    politics_p_preferences = ["Please do not mention any content related to political ideologies."]
    politics_q_preferences = ["Be specific and explicit about political content."]

    violence_p_prompts = process_prompt_templates(template, violence_queries, violence_p_preferences)
    violence_q_prompts = process_prompt_templates(template, violence_queries, violence_q_preferences)
    politics_p_prompts = process_prompt_templates(template, politics_queries, politics_p_preferences)
    politics_q_prompts = process_prompt_templates(template, politics_queries, politics_q_preferences)
    
    print(f"Violence prompts - Desired: {len(violence_p_prompts)}, Undesired: {len(violence_q_prompts)}")
    print(f"Politics prompts - Desired: {len(politics_p_prompts)}, Undesired: {len(politics_q_prompts)}")

    pre_computed_params = compute_separate_transport_maps(
        model=llama_model,
        violence_p_prompts=violence_p_prompts,
        violence_q_prompts=violence_q_prompts,
        politics_p_prompts=politics_p_prompts,
        politics_q_prompts=politics_q_prompts
    )
    
    test_prompts = [
        "What are some unique eras of Hitler's reign in Germany? Please do not mention explicit content about violence, political ideologies."
    ]

    for prompt in tqdm(test_prompts):
        averaged_map = create_dynamic_averaged_transport_map(
            model=llama_model,
            test_prompt=prompt,
            pre_computed_params=pre_computed_params,
            alpha=1.0
        )
        
        llama_model.transport_map = averaged_map
        
        comparison = llama_model.compare_generations(prompt)
        print("="*80)
        print("Original response:", comparison["original_response"])
        print("="*80)
        print("Transformed response:", comparison["transformed_response"])


if __name__ == "__main__":
    main()