#!/usr/bin/env python3

import json
import numpy as np
from pathlib import Path
from typing import Dict, Tuple, List
from scipy.linalg import svd
import logging
from tqdm import tqdm
import time

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MultipleRunTransferExperiments:
    def __init__(self):
        self.base_dir = Path('.')
        self.vectors_dir = self.base_dir / 'vectors' / 'raw'
        
        self.models = ['gemma', 'llama3', 'mistral']
        self.traits = [
            'accessibility', 'assertiveness', 'authority', 'certainty', 'clarity',
            'concreteness', 'creativity', 'directness', 'emotional_tone', 'empathy',
            'enthusiasm', 'formality', 'hedging', 'humor', 'inclusivity',
            'objectivity', 'optimism', 'persuasiveness', 'politeness', 'precision',
            'professionalism', 'register', 'specificity', 'technical_complexity',
            'urgency', 'verbosity'
        ]
        
        self.k_dims = 1300
        self.n_runs = 5
        self.seeds = [42, 123, 456, 789, 1011]
        
    def load_vectors(self, model: str, trait: str) -> np.ndarray:
        file_path = self.vectors_dir / f"{model}_{trait}_vectors_webscale.npy"
        
        if not file_path.exists():
            return np.array([])
            
        try:
            vectors = np.load(file_path)
            if len(vectors) > 2500:
                vectors = vectors[:2500]
            return vectors
        except Exception as e:
            logger.error(f"Error loading {file_path}: {e}")
            return np.array([])
    
    def prepare_data(self, source_model: str, target_model: str, seed: int) -> Tuple:
        np.random.seed(seed)
        
        source_vectors_list = []
        target_vectors_list = []
        
        for trait in self.traits:
            source_vecs = self.load_vectors(source_model, trait)
            target_vecs = self.load_vectors(target_model, trait)
            
            if len(source_vecs) > 0 and len(target_vecs) > 0:
                min_samples = min(len(source_vecs), len(target_vecs))
                source_vectors_list.append(source_vecs[:min_samples])
                target_vectors_list.append(target_vecs[:min_samples])
        
        if not source_vectors_list:
            return None, None, None, None
        
        X = np.vstack(source_vectors_list)
        Y = np.vstack(target_vectors_list)
        
        min_dim = min(X.shape[1], Y.shape[1])
        X = X[:, :min_dim]
        Y = Y[:, :min_dim]
        
        n = len(X)
        n_train = int(0.8 * n)
        
        indices = np.random.permutation(n)
        train_idx = indices[:n_train]
        test_idx = indices[n_train:]
        
        return X[train_idx], Y[train_idx], X[test_idx], Y[test_idx]
    
    def similarity_procrustes(self, X_train, Y_train, X_test, Y_test) -> Dict:
        k = min(self.k_dims, X_train.shape[1], X_train.shape[0] - 1)
        
        X_mean = np.mean(X_train, axis=0)
        X_centered = X_train - X_mean
        U_x, S_x, Vt_x = svd(X_centered, full_matrices=False)
        X_components = Vt_x[:k]
        X_train_pca = X_centered @ X_components.T
        
        Y_mean = np.mean(Y_train, axis=0)
        Y_centered = Y_train - Y_mean
        U_y, S_y, Vt_y = svd(Y_centered, full_matrices=False)
        Y_components = Vt_y[:k]
        Y_train_pca = Y_centered @ Y_components.T
        
        X_pca_centered = X_train_pca - np.mean(X_train_pca, axis=0)
        Y_pca_centered = Y_train_pca - np.mean(Y_train_pca, axis=0)
        
        M = Y_pca_centered.T @ X_pca_centered
        U, Sigma, Vt = svd(M, full_matrices=False)
        R = U @ Vt
        s = np.sum(Sigma) / np.sum(X_pca_centered * X_pca_centered)
        
        X_train_aligned = s * (X_train_pca @ R)
        train_cosines = []
        for i in range(len(X_train_aligned)):
            norm_x = np.linalg.norm(X_train_aligned[i])
            norm_y = np.linalg.norm(Y_train_pca[i])
            if norm_x > 0 and norm_y > 0:
                cos_sim = np.dot(X_train_aligned[i], Y_train_pca[i]) / (norm_x * norm_y)
                train_cosines.append(cos_sim)
        
        X_test_pca = (X_test - X_mean) @ X_components.T
        Y_test_pca = (Y_test - Y_mean) @ Y_components.T
        X_test_aligned = s * (X_test_pca @ R)
        
        test_cosines = []
        for i in range(len(X_test_aligned)):
            norm_x = np.linalg.norm(X_test_aligned[i])
            norm_y = np.linalg.norm(Y_test_pca[i])
            if norm_x > 0 and norm_y > 0:
                cos_sim = np.dot(X_test_aligned[i], Y_test_pca[i]) / (norm_x * norm_y)
                test_cosines.append(cos_sim)
        
        train_score = np.mean(train_cosines) if train_cosines else 0.0
        test_score = np.mean(test_cosines) if test_cosines else 0.0
        
        return {
            'test_cosine': test_score,
            'train_cosine': train_score,
            'scale_factor': s,
            'train_test_gap': train_score - test_score,
            'n_test_vectors': len(test_cosines),
            'n_train_vectors': len(train_cosines)
        }
    
    def run_single_pair_multiple_times(self, source: str, target: str) -> Dict:
        results_per_run = []
        
        for run_idx, seed in enumerate(self.seeds):
            logger.info(f"    Run {run_idx + 1}/5 with seed {seed}")
            
            X_train, Y_train, X_test, Y_test = self.prepare_data(source, target, seed)
            
            if X_train is None:
                logger.error(f"No data for {source} -> {target}")
                continue
            
            run_results = self.similarity_procrustes(X_train, Y_train, X_test, Y_test)
            run_results['seed'] = seed
            run_results['run_idx'] = run_idx
            results_per_run.append(run_results)
        
        if results_per_run:
            test_cosines = [r['test_cosine'] for r in results_per_run]
            scale_factors = [r['scale_factor'] for r in results_per_run]
            train_test_gaps = [r['train_test_gap'] for r in results_per_run]
            
            summary = {
                'test_cosine_mean': np.mean(test_cosines),
                'test_cosine_std': np.std(test_cosines),
                'scale_factor_mean': np.mean(scale_factors),
                'scale_factor_std': np.std(scale_factors),
                'train_test_gap_mean': np.mean(train_test_gaps),
                'train_test_gap_std': np.std(train_test_gaps),
                'n_runs': len(results_per_run),
                'individual_runs': results_per_run
            }
        else:
            summary = {
                'test_cosine_mean': 0.0,
                'test_cosine_std': 0.0,
                'scale_factor_mean': 0.0,
                'scale_factor_std': 0.0,
                'train_test_gap_mean': 0.0,
                'train_test_gap_std': 0.0,
                'n_runs': 0,
                'individual_runs': []
            }
        
        return summary
    
    def run_all_experiments(self) -> Dict:
        results = {}
        
        model_pairs = [
            ('gemma', 'llama3'),
            ('gemma', 'mistral'),
            ('llama3', 'gemma'),
            ('llama3', 'mistral'),
            ('mistral', 'gemma'),
            ('mistral', 'llama3')
        ]
        
        start_time = time.time()
        
        for source, target in model_pairs:
            pair_key = f"{source}_to_{target}"
            logger.info(f"\n{'='*60}")
            logger.info(f"Processing {pair_key} (5 runs)")
            logger.info(f"{'='*60}")
            
            pair_results = self.run_single_pair_multiple_times(source, target)
            results[pair_key] = pair_results
            
            logger.info(f"\nSummary for {pair_key}:")
            logger.info(f"  Test Cosine: {pair_results['test_cosine_mean']:.4f} ± {pair_results['test_cosine_std']:.4f}")
            logger.info(f"  Scale Factor: {pair_results['scale_factor_mean']:.4f} ± {pair_results['scale_factor_std']:.4f}")
            logger.info(f"  Train-Test Gap: {pair_results['train_test_gap_mean']:.4f} ± {pair_results['train_test_gap_std']:.4f}")
        
        all_test_cosines = []
        for pair_key in results:
            all_test_cosines.append(results[pair_key]['test_cosine_mean'])
        
        results['overall_mean'] = {
            'test_cosine': np.mean(all_test_cosines),
            'scale_factor': np.mean([results[pk]['scale_factor_mean'] for pk in results if pk != 'overall_mean']),
            'train_test_gap': np.mean([results[pk]['train_test_gap_mean'] for pk in results if pk != 'overall_mean'])
        }
        
        elapsed_time = time.time() - start_time
        logger.info(f"\nTotal runtime: {elapsed_time/60:.1f} minutes")
        
        return results
    
    def format_table_output(self, results: Dict) -> str:
        model_pairs = [
            ('gemma', 'llama3', 'Gemma → LLaMA'),
            ('gemma', 'mistral', 'Gemma → Mistral'),
            ('llama3', 'gemma', 'LLaMA → Gemma'),
            ('llama3', 'mistral', 'LLaMA → Mistral'),
            ('mistral', 'gemma', 'Mistral → Gemma'),
            ('mistral', 'llama3', 'Mistral → LLaMA')
        ]
        
        table_rows = []
        for source, target, display_name in model_pairs:
            pair_key = f"{source}_to_{target}"
            if pair_key in results:
                r = results[pair_key]
                row = f"{display_name:20s} & {r['test_cosine_mean']:.3f} ± {r['test_cosine_std']:.3f} & {r['scale_factor_mean']:.3f} & {r['train_test_gap_mean']:.3f} \\\\"
                table_rows.append(row)
        
        if 'overall_mean' in results:
            r = results['overall_mean']
            row = f"\\midrule\n\\textbf{{Mean:20s}} & \\textbf{{{r['test_cosine']:.3f}}} & {r['scale_factor']:.3f} & {r['train_test_gap']:.3f} \\\\"
            table_rows.append(row)
        
        return "\n".join(table_rows)

def main():
    print("="*60)
    print("MULTIPLE RUN TRANSFER EXPERIMENTS")
    print("Running each model pair 5 times with k=1300")
    print("="*60)
    
    runner = MultipleRunTransferExperiments()
    results = runner.run_all_experiments()
    
    output_file = Path('results/transfer_results_5runs.json')
    output_file.parent.mkdir(exist_ok=True, parents=True)
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved results to {output_file}")
    
    print("\n" + "="*60)
    print("LATEX TABLE FORMAT:")
    print("="*60)
    print(runner.format_table_output(results))
    print("="*60)
    
    return results

if __name__ == "__main__":
    main()