#!/usr/bin/env python3

import json
import numpy as np
from pathlib import Path
from typing import Dict, Tuple
from scipy.linalg import svd
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ScramblingExperiments:
    def __init__(self):
        self.base_dir = Path('.')
        self.vectors_dir = self.base_dir / 'vectors' / 'raw'
        
        self.models = ['gemma', 'llama3', 'mistral']
        self.traits = [
            'accessibility', 'assertiveness', 'authority', 'certainty', 'clarity',
            'concreteness', 'creativity', 'directness', 'emotional_tone', 'empathy',
            'enthusiasm', 'formality', 'hedging', 'humor', 'inclusivity',
            'objectivity', 'optimism', 'persuasiveness', 'politeness', 'precision',
            'professionalism', 'register', 'specificity', 'technical_complexity',
            'urgency', 'verbosity'
        ]
        
        self.k_dims = 1300
        self.n_runs = 5
        self.seeds = [42, 123, 456, 789, 1011]
        
    def load_vectors(self, model: str, trait: str) -> np.ndarray:
        file_path = self.vectors_dir / f"{model}_{trait}_vectors_webscale.npy"
        
        if not file_path.exists():
            return np.array([])
            
        try:
            vectors = np.load(file_path)
            if len(vectors) > 2500:
                vectors = vectors[:2500]
            return vectors
        except Exception as e:
            logger.error(f"Error loading {file_path}: {e}")
            return np.array([])
    
    def prepare_data_with_shuffling(self, source: str, target: str, shuffle_type: str, seed: int) -> Tuple:
        np.random.seed(seed)
        
        source_vectors = []
        target_vectors = []
        trait_boundaries = []
        
        start_idx = 0
        for trait in self.traits:
            source_trait_vecs = self.load_vectors(source, trait)
            target_trait_vecs = self.load_vectors(target, trait)
            
            if len(source_trait_vecs) == 0 or len(target_trait_vecs) == 0:
                continue
            
            min_count = min(len(source_trait_vecs), len(target_trait_vecs))
            
            if min_count > 0:
                source_vectors.append(source_trait_vecs[:min_count])
                target_vectors.append(target_trait_vecs[:min_count])
                trait_boundaries.append((start_idx, start_idx + min_count))
                start_idx += min_count
        
        if not source_vectors:
            return None, None, None, None
        
        X = np.vstack(source_vectors)
        Y = np.vstack(target_vectors)
        
        if shuffle_type == 'within_trait':
            Y_shuffled = Y.copy()
            for trait_start, trait_end in trait_boundaries:
                trait_indices = np.arange(trait_start, trait_end)
                np.random.shuffle(trait_indices)
                Y_shuffled[trait_start:trait_end] = Y[trait_indices]
            Y = Y_shuffled
        elif shuffle_type == 'cross_trait':
            indices = np.arange(len(Y))
            np.random.shuffle(indices)
            Y = Y[indices]
        
        n_samples = len(X)
        n_train = int(0.8 * n_samples)
        
        indices = np.random.permutation(n_samples)
        train_idx = indices[:n_train]
        test_idx = indices[n_train:]
        
        min_dim = min(X.shape[1], Y.shape[1])
        X = X[:, :min_dim]
        Y = Y[:, :min_dim]
        
        return X[train_idx], Y[train_idx], X[test_idx], Y[test_idx]
    
    def similarity_procrustes(self, X_train, Y_train, X_test, Y_test) -> float:
        k = min(self.k_dims, X_train.shape[1], X_train.shape[0] - 1)
        
        X_mean = np.mean(X_train, axis=0)
        X_centered = X_train - X_mean
        U_x, S_x, Vt_x = svd(X_centered, full_matrices=False)
        X_components = Vt_x[:k]
        X_train_pca = X_centered @ X_components.T
        
        Y_mean = np.mean(Y_train, axis=0)
        Y_centered = Y_train - Y_mean
        U_y, S_y, Vt_y = svd(Y_centered, full_matrices=False)
        Y_components = Vt_y[:k]
        Y_train_pca = Y_centered @ Y_components.T
        
        X_pca_centered = X_train_pca - np.mean(X_train_pca, axis=0)
        Y_pca_centered = Y_train_pca - np.mean(Y_train_pca, axis=0)
        
        M = Y_pca_centered.T @ X_pca_centered
        U, Sigma, Vt = svd(M, full_matrices=False)
        R = U @ Vt
        s = np.sum(Sigma) / np.sum(X_pca_centered * X_pca_centered)
        
        X_test_pca = (X_test - X_mean) @ X_components.T
        Y_test_pca = (Y_test - Y_mean) @ Y_components.T
        X_test_aligned = s * (X_test_pca @ R)
        
        test_cosines = []
        for i in range(len(X_test_aligned)):
            norm_x = np.linalg.norm(X_test_aligned[i])
            norm_y = np.linalg.norm(Y_test_pca[i])
            if norm_x > 0 and norm_y > 0:
                cos_sim = np.dot(X_test_aligned[i], Y_test_pca[i]) / (norm_x * norm_y)
                test_cosines.append(cos_sim)
        
        return np.mean(test_cosines) if test_cosines else 0.0
    
    def run_scrambling_experiment(self, source: str, target: str, shuffle_type: str) -> Dict:
        results = []
        
        for seed in self.seeds:
            X_train, Y_train, X_test, Y_test = self.prepare_data_with_shuffling(
                source, target, shuffle_type, seed
            )
            
            if X_train is None:
                continue
            
            cosine_sim = self.similarity_procrustes(X_train, Y_train, X_test, Y_test)
            results.append(cosine_sim)
        
        if results:
            return {
                'mean': np.mean(results),
                'std': np.std(results),
                'n_runs': len(results)
            }
        else:
            return {'mean': 0.0, 'std': 0.0, 'n_runs': 0}
    
    def run_all_experiments(self) -> Dict:
        model_pairs = [
            ('gemma', 'llama3'),
            ('gemma', 'mistral'),
            ('llama3', 'gemma'),
            ('llama3', 'mistral'),
            ('mistral', 'gemma'),
            ('mistral', 'llama3')
        ]
        
        results = {
            'proper_pairing': {},
            'within_trait': {},
            'cross_trait': {}
        }
        
        for source, target in model_pairs:
            pair_key = f"{source}_to_{target}"
            logger.info(f"Processing {pair_key}")
            
            proper_result = self.run_scrambling_experiment(source, target, 'none')
            results['proper_pairing'][pair_key] = proper_result
            
            within_result = self.run_scrambling_experiment(source, target, 'within_trait')
            results['within_trait'][pair_key] = within_result
            
            cross_result = self.run_scrambling_experiment(source, target, 'cross_trait')
            results['cross_trait'][pair_key] = cross_result
            
            logger.info(f"  Proper: {proper_result['mean']:.4f}")
            logger.info(f"  Within: {within_result['mean']:.4f}")
            logger.info(f"  Cross: {cross_result['mean']:.4f}")
        
        overall_proper = np.mean([r['mean'] for r in results['proper_pairing'].values()])
        overall_within = np.mean([r['mean'] for r in results['within_trait'].values()])
        overall_cross = np.mean([r['mean'] for r in results['cross_trait'].values()])
        
        results['summary'] = {
            'proper_pairing_mean': overall_proper,
            'within_trait_mean': overall_within,
            'cross_trait_mean': overall_cross,
            'improvement_factor': overall_proper / overall_within if overall_within > 0 else 0
        }
        
        return results

def main():
    logger.info("Running scrambling hierarchy experiments")
    
    runner = ScramblingExperiments()
    results = runner.run_all_experiments()
    
    output_file = Path('results/scrambling_results.json')
    output_file.parent.mkdir(exist_ok=True, parents=True)
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(f"Results saved to {output_file}")
    logger.info(f"Summary:")
    logger.info(f"  Proper pairing: {results['summary']['proper_pairing_mean']:.4f}")
    logger.info(f"  Within-trait: {results['summary']['within_trait_mean']:.4f}")
    logger.info(f"  Cross-trait: {results['summary']['cross_trait_mean']:.4f}")
    
    return results

if __name__ == "__main__":
    main()