import os
import sys
import numpy as np
import torch
import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional

import dataset
from agent_to_agent_flow import (
    load_model, run_agent_interaction, 
    find_convergence_step, compute_angle_between_objectives
)
from plot_utils import (
    create_axa_convergence_dynamics,
    create_objective_alignment_study, create_theoretical_validation,
    convert_tensors_to_lists
)


def generate_targeted_w_star(w_star_reference: torch.Tensor, target_angle_deg: float, 
                           angle_tolerance: float = 5.0, device: str = 'cpu') -> torch.Tensor:
    """
    Generate a w_star vector that has a specific angle with the reference vector.
    
    Args:
        w_star_reference: Reference vector to compute angle against
        target_angle_deg: Target angle in degrees (0=aligned, 90=orthogonal, 180=opposite)
        angle_tolerance: Tolerance in degrees for accepting the generated vector
        device: Device to generate tensor on
        
    Returns:
        Generated w_star vector with desired angle
    """
    d = w_star_reference.shape[0]
    target_angle_rad = np.radians(target_angle_deg)
    
    w_ref_norm = w_star_reference / torch.norm(w_star_reference)
    
    max_attempts = 1000
    for attempt in range(max_attempts):
        if target_angle_deg <= 10:  # Aligned case
            # Generate vector close to reference direction with some noise
            noise_scale = np.sin(target_angle_rad) * 0.5  # Scale noise based on target angle
            w_candidate = w_ref_norm + noise_scale * torch.randn(d, device=device) / np.sqrt(d)
            
        elif target_angle_deg >= 160:  # Opposite case  
            noise_scale = np.sin(np.pi - target_angle_rad) * 0.5
            w_candidate = -w_ref_norm + noise_scale * torch.randn(d, device=device) / np.sqrt(d)
            
        else:  # Orthogonal or other angles
            random_vec = torch.randn(d, device=device) / np.sqrt(d)
            parallel_component = torch.dot(random_vec, w_ref_norm) * w_ref_norm
            orthogonal_component = random_vec - parallel_component
            orthogonal_component = orthogonal_component / torch.norm(orthogonal_component)
            
            w_candidate = np.cos(target_angle_rad) * w_ref_norm + np.sin(target_angle_rad) * orthogonal_component
        
        w_candidate = w_candidate / torch.norm(w_candidate)
        actual_angle = compute_angle_between_objectives(w_star_reference, w_candidate)
        
        if abs(actual_angle - target_angle_deg) <= angle_tolerance:
            return w_candidate
    
    print(f"Warning: Could not generate exact angle {target_angle_deg}°, got {actual_angle:.1f}°")
    return w_candidate


def generate_targeted_dataset(d: int, n: int, lr: float, w_star_target: torch.Tensor, device: str):
    """
    Generate dataset with a specific target w_star vector.
    
    Args:
        d, n, lr: Dataset parameters
        w_star_target: Target optimal vector
        device: Device for tensors
        
    Returns:
        Same format as dataset.data_generation but with controlled w_star
    """
    X = torch.randn(d, n, device=device) / np.sqrt(d)
    
    w_star = w_star_target.unsqueeze(-1) if w_star_target.dim() == 1 else w_star_target
    y = X.T @ w_star
    
    w, g, max_iter = dataset.gradient_descent(X, y, torch.zeros(d, 1, device=device), w_star, lr)
    
    return X, y, torch.stack(w, dim=-1).squeeze(), torch.stack(g, dim=-1).squeeze(), max_iter, w_star.squeeze()


def run_convergence_only_experiments_random(model_path: str) -> dict:
    """
    Original random sampling approach - keeps sampling until all three cases are found.
    
    Args:
        model_path: Path to trained model
        
    Returns:
        Dictionary with convergence examples for the three cases
    """
    print("Running convergence-only experiments (random sampling)...")
    
    # Load model
    model_agent1, d, T, learning_rate, device = load_model(model_path)
    model_agent2, _, _, _, _ = load_model(model_path)
    
    results = {
        'convergence_examples': {},
    }
    
    # Generate fixed Agent 1 dataset
    X1, y1, w_gt1, gd_gt1, max_step1, w_star1 = dataset.data_generation(d=d, n=T, lr=learning_rate)
    X1, y1, w_gt1, gd_gt1, w_star1 = X1.to(device), y1.to(device), w_gt1.to(device), gd_gt1.to(device), w_star1.to(device)
    
    print(f"Agent 1 fixed: w_star = {w_star1[:3]}..., convergence steps = {max_step1}")
    
    # Track which cases we still need
    needed_cases = {'aligned','orthogonal', 'opposite'}
    found_cases = {}
    
    sample_count = 0
    
    while needed_cases:
        sample_count += 1
        # Generate Agent 2 dataset
        X2, y2, w_gt2, gd_gt2, max_step2, w_star2 = dataset.data_generation(d=d, n=T, lr=learning_rate)
        X2, y2, w_gt2, gd_gt2, w_star2 = X2.to(device), y2.to(device), w_gt2.to(device), gd_gt2.to(device), w_star2.to(device)
        
        # Compute angle between objectives
        angle_deg = compute_angle_between_objectives(w_star1, w_star2)
        
        # Check if this angle corresponds to a needed case
        category = None
        if 'aligned' in needed_cases and 0 <= angle_deg <= 10:
            category = 'aligned'
        elif 'orthogonal' in needed_cases and 80 < angle_deg <= 100:
            category = 'orthogonal'  
        elif 'opposite' in needed_cases and 160 < angle_deg <= 180:
            category = 'opposite'
            
        if category:
            print(f"Sample {sample_count}: Found {category} case at {angle_deg:.1f}°, processing...")
            
            # Run interaction for this sample
            interaction_data = run_agent_interaction(
                model_agent1, model_agent2, 
                X1, y1, w_star1, max_step1,
                X2, y2, w_star2, max_step2,
                d, learning_rate, device
            )
            
            sample_data = {
                'angle_deg': angle_deg,
                'agent1_distances': interaction_data['agent1_distances'],
                'agent2_distances': interaction_data['agent2_distances'], 
                'theoretical_results': interaction_data['theoretical_results'],
                'convergence_step': find_convergence_step(
                    interaction_data['g1_predictions'],
                    interaction_data['g2_predictions']
                ),
                'sample_count': sample_count
            }
            
            found_cases[category] = sample_data
            needed_cases.remove(category)
            print(f"✓ Found {category} example at {angle_deg:.1f}°")
            print(f"  Remaining cases needed: {needed_cases}")
        else:
            already_found = list(found_cases.keys())
            if already_found:
                found_info = f", found: {already_found}"
            else:
                found_info = ""
            print(f"Sample {sample_count}: Angle {angle_deg:.1f}° (not needed{found_info})")
    
    results['convergence_examples'] = found_cases
    
    print(f"\nCompleted random convergence-only experiments:")
    print(f"  Total samples processed: {sample_count}")
    print(f"  Found examples: {list(found_cases.keys())}")
    for category, data in found_cases.items():
        print(f"    {category}: {data['angle_deg']:.1f}° (sample #{data['sample_count']})")
    
    return results


def run_convergence_only_experiments(model_path: str) -> dict:
    """
    Run experiments to find examples for convergence plots only.
    Uses targeted sampling to quickly find aligned, orthogonal, and opposite cases.
    
    Args:
        model_path: Path to trained model
        
    Returns:
        Dictionary with convergence examples for the three cases
    """
    print("Running convergence-only experiments (targeted sampling)...")
    
    model_agent1, d, T, learning_rate, device = load_model(model_path)
    model_agent2, _, _, _, _ = load_model(model_path)
    
    results = {
        'convergence_examples': {},
    }
    
    # Generate fixed Agent 1 dataset
    X1, y1, w_gt1, gd_gt1, max_step1, w_star1 = dataset.data_generation(d=d, n=T, lr=learning_rate)
    X1, y1, w_gt1, gd_gt1, w_star1 = X1.to(device), y1.to(device), w_gt1.to(device), gd_gt1.to(device), w_star1.to(device)
    
    print(f"Agent 1 fixed: w_star = {w_star1[:3]}..., convergence steps = {max_step1}")
    
    # Define target angles for each case
    target_cases = {
        'aligned': 10.0,      # Target 5° (within 0-10° range)
        'orthogonal': 90.0,  # Target 90° (within 80-100° range)  
        'opposite': 170.0    # Target 170° (within 160-180° range)
    }
    
    found_cases = {}
    
    for category, target_angle in target_cases.items():
        print(f"\nGenerating {category} case (target: {target_angle}°)...")
        
        w_star2 = generate_targeted_w_star(
            w_star1, target_angle, 
            angle_tolerance=5.0, device=device
        )
        
        X2, y2, w_gt2, gd_gt2, max_step2, _ = generate_targeted_dataset(
            d, T, learning_rate, w_star2, device
        )
        
        actual_angle = compute_angle_between_objectives(w_star1, w_star2)
        print(f"✓ Generated {category} case at {actual_angle:.1f}° (target: {target_angle}°)")
        
        # Run interaction for this sample
        interaction_data = run_agent_interaction(
            model_agent1, model_agent2, 
            X1, y1, w_star1, max_step1,
            X2, y2, w_star2, max_step2,
            d, learning_rate, device
        )
        
        sample_data = {
            'angle_deg': actual_angle,
            'agent1_distances': interaction_data['agent1_distances'],
            'agent2_distances': interaction_data['agent2_distances'], 
            'theoretical_results': interaction_data['theoretical_results'],
            'convergence_step': find_convergence_step(
                interaction_data['g1_predictions'],
                interaction_data['g2_predictions']
            ),
            'sample_count': len(found_cases) + 1
        }
        
        found_cases[category] = sample_data
        print(f"✓ Processed {category} example successfully")
    
    results['convergence_examples'] = found_cases
    
    print(f"\nCompleted convergence-only experiments:")
    print(f"  Total samples processed: {len(found_cases)} (targeted)")
    print(f"  Found examples: {list(found_cases.keys())}")
    for category, data in found_cases.items():
        print(f"    {category}: {data['angle_deg']:.1f}°")
    
    return results


def run_continuous_angle_experiments(model_path: str, n_samples: int = 50) -> dict:
    """
    Run systematic experiments across all angles. 
    
    Args:
        model_path: Path to trained model
        n_samples: Total number of experimental samples to collect
        
    Returns:
        Dictionary with experimental results 
    """
    print("Running  experiments with continuous angle sampling...")
    
    # Load model
    model_agent1, d, T, learning_rate, device = load_model(model_path)
    model_agent2, _, _, _, _ = load_model(model_path)
    
    results = {
        'convergence_examples': {},
        'alignment_study': [],
        'theoretical_validation': {'agent1_empirical': [], 'agent1_theoretical': [],
                                 'agent2_empirical': [], 'agent2_theoretical': []}
    }
    
    # Generate fixed Agent 1 dataset
    X1, y1, w_gt1, gd_gt1, max_step1, w_star1 = dataset.data_generation(d=d, n=T, lr=learning_rate)
    X1, y1, w_gt1, gd_gt1, w_star1 = X1.to(device), y1.to(device), w_gt1.to(device), gd_gt1.to(device), w_star1.to(device)
    
    print(f"Collecting {n_samples} samples across all angles...")
    print(f"Agent 1 fixed: w_star = {w_star1[:3]}..., convergence steps = {max_step1}")
    
    # Store best examples for each category (for individual plots)
    best_examples = {'aligned': None, 'orthogonal': None, 'opposite': None}# commented out
    best_distances = {'aligned': float('inf'), 'orthogonal': float('inf'), 'opposite': float('inf')}  # 'orthogonal': float('inf'), 'opposite': float('inf') commented out
    
    collected_samples = 0
    
    # Collect exactly n_samples without filtering
    while collected_samples < n_samples:
        # Generate Agent 2 dataset
        X2, y2, w_gt2, gd_gt2, max_step2, w_star2 = dataset.data_generation(d=d, n=T, lr=learning_rate)
        X2, y2, w_gt2, gd_gt2, w_star2 = X2.to(device), y2.to(device), w_gt2.to(device), gd_gt2.to(device), w_star2.to(device)
        
        # Compute angle between objectives
        angle_deg = compute_angle_between_objectives(w_star1, w_star2)
        
        # Run interaction for every sample (no filtering)
        interaction_data = run_agent_interaction(
            model_agent1, model_agent2, 
            X1, y1, w_star1, max_step1,
            X2, y2, w_star2, max_step2,
            d, learning_rate, device
        )
        
        # Store data for continuous analysis
        sample_data = {
            'angle_deg': angle_deg,
            'agent1_distances': interaction_data['agent1_distances'],
            'agent2_distances': interaction_data['agent2_distances'], 
            'theoretical_results': interaction_data['theoretical_results'],
            'convergence_step': find_convergence_step(
                interaction_data['g1_predictions'],
                interaction_data['g2_predictions']
            ),
            'sample_count': collected_samples + 1
        }
        
        # Compute eigenvalue parameters from the corollary
        theoretical_results = interaction_data['theoretical_results']
        S_W = theoretical_results['S_W']  # Agent 1's second moment matrix
        S_U = theoretical_results['S_Gamma']  # Agent 2's second moment matrix (S_Gamma in code)
        S = theoretical_results['S']  # Combined matrix S = S_W + S_U
        
        # Compute the exact eigenvalue parameters from the corollary
        # α_U = √(λ_min(S_W S^(-2) S_W)), β_U = √(λ_max(S_W S^(-2) S_W))
        # α_W = √(λ_min(S_U S^(-2) S_U)), β_W = √(λ_max(S_U S^(-2) S_U))
        S_inv = torch.pinverse(S)
        S_inv_squared = S_inv @ S_inv
        
        # Compute the matrices from the corollary
        M_U = S_W @ S_inv_squared @ S_W  # For α_U, β_U
        M_W = S_U @ S_inv_squared @ S_U  # For α_W, β_W
        
        # Compute eigenvalues
        eigenvals_U = torch.linalg.eigvals(M_U).real
        eigenvals_W = torch.linalg.eigvals(M_W).real
        
        # Extract min and max eigenvalues and take square root
        alpha_U = torch.sqrt(torch.min(eigenvals_U)).item()
        beta_U = torch.sqrt(torch.max(eigenvals_U)).item()
        alpha_W = torch.sqrt(torch.min(eigenvals_W)).item()
        beta_W = torch.sqrt(torch.max(eigenvals_W)).item()
        
        # Add to alignment study (continuous relationship)
        results['alignment_study'].append({
            'angle_deg': angle_deg,
            'plateau_agent1': interaction_data['agent1_distances'][-1],
            'plateau_agent2': interaction_data['agent2_distances'][-1],
            'theoretical_agent1': interaction_data['theoretical_results']['plateau_error_agent1_boxed'],
            'theoretical_agent2': interaction_data['theoretical_results']['plateau_error_agent2_boxed'],
            # Add the exact eigenvalue parameters from the corollary
            'alpha_U': alpha_U,
            'beta_U': beta_U, 
            'alpha_W': alpha_W,
            'beta_W': beta_W,
            'w_star_norm': torch.norm(interaction_data['w_star1']).item(),
            'u_star_norm': torch.norm(interaction_data['w_star2']).item()
        })
        
        # Add to theoretical validation
        results['theoretical_validation']['agent1_empirical'].append(
            interaction_data['agent1_distances'][-1]
        )
        results['theoretical_validation']['agent1_theoretical'].append(
            interaction_data['theoretical_results']['plateau_error_agent1_boxed']
        )
        results['theoretical_validation']['agent2_empirical'].append(
            interaction_data['agent2_distances'][-1]
        )
        results['theoretical_validation']['agent2_theoretical'].append(
            interaction_data['theoretical_results']['plateau_error_agent2_boxed']
        )
        
        
        # Check if this is a good representative example for any category
        category = None
        if 0 <= angle_deg <= 10:
            category = 'aligned'
        if 80 < angle_deg <= 100:
            category = 'orthogonal'  
        elif 160 < angle_deg <= 180:
            category = 'opposite'
            
        if category:
            # Use minimum final distance as quality metric (better convergence)
            min_final_dist = min(interaction_data['agent1_distances'][-1], 
                                interaction_data['agent2_distances'][-1])
            if min_final_dist < best_distances[category]:
                best_distances[category] = min_final_dist
                best_examples[category] = sample_data
                print(f"Sample {collected_samples + 1}: New best {category} example at {angle_deg:.1f}° (quality: {min_final_dist:.4f})")
            else:
                print(f"Sample {collected_samples + 1}: {category} example at {angle_deg:.1f}°")
        else:
            print(f"Sample {collected_samples + 1}: Angle {angle_deg:.1f}°")
        
        collected_samples += 1
    
    # Store best examples for individual convergence plots
    for category, example in best_examples.items():
        if example is not None:
            results['convergence_examples'][category] = example
    
    print(f"  Total samples: {collected_samples}")
    angles = [s['angle_deg'] for s in results['alignment_study']]
    print(f"  Angle range: {min(angles):.1f}° - {max(angles):.1f}°")
    print(f"  Mean angle: {np.mean(angles):.1f}° ± {np.std(angles):.1f}°")
    print(f"  Best examples found: {list(k for k, v in best_examples.items() if v is not None)}")
    
    return results


def generate_all_publication_figures(results: dict, output_dir: str):
    """Generate all figures from experimental results."""
    
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nGenerating  figures in {output_dir}...")
    
    # 1. AxA Convergence Examples (best representatives from each category)
    for category, data in results['convergence_examples'].items():
        print(f"Creating {category} convergence example...")
        create_axa_convergence_dynamics(
            data['agent1_distances'],
            data['agent2_distances'], 
            data['angle_deg'],
            data['theoretical_results'],
            data['convergence_step'],
            save_path=os.path.join(output_dir, f"{category}_convergence")
        )
        import matplotlib.pyplot as plt
        plt.close()  # Close to save memory
    
    # 2. Objective Alignment Study (continuous relationship)
    if results['alignment_study']:
        print("Creating objective alignment study...")
        create_objective_alignment_study(
            results['alignment_study'],
            save_path=os.path.join(output_dir, "objective_alignment_study")
        )
        import matplotlib.pyplot as plt
        plt.close()
    
    # 3. Theoretical Validation Plots
    if results['theoretical_validation']['agent1_empirical']:
        print("Creating theoretical validation plots...")
        create_theoretical_validation(
            results['theoretical_validation']['agent1_empirical'],
            results['theoretical_validation']['agent1_theoretical'],
            agent_name="W",
            save_path=os.path.join(output_dir, "theoretical_validation_agent1")
        )
        import matplotlib.pyplot as plt
        plt.close()
        
        create_theoretical_validation(
            results['theoretical_validation']['agent2_empirical'], 
            results['theoretical_validation']['agent2_theoretical'],
            agent_name="U",
            save_path=os.path.join(output_dir, "theoretical_validation_agent2")
        )
        import matplotlib.pyplot as plt
        plt.close()
    
    
    print("All  figures generated successfully!")


def main():
    """Main function to run  figure generation."""
    parser = argparse.ArgumentParser(description="Generate figures from agent-to-agent analysis")
    parser.add_argument("--model_path", type=str,
                       default="experiments/exp_20250909_081624_lr0.005_adam_ep100_d10_T20_nds100_bs512_neval10/model_latest.pth",
                       help="Path to trained model checkpoint")
    parser.add_argument("--output_dir", type=str,
                       default="../latex-paper/figs/publication_improved",
                       help="Directory to save publication figures")
    parser.add_argument("--n_samples", type=int, default=1000,
                       help="Total number of samples to collect (only used in full mode)")
    parser.add_argument("--convergence_only", action="store_true",
                       help="Run optimized mode to generate only convergence plots for 3 cases")
    parser.add_argument("--sampling_method", type=str, default="targeted", 
                       choices=["targeted", "random"],
                       help="Sampling method: targeted (fast) or random (original)")
    
    args = parser.parse_args()
    
    if args.convergence_only:
        print(f"Running convergence-only mode ({args.sampling_method} sampling)...")
        
        # Choose sampling method
        if args.sampling_method == "targeted":
            results = run_convergence_only_experiments(args.model_path)
        else:  # random
            results = run_convergence_only_experiments_random(args.model_path)
        
        # Generate only convergence plots
        os.makedirs(args.output_dir, exist_ok=True)
        print(f"\nGenerating convergence plots in {args.output_dir}...")
        
        for category, data in results['convergence_examples'].items():
            print(f"Creating {category} convergence plot...")
            create_axa_convergence_dynamics(
                data['agent1_distances'],
                data['agent2_distances'], 
                data['angle_deg'],
                data['theoretical_results'],
                data['convergence_step'],
                save_path=os.path.join(args.output_dir, f"{category}_convergence")
            )
            import matplotlib.pyplot as plt
            plt.close()  # Close to save memory
        
        print("Convergence plots generated successfully!")
        
    else:
        print("Running full mode with continuous angle experiments...")
        
        # Run experiments to collect data
        results = run_continuous_angle_experiments(args.model_path, args.n_samples)
        
        # Generate all  figures
        generate_all_publication_figures(results, args.output_dir)
        
        # Save experimental data for future use (only in full mode)
        results_file = os.path.join(args.output_dir, "experimental_results.json")
        json_results = convert_tensors_to_lists(results)
        
        os.makedirs(args.output_dir, exist_ok=True)
        with open(results_file, 'w') as f:
            json.dump(json_results, f, indent=2)
        
        print(f"\nExperimental results saved to: {results_file}")
        
        # Print summary statistics (only in full mode)
        if 'alignment_study' in results and results['alignment_study']:
            angles = [s['angle_deg'] for s in results['alignment_study']]
            plateaus1 = [s['plateau_agent1'] for s in results['alignment_study']]
            plateaus2 = [s['plateau_agent2'] for s in results['alignment_study']]
            
            print(f"\nSummary Statistics:")
            print(f"  Angles: {min(angles):.1f}° to {max(angles):.1f}° (mean: {np.mean(angles):.1f}°)")
            print(f"  Agent 1 plateaus: {min(plateaus1):.4f} to {max(plateaus1):.4f}")
            print(f"  Agent 2 plateaus: {min(plateaus2):.4f} to {max(plateaus2):.4f}")
            
            # Check correlation between angle and plateau error
            correlation1 = np.corrcoef(angles, plateaus1)[0, 1]
            correlation2 = np.corrcoef(angles, plateaus2)[0, 1]
            print(f"  Angle-plateau correlation: Agent 1: {correlation1:.3f}, Agent 2: {correlation2:.3f}")


if __name__ == "__main__":
    main()


