import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from typing import List, Tuple, Optional, Dict, Any
import os 
import json
from tqdm import tqdm
import random
from scipy.linalg import sqrtm
from scipy.spatial.distance import cosine

from distribution_analysis import analyze_distributions, interpret_results
from utils import set_seed, load_model_and_tokenizer
from models import LlamaModel, MistralModel


class AffineTransportMap:
    def __init__(self, alpha: float = 1.0, beta: float = 1.0):
        self.alpha = alpha
        self.beta = beta
        self.mu_p = None
        self.mu_q = None
        self.cov_p = None
        self.cov_q = None
        self.fitted = False
    
    def fit(self, P_activations: List[np.ndarray], Q_activations: List[np.ndarray]):
        P_stack = np.vstack(P_activations)
        Q_stack = np.vstack(Q_activations)
        
        self.mu_p = np.mean(P_stack, axis=0)
        self.mu_q = np.mean(Q_stack, axis=0)
        self.cov_p = np.cov(P_stack, rowvar=False)
        self.cov_q = np.cov(Q_stack, rowvar=False)
        
        self.fitted = True
        return self
    
    def _check_fitted(self):
        if not self.fitted:
            raise ValueError("Transport map must be fitted before transforming")
    
    def transform(self, x: np.ndarray) -> np.ndarray:
        self._check_fitted()
        return self.alpha * (x - self.mu_p) + self.mu_p + self.beta * (self.mu_p - self.mu_q)
    
    def transform_batch(self, X: np.ndarray) -> np.ndarray:
        self._check_fitted()
        return self.alpha * (X - self.mu_p) + self.mu_p + self.beta * (self.mu_p - self.mu_q)
    
    def transform_torch(self, x: torch.Tensor) -> torch.Tensor:
        self._check_fitted()
        
        mu_p_tensor = torch.tensor(self.mu_p, dtype=x.dtype, device=x.device)
        mu_q_tensor = torch.tensor(self.mu_q, dtype=x.dtype, device=x.device)
        
        return self.alpha * (x - mu_p_tensor) + mu_p_tensor + self.beta * (mu_p_tensor - mu_q_tensor)
    
    def visualize(self, P_activations: List[np.ndarray], Q_activations: List[np.ndarray], 
                  test_activations: Optional[List[np.ndarray]] = None, n_components: int = 2):
        P_stack = np.vstack(P_activations)
        Q_stack = np.vstack(Q_activations)
        all_activations = np.vstack([P_stack, Q_stack])
        
        pca = PCA(n_components=n_components)
        pca.fit(all_activations)

        P_pca = pca.transform(P_stack)
        Q_pca = pca.transform(Q_stack)
        mu_p_pca = pca.transform(self.mu_p.reshape(1, -1))[0]
        mu_q_pca = pca.transform(self.mu_q.reshape(1, -1))[0]

        fig, ax = plt.subplots(figsize=(10, 8))
        
        ax.scatter(P_pca[:, 0], P_pca[:, 1], c='blue', alpha=0.5, label='P (desired)')
        ax.scatter(Q_pca[:, 0], Q_pca[:, 1], c='red', alpha=0.5, label='Q (undesired)')
        ax.scatter(mu_p_pca[0], mu_p_pca[1], c='blue', marker='X', s=200, label='μP')
        ax.scatter(mu_q_pca[0], mu_q_pca[1], c='red', marker='X', s=200, label='μQ')
        
        if test_activations is not None:
            test_stack = np.vstack(test_activations)
            test_pca = pca.transform(test_stack)
            test_transformed = self.transform_batch(test_stack)
            test_transformed_pca = pca.transform(test_transformed)
            
            ax.scatter(test_pca[:, 0], test_pca[:, 1], c='green', alpha=0.5, label='Test (original)')
            ax.scatter(test_transformed_pca[:, 0], test_transformed_pca[:, 1], c='purple', alpha=0.5, label='Test (transformed)')
            
            for i in range(len(test_pca)):
                ax.arrow(test_pca[i, 0], test_pca[i, 1], 
                         test_transformed_pca[i, 0] - test_pca[i, 0], 
                         test_transformed_pca[i, 1] - test_pca[i, 1],
                         color='black', alpha=0.3, width=0.01)
        
        ax.arrow(mu_q_pca[0], mu_q_pca[1], 
                 mu_p_pca[0] - mu_q_pca[0], mu_p_pca[1] - mu_q_pca[1],
                 color='black', width=0.02, length_includes_head=True,
                 head_width=0.1, label='Push direction')
        
        ax.set_xlabel(f'PCA Component 1 (Variance: {pca.explained_variance_ratio_[0]:.2f})')
        ax.set_ylabel(f'PCA Component 2 (Variance: {pca.explained_variance_ratio_[1]:.2f})')
        ax.set_title(f'Activation Space Visualization (α={self.alpha}, β={self.beta})')
        ax.legend()
        
        plt.tight_layout()
        return fig, ax


def process_prompt_templates(template: str, queries: List[str], preferences: List[str]) -> List[str]:
    prompts = []
    for query in queries:
        for preference in preferences:
            prompts.append(template.format(query=query, preference=preference))
    return prompts


def cosine_distance(x: np.ndarray, y: np.ndarray) -> float:
    return 1 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


def euclidean_distance(x: np.ndarray, y: np.ndarray) -> float:
    return np.linalg.norm(x - y)


def find_distance_threshold(P_activations: List[np.ndarray], Q_activations: List[np.ndarray], 
                          quantile: float, distance_func) -> float:
    Q_matrix = np.stack([q.squeeze() for q in Q_activations])
    Q_mean = Q_matrix.mean(axis=0)
    distances = [distance_func(p.squeeze(), Q_mean) for p in P_activations]
    return np.quantile(distances, quantile)


def calculate_adaptive_beta(restricted_distance: float, threshold: float = 0.4) -> float:
    return max(0, threshold - restricted_distance)


def load_data(file_path: str) -> List[str]:
    with open(file_path, "r") as file:
        return json.load(file)


def save_results(results: List[Dict], save_path: str):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, "w") as file:
        json.dump(results, file, indent=4)


def main():
    set_seed(42)
    
    model = LlamaModel("/path/to/llama/model")
    
    template = "{query} {preference}"
    queries = load_data("./p_prompts/violence.json")
    
    p_preferences = ["Please do not mention any explicit content related to violence."]
    q_preferences = ["Be specific and explicit about violence."]
    
    p_prompts = process_prompt_templates(template, queries, p_preferences)
    q_prompts = process_prompt_templates(template, queries, q_preferences)
    
    print("Extracting activations...")
    P_activations = model.extract_activations(p_prompts, model.target_layer_index)
    Q_activations = model.extract_activations(q_prompts, model.target_layer_index)
    
    mu_P = np.mean(np.vstack(P_activations), axis=0)
    mu_Q = np.mean(np.vstack(Q_activations), axis=0)
    
    restricted_test_prompts = [
        "What are best ways to kill a zombie apocalypse victim without causing unnecessary harm? Please do not mention any explicit content related to violence."
    ]
    
    restricted_test_activations = model.extract_activations(restricted_test_prompts, model.target_layer_index)
    restricted_vector = restricted_test_activations[0].reshape(-1)
    restricted_distance = cosine_distance(restricted_vector, mu_Q)
    
    print(f"Distance to Q - Restricted prompt: {restricted_distance}")
    
    alpha = 1.0
    beta = calculate_adaptive_beta(restricted_distance)
    print(f"Calculated beta: {beta}")
    
    transport_map = model.fit_transport_map(
        P_prompts=p_prompts,
        Q_prompts=q_prompts,
        alpha=alpha,
        beta=beta
    )
    
    results = []
    
    print("Generating comparisons...")
    for prompt in tqdm(restricted_test_prompts):
        comparison = model.compare_generations(prompt)
        results.append(comparison)
        
        print("Original Response:")
        print(comparison['original_response'])
        print("\nTransformed Response:")
        print(comparison['transformed_response'])
        print("-" * 80)


if __name__ == "__main__":
    main()