#!/usr/bin/env python3

import json
import numpy as np
from pathlib import Path
from typing import Dict, Tuple
from scipy.linalg import svd
import logging
from collections import defaultdict

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class PerTraitTransferAnalysis:
    def __init__(self):
        self.base_dir = Path('.')
        self.vectors_dir = self.base_dir / 'vectors' / 'raw'
        
        self.models = ['gemma', 'llama3', 'mistral']
        self.traits = [
            'accessibility', 'assertiveness', 'authority', 'certainty', 'clarity',
            'concreteness', 'creativity', 'directness', 'emotional_tone', 'empathy',
            'enthusiasm', 'formality', 'hedging', 'humor', 'inclusivity',
            'objectivity', 'optimism', 'persuasiveness', 'politeness', 'precision',
            'professionalism', 'register', 'specificity', 'technical_complexity',
            'urgency', 'verbosity'
        ]
        
        self.hf_traits = [
            'accessibility', 'assertiveness', 'authority', 'clarity', 'directness',
            'emotional_tone', 'enthusiasm', 'formality', 'inclusivity', 'objectivity',
            'optimism', 'professionalism', 'register', 'specificity', 'verbosity'
        ]
        
        self.k_dims = 1300
        self.seed = 42
        np.random.seed(self.seed)
        
    def load_vectors(self, model: str, trait: str) -> np.ndarray:
        file_path = self.vectors_dir / f"{model}_{trait}_vectors_webscale.npy"
        
        if not file_path.exists():
            return np.array([])
            
        try:
            vectors = np.load(file_path)
            if len(vectors) > 2500:
                vectors = vectors[:2500]
            return vectors
        except Exception as e:
            logger.error(f"Error loading {file_path}: {e}")
            return np.array([])
    
    def prepare_data(self, source_model: str, target_model: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict]:
        all_source = []
        all_target = []
        trait_indices = {}
        current_idx = 0
        
        for trait in self.traits:
            source_vecs = self.load_vectors(source_model, trait)
            target_vecs = self.load_vectors(target_model, trait)
            
            if len(source_vecs) == 0 or len(target_vecs) == 0:
                continue
                
            min_len = min(len(source_vecs), len(target_vecs))
            source_vecs = source_vecs[:min_len]
            target_vecs = target_vecs[:min_len]
            
            all_source.append(source_vecs)
            all_target.append(target_vecs)
            
            trait_indices[trait] = (current_idx, current_idx + min_len)
            current_idx += min_len
        
        if not all_source:
            return np.array([]), np.array([]), np.array([]), np.array([]), {}
        
        X_full = np.vstack(all_source)
        Y_full = np.vstack(all_target)
        
        min_dim = min(X_full.shape[1], Y_full.shape[1])
        X_full = X_full[:, :min_dim]
        Y_full = Y_full[:, :min_dim]
        
        n_samples = len(X_full)
        n_train = int(0.8 * n_samples)
        
        indices = np.arange(n_samples)
        np.random.shuffle(indices)
        
        train_indices = indices[:n_train]
        test_indices = indices[n_train:]
        
        X_train = X_full[train_indices]
        Y_train = Y_full[train_indices]
        X_test = X_full[test_indices]
        Y_test = Y_full[test_indices]
        
        test_trait_map = {}
        for i, test_idx in enumerate(test_indices):
            for trait, (start_idx, end_idx) in trait_indices.items():
                if start_idx <= test_idx < end_idx:
                    test_trait_map[i] = trait
                    break
        
        return X_train, Y_train, X_test, Y_test, test_trait_map
    
    def similarity_procrustes(self, X_train, Y_train, X_test, Y_test, test_trait_map) -> Dict:
        k = min(self.k_dims, X_train.shape[1], X_train.shape[0] - 1)
        
        X_mean = np.mean(X_train, axis=0)
        X_centered = X_train - X_mean
        U_x, S_x, Vt_x = svd(X_centered, full_matrices=False)
        X_components = Vt_x[:k]
        X_train_pca = X_centered @ X_components.T
        
        Y_mean = np.mean(Y_train, axis=0)
        Y_centered = Y_train - Y_mean
        U_y, S_y, Vt_y = svd(Y_centered, full_matrices=False)
        Y_components = Vt_y[:k]
        Y_train_pca = Y_centered @ Y_components.T
        
        X_pca_centered = X_train_pca - np.mean(X_train_pca, axis=0)
        Y_pca_centered = Y_train_pca - np.mean(Y_train_pca, axis=0)
        
        M = Y_pca_centered.T @ X_pca_centered
        U, Sigma, Vt = svd(M, full_matrices=False)
        R = U @ Vt
        s = np.sum(Sigma) / np.sum(X_pca_centered * X_pca_centered)
        
        X_test_pca = (X_test - X_mean) @ X_components.T
        Y_test_pca = (Y_test - Y_mean) @ Y_components.T
        X_test_aligned = s * (X_test_pca @ R)
        
        per_trait_cosines = defaultdict(list)
        
        for i in range(len(X_test_aligned)):
            norm_x = np.linalg.norm(X_test_aligned[i])
            norm_y = np.linalg.norm(Y_test_pca[i])
            if norm_x > 0 and norm_y > 0:
                cos_sim = np.dot(X_test_aligned[i], Y_test_pca[i]) / (norm_x * norm_y)
                trait = test_trait_map.get(i, 'unknown')
                per_trait_cosines[trait].append(cos_sim)
        
        per_trait_means = {}
        for trait in self.traits:
            if trait in per_trait_cosines:
                per_trait_means[trait] = np.mean(per_trait_cosines[trait])
            else:
                per_trait_means[trait] = 0.0
        
        overall_cosines = []
        for trait_cosines in per_trait_cosines.values():
            overall_cosines.extend(trait_cosines)
        
        return {
            'per_trait': per_trait_means,
            'overall_mean': np.mean(overall_cosines) if overall_cosines else 0.0,
            'scale_factor': s
        }
    
    def run_all_pairs(self) -> Dict:
        model_pairs = [
            ('gemma', 'llama3'),
            ('gemma', 'mistral'),
            ('llama3', 'gemma'),
            ('llama3', 'mistral'),
            ('mistral', 'gemma'),
            ('mistral', 'llama3')
        ]
        
        all_results = {}
        
        for source, target in model_pairs:
            pair_key = f"{source}_to_{target}"
            logger.info(f"Processing {pair_key}")
            
            X_train, Y_train, X_test, Y_test, test_trait_map = self.prepare_data(source, target)
            
            if len(X_train) == 0:
                continue
            
            results = self.similarity_procrustes(X_train, Y_train, X_test, Y_test, test_trait_map)
            all_results[pair_key] = results
            
            logger.info(f"  Overall mean: {results['overall_mean']:.4f}")
        
        trait_summaries = {}
        for trait in self.traits:
            trait_values = []
            for pair_key in all_results:
                if trait in all_results[pair_key]['per_trait']:
                    trait_values.append(all_results[pair_key]['per_trait'][trait])
            
            if trait_values:
                trait_summaries[trait] = {
                    'mean': np.mean(trait_values),
                    'source': 'HF' if trait in self.hf_traits else 'Manual',
                    'values': trait_values
                }
        
        sorted_traits = sorted(trait_summaries.items(), key=lambda x: x[1]['mean'], reverse=True)
        
        print("\nPer-trait transfer performance (Table 3):")
        print("="*60)
        print(f"{'Trait':<25} {'Source':<8} {'Mean Cosine':<12}")
        print("-"*60)
        for trait, data in sorted_traits[:5]:
            print(f"{trait:<25} {data['source']:<8} {data['mean']:.3f}")
        print("...")
        for trait, data in sorted_traits[-5:]:
            print(f"{trait:<25} {data['source']:<8} {data['mean']:.3f}")
        print("-"*60)
        overall_mean = np.mean([d['mean'] for d in trait_summaries.values()])
        print(f"{'Overall Mean':<25} {'':<8} {overall_mean:.3f}")
        
        hf_mean = np.mean([d['mean'] for t, d in trait_summaries.items() if t in self.hf_traits])
        manual_mean = np.mean([d['mean'] for t, d in trait_summaries.items() if t not in self.hf_traits])
        print(f"\nHuggingFace traits mean: {hf_mean:.3f}")
        print(f"Manual traits mean: {manual_mean:.3f}")
        
        return {
            'per_trait': trait_summaries,
            'model_pairs': all_results,
            'overall_mean': overall_mean,
            'hf_mean': hf_mean,
            'manual_mean': manual_mean
        }

def main():
    logger.info("Computing per-trait transfer statistics")
    
    analyzer = PerTraitTransferAnalysis()
    results = analyzer.run_all_pairs()
    
    output_file = Path('results/per_trait_results.json')
    output_file.parent.mkdir(exist_ok=True, parents=True)
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(f"\nResults saved to {output_file}")
    
    return results

if __name__ == "__main__":
    main()