import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import argparse
import json
import re
from typing import Optional, Tuple, Dict, List
import model
import dataset
from agent_to_agent_flow import (
    load_model, 
    run_agent_interaction,
    create_agent_model
)
from plot_utils import (
    COLORS, AGENT_COLORS, plot_adversarial_result,
    create_adversarial_statistics_plots
)


def adversarial_construction_line_space(S_1: torch.Tensor, w_star_1: torch.Tensor, w_star_2: torch.Tensor, 
                                      n: int, eta: float, tau: float = 0.1, delta: float = 1e-8, 
                                      device: str = 'cpu') -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Implement Algorithm 1: Adversarial construction of X_2 from line-space span{S_1(w_1*-w_2*)}.
    
    Args:
        S_1: Second moment matrix of agent 1 (shape: [d, d])
        w_star_1: Target weight for agent 1 (shape: [d])
        w_star_2: Target weight for agent 2 (shape: [d]) 
        n: Sample size for X_2
        eta: Step size
        tau: Stability margin (default: 0.1)
        delta: Tolerance (default: 1e-8)
        device: PyTorch device
        
    Returns:
        X_2: Adversarially constructed data matrix (shape: [d, n])
        S_2: Constructed second moment matrix (shape: [d, d])
    """
    d = S_1.shape[0]
    
    m = w_star_1 - w_star_2  
    v = S_1 @ m  

    # Projector onto U = span{v}
    P_U = torch.outer(v, v) / torch.norm(v)**2
    
    # Step 2: Build the adversarial geometry S_2
    lambda_max_S_1 = torch.max(torch.linalg.eigvals(S_1).real).item()
    
    # Make sure the maximum eigenvalue of S_2 is reasonable
    target_max_eigenval = min(1.0 / eta, 2 * lambda_max_S_1)
    epsilon = 0.1 * target_max_eigenval  
    
    
    I = torch.eye(d, device=device)
    S_2 = target_max_eigenval * P_U + epsilon * (I - P_U)
    

    try:
        L = torch.linalg.cholesky(S_2)
    except:
        # Fallback to eigendecomposition if Cholesky fails
        eigenvals, eigenvecs = torch.linalg.eigh(S_2)
        eigenvals = torch.clamp(eigenvals, min=1e-10)  # Ensure positive definiteness
        L = eigenvecs @ torch.diag(torch.sqrt(eigenvals))
    
    X_2 = np.sqrt(n) * L
    
    return X_2, S_2


def generate_adversarial_agent_data(X_1: torch.Tensor, y_1: torch.Tensor, w_star_1: torch.Tensor, 
                                  eta: float, R: Optional[torch.Tensor] = None, 
                                  device: str = 'cpu') -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate adversarial data for agent 2 given agent 1's data.
    
    Args:
        X_1: Agent 1's data matrix (shape: [d, n])
        y_1: Agent 1's target vector (shape: [n, 1])
        w_star_1: Agent 1's optimal weight (shape: [d])
        eta: Learning rate
        R: Transformation matrix for w_star_2 = R @ w_star_1 (default: identity)
        device: PyTorch device
        
    Returns:
        X_2: Adversarial data matrix for agent 2
        y_2: Target vector for agent 2  
        w_star_2: Optimal weight for agent 2
    """
    d, n = X_1.shape
    
    # Compute S_1 = X_1 @ X_1.T
    S_1 = X_1 @ X_1.T
    
    if R is None:
        R = torch.eye(d, device=device)
    
    w_star_2 = R @ w_star_1
    
    X_2, S_2 = adversarial_construction_line_space(
        S_1, w_star_1, w_star_2, n, eta, device=device
    )
    
    y_2 = X_2.T @ w_star_2.unsqueeze(1)
    
    return X_2, y_2, w_star_2


def create_agents_for_experiment(agent_type: str, model_path: str = None, openai_model: str = "gpt-5-mini",
                                X1: Optional[torch.Tensor] = None, y1: Optional[torch.Tensor] = None,
                                X2: Optional[torch.Tensor] = None, y2: Optional[torch.Tensor] = None,
                                d: Optional[int] = None, learning_rate: Optional[float] = None, device: str = 'cpu'):
    """
    Create homogeneous agent pairs for adversarial experiments (both same type).
    
    Args:
        agent_type: Either "model" or "openai" (both agents will be this type)
        model_path: Path to model checkpoint (required if agent_type="model")
        openai_model: OpenAI model name (used if agent_type="openai")
        X1, y1: Data for agent 1 (required for OpenAI agents)
        X2, y2: Data for agent 2 (required for OpenAI agents)
        d: Problem dimension (required for OpenAI agents)
        learning_rate: Learning rate (required for OpenAI agents)
        device: Device string
        
    Returns:
        tuple: (agent1, agent2, learning_rate, device)
    """
    if agent_type == "model":
        if model_path is None:
            raise ValueError("model_path is required for CoT model agents")
        
        # Create two identical CoT models
        agent1, d, T, lr, device = load_model(model_path, device)
        agent2, _, _, _, _ = load_model(model_path, device)
        return agent1, agent2, lr, device
        
    elif agent_type == "openai":
        if X1 is None or y1 is None or X2 is None or y2 is None or d is None or learning_rate is None:
            raise ValueError("X1, y1, X2, y2, d, and learning_rate are required for OpenAI agents")
        
        agent1_config = {"type": "openai", "model_name": openai_model}
        agent2_config = {"type": "openai", "model_name": openai_model}
        
        agent1, _, _, _, device = create_agent_model(agent1_config, X=X1, y=y1, d=d, device=device)
        agent2, _, _, _, device = create_agent_model(agent2_config, X=X2, y=y2, d=d, device=device)
        
        return agent1, agent2, learning_rate, device
    
    else:
        raise ValueError(f"Unknown agent_type: {agent_type}. Use 'model' or 'openai'.")


def run_adversarial_attack_experiment(agent_type: str = "model", model_path: str = None, 
                                    openai_model: str = "gpt-5-mini", num_trials: int = 5, 
                                    R_type: str = 'opposite', d: int = None, learning_rate: float = None):
    """
    Run adversarial attack experiments with homogeneous agent pairs.
    
    Args:
        agent_type: Either "model" (for CoT models) or "openai" (for OpenAI agents)
        model_path: Path to the trained model (required if agent_type="model")
        openai_model: OpenAI model name (used if agent_type="openai")
        num_trials: Number of trials to run
        R_type: Type of transformation ('opposite', 'orthogonal', 'scaled')
        d: Problem dimension (required if agent_type="openai")
        learning_rate: Learning rate (required if agent_type="openai")
    """
    if agent_type == "model":
        if model_path is None:
            raise ValueError("model_path is required when agent_type='model'")
        # Load model to get parameters
        _, d, T, learning_rate, device = load_model(model_path)
        model_dir = os.path.dirname(model_path)
    elif agent_type == "openai":
        if d is None or learning_rate is None:
            raise ValueError("d and learning_rate must be provided when agent_type='openai'")
        T = 20  
        device = 'cpu'
        model_dir = f"adversarial_openai_experiments_{openai_model.replace('-', '_')}"
    else:
        raise ValueError(f"Unknown agent_type: {agent_type}. Use 'model' or 'openai'.")
    
    print(f"Running adversarial attack experiment with {num_trials} trials")
    print(f"Agent type: {agent_type}")
    if agent_type == "model":
        print(f"Model path: {model_path}")
    else:
        print(f"OpenAI model: {openai_model}")
    print(f"Parameters: d={d}, T={T}, lr={learning_rate}")
    print(f"Transformation type: {R_type}")
    
    plots_dir = os.path.join(model_dir, "adversarial_plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    results = []
    
    for trial in range(num_trials):
        print(f"\n--- Trial {trial + 1}/{num_trials} ---")
        
        # Generate Agent 1's dataset (victim)
        X_1, y_1, w_gt_1, gd_gt_1, max_step_1, w_star_1 = dataset.data_generation(d=d, n=T, lr=learning_rate)
        X_1, y_1, w_star_1 = X_1.to(device), y_1.to(device), w_star_1.to(device)
        
        # Create transformation matrix R based on type
        if R_type == 'opposite':
            # Create opposite/negative transformation
            R = -torch.eye(d, device=device)
        elif R_type == 'orthogonal':
            # Random orthogonal matrix
            Q, _ = torch.linalg.qr(torch.randn(d, d, device=device))
            R = Q
        elif R_type == 'scaled':
            # Scaled identity
            R = 0.5 * torch.eye(d, device=device)
        else:
            raise ValueError(f"Unknown R_type: {R_type}")
        
        try:
            X_2, y_2, w_star_2 = generate_adversarial_agent_data(
                X_1, y_1, w_star_1, learning_rate, R, device
            )
            
            model_agent1, model_agent2, trial_lr, trial_device = create_agents_for_experiment(
                agent_type, model_path, openai_model, X_1, y_1, X_2, y_2, d, learning_rate, device
            )
            

            max_step_2 = max_step_1  
            
            interaction_result = run_agent_interaction(
                model_agent1, model_agent2,
                X_1, y_1, w_star_1, max_step_1,
                X_2, y_2, w_star_2, max_step_2,
                d, learning_rate, device
            )
            
            # Store results
            trial_result = {
                'trial': trial,
                'R_type': R_type,
                'w_star_1': w_star_1,
                'w_star_2': w_star_2,
                'interaction_result': interaction_result,
                'attack_success': check_attack_success(interaction_result, agent_type=agent_type)
            }
            results.append(trial_result)
            
            agent_tag = "openai_" if agent_type == "openai" else ""
            save_path = os.path.join(plots_dir, f"{agent_tag}adversarial_attack_trial_{trial+1}_{R_type}.png")
            plot_adversarial_result(interaction_result, trial+1, R_type, save_path, agent_type)
            
            print(f"Attack success: {trial_result['attack_success']}")
            
        except Exception as e:
            print(f"Trial {trial+1} failed: {e}")
            continue
    
    successful_attacks = sum(1 for r in results if r['attack_success'])
    print(f"\n=== Adversarial Attack Summary ===")
    print(f"Successful attacks: {successful_attacks}/{len(results)}")
    if len(results) > 0:
        print(f"Success rate: {successful_attacks/len(results)*100:.1f}%")
    else:
        print("No successful trials completed.")
    
    return results


def save_statistical_data(results_by_type: Dict, save_path: str, agent_type: str, learning_rate: float, max_steps: int):
    """
    Save adversarial attack data (both statistical summaries and all raw trial data) to JSON file for later replotting.
    
    Args:
        results_by_type: Results grouped by R matrix type, including means, stds, and all individual trial data
        save_path: Path to save the JSON file
        agent_type: Type of agents used
        learning_rate: Learning rate used
        max_steps: Maximum steps used
    """
    def convert_numpy_to_lists(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_numpy_to_lists(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy_to_lists(item) for item in obj]
        else:
            return obj
    
    serializable_data = {
        'results_by_type': convert_numpy_to_lists(results_by_type),
        'parameters': {
            'agent_type': agent_type,
            'learning_rate': learning_rate,
            'max_steps': max_steps
        },
        'metadata': {
            'saved_at': __import__('datetime').datetime.now().isoformat(),
            'file_version': '1.0',
            'description': 'Adversarial attack statistical data for replotting'
        }
    }
    
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    with open(save_path, 'w') as f:
        json.dump(serializable_data, f, indent=2)
    
    print(f"Statistical data saved: {save_path}")


def check_attack_success(interaction_result: Dict, convergence_threshold: float = 1e-2, 
                        asymmetry_threshold: float = 5e-2, agent_type: str = "model") -> bool:
    """
    Check if the adversarial attack was successful based on asymmetric convergence.
    
    Success criteria:
    - Agent 2 (adversarial) converges to its target (small final distance)
    - Agent 1 (victim) does NOT converge to its target (larger final distance)
    
    Args:
        interaction_result: Results from run_adversarial_interaction
        convergence_threshold: Threshold for considering an agent "converged" (default for CoT models)
        asymmetry_threshold: Minimum difference required for asymmetric convergence
        agent_type: Type of agents used ("model" or "openai") - affects convergence threshold
        
    Returns:
        bool: True if attack was successful (asymmetric convergence achieved)
    """
    if agent_type == "openai":
        actual_convergence_threshold = 1e-1  
        actual_asymmetry_threshold = asymmetry_threshold  
    else:
        actual_convergence_threshold = convergence_threshold
        actual_asymmetry_threshold = asymmetry_threshold
    agent1_final_dist = interaction_result['agent1_distances'][-1]
    agent2_final_dist = interaction_result['agent2_distances'][-1]
    
    agent2_converged = agent2_final_dist < actual_convergence_threshold
    
    agent1_not_converged = agent1_final_dist > actual_asymmetry_threshold
    
    asymmetric_success = agent2_converged and agent1_not_converged

    return asymmetric_success



def generate_statistical_plots(agent_type: str = "model", model_path: str = None, 
                              openai_model: str = "gpt-5-mini", num_trials: int = 20, 
                              max_steps: int = 200, d: int = None, learning_rate: float = None):
    """
    Generate statistical plots with homogeneous agent pairs.
    
    Args:
        agent_type: Either "model" (for CoT models) or "openai" (for OpenAI agents)
        model_path: Path to the trained model (required if agent_type="model")
        openai_model: OpenAI model name (used if agent_type="openai")
        num_trials: Number of trials to average over
        max_steps: Maximum number of interaction steps to plot
        d: Problem dimension (required if agent_type="openai")
        learning_rate: Learning rate (required if agent_type="openai")
    """
    if agent_type == "model":
        if model_path is None:
            raise ValueError("model_path is required when agent_type='model'")
        _, d, T, learning_rate, device = load_model(model_path)
        model_dir = os.path.dirname(model_path)
    elif agent_type == "openai":
        if d is None or learning_rate is None:
            raise ValueError("d and learning_rate must be provided when agent_type='openai'")
        T = 20  
        device = 'cpu'
        model_dir = f"statistical_openai_experiments_{openai_model.replace('-', '_')}"
    else:
        raise ValueError(f"Unknown agent_type: {agent_type}. Use 'model' or 'openai'.")
    
    print(f"Generating statistical plots with {num_trials} trials...")
    print(f"Agent type: {agent_type}")
    if agent_type == "model":
        print(f"Model path: {model_path}")
    else:
        print(f"OpenAI model: {openai_model}")
    print(f"Parameters: d={d}, T={T}, lr={learning_rate}")
    
    plots_dir = os.path.join(model_dir, "statistical_plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    R_types = ['orthogonal', 'scaled', 'opposite']
    results_by_type = {}
    
    for R_type in R_types:
        print(f"\n--- Processing R_type: {R_type} ---")
        
        # Store distances for all trials
        all_agent1_distances = []
        all_agent2_distances = []
        success_count = 0
        
        for trial in range(num_trials):
            if (trial + 1) % 5 == 0:
                print(f"  Trial {trial + 1}/{num_trials}")
            
            try:
                # Generate Agent 1's dataset
                X_1, y_1, _, _, max_step_1, w_star_1 = dataset.data_generation(d=d, n=T, lr=learning_rate)
                X_1, y_1, w_star_1 = X_1.to(device), y_1.to(device), w_star_1.to(device)
                
                # Create transformation matrix R
                if R_type == 'opposite':
                    R = -torch.eye(d, device=device)
                elif R_type == 'orthogonal':
                    Q, _ = torch.linalg.qr(torch.randn(d, d, device=device))
                    R = Q
                elif R_type == 'scaled':
                    R = 0.5 * torch.eye(d, device=device)
                
                X_2, y_2, w_star_2 = generate_adversarial_agent_data(
                    X_1, y_1, w_star_1, learning_rate, R, device
                )
                
                max_step_2 = max_step_1
                
                model_agent1, model_agent2, trial_lr, trial_device = create_agents_for_experiment(
                    agent_type, model_path, openai_model, X_1, y_1, X_2, y_2, d, learning_rate, device
                )
                
                interaction_result = run_agent_interaction(
                    model_agent1, model_agent2,
                    X_1, y_1, w_star_1, max_step_1,
                    X_2, y_2, w_star_2, max_step_2,
                    d, learning_rate, device
                )
                
                agent1_distances = interaction_result['agent1_distances']
                agent2_distances = interaction_result['agent2_distances']
                
                if agent_type == "openai":
                    pass
                else:
                    if len(agent1_distances) < max_steps:
                        # Pad with last value
                        agent1_distances.extend([agent1_distances[-1]] * (max_steps - len(agent1_distances)))
                        agent2_distances.extend([agent2_distances[-1]] * (max_steps - len(agent2_distances)))
                    else:
                        agent1_distances = agent1_distances[:max_steps]
                        agent2_distances = agent2_distances[:max_steps]
                
                all_agent1_distances.append(agent1_distances)
                all_agent2_distances.append(agent2_distances)
                
                if check_attack_success(interaction_result, agent_type=agent_type):
                    success_count += 1
                    
            except Exception as e:
                print(f"    Trial {trial + 1} failed: {e}")
                continue
        
        if len(all_agent1_distances) > 0:
            agent1_distances_array = np.array(all_agent1_distances)
            agent2_distances_array = np.array(all_agent2_distances)
            
            agent1_mean = np.mean(agent1_distances_array, axis=0)
            agent1_std = np.std(agent1_distances_array, axis=0)
            agent2_mean = np.mean(agent2_distances_array, axis=0)
            agent2_std = np.std(agent2_distances_array, axis=0)
            
            results_by_type[R_type] = {
                'agent1_mean': agent1_mean,
                'agent1_std': agent1_std,
                'agent2_mean': agent2_mean,
                'agent2_std': agent2_std,
                'agent1_all_trials': all_agent1_distances,
                'agent2_all_trials': all_agent2_distances,
                'success_rate': success_count / len(all_agent1_distances),
                'num_trials': len(all_agent1_distances)
            }
            
            print(f"  Completed: {len(all_agent1_distances)} trials, Success rate: {success_count/len(all_agent1_distances)*100:.1f}%")
        else:
            print(f"  No successful trials for {R_type}")
    
    create_adversarial_statistics_plots(results_by_type, max_steps, plots_dir, learning_rate, agent_type)
    
    agent_tag = "openai_" if agent_type == "openai" else ""
    stats_data_path = os.path.join(plots_dir, f"{agent_tag}adversarial_statistics_lr{learning_rate}_data.json")
    save_statistical_data(results_by_type, stats_data_path, agent_type, learning_rate, max_steps)
    
    return results_by_type




def main():
    parser = argparse.ArgumentParser(description="Adversarial Multi-Agent Attack with Homogeneous Agent Pairs")
    
    # Agent type selection
    parser.add_argument("--agent_type", type=str, default="model", 
                       choices=["model", "openai"],
                       help="Type of agents (both agents will be this type)")
    
    # Model-specific arguments
    parser.add_argument("--model_path", type=str, 
                       default="experiments/exp_20250909_081624_lr0.005_adam_ep100_d10_T20_nds100_bs512_neval10/model_latest.pth",
                       help="Path to the saved model checkpoint (required if agent_type='model')")
    
    # OpenAI-specific arguments
    parser.add_argument("--openai_model", type=str, default="gpt-5-mini",
                       help="OpenAI model name (used if agent_type='openai')")
    parser.add_argument("--d", type=int, default=None,
                       help="Problem dimension (required if agent_type='openai')")
    parser.add_argument("--learning_rate", type=float, default=None,
                       help="Learning rate (required if agent_type='openai')")
    
    # Experiment parameters
    parser.add_argument("--num_trials", type=int, default=100,
                       help="Number of adversarial attack trials to run")
    parser.add_argument("--R_type", type=str, default="opposite", 
                       choices=["opposite", "orthogonal", "scaled"],
                       help="Type of transformation matrix R")
    parser.add_argument("--no-statistical", action="store_true",
                       help="Skip statistical plots (run individual trials only)")
    parser.add_argument("--max_steps", type=int, default=400,
                       help="Maximum steps for statistical plots")
    
    args = parser.parse_args()
    
    # Validate OpenAI requirements
    if args.agent_type == "openai" and (args.d is None or args.learning_rate is None):
        raise ValueError("--d and --learning_rate must be provided when agent_type='openai'")
    
    if not args.no_statistical:
        print("Generating statistical plots ...")
        results = generate_statistical_plots(
            agent_type=args.agent_type,
            model_path=args.model_path,
            openai_model=args.openai_model,
            num_trials=args.num_trials,
            max_steps=args.max_steps,
            d=args.d,
            learning_rate=args.learning_rate
        )
        
        print(f"\n=== Statistical Analysis Complete ===")
        for R_type, result in results.items():
            if result:
                print(f"{R_type.title()}: {result['num_trials']} trials, "
                      f"Success rate: {result['success_rate']:.1%}")
    else:
        print("Running adversarial attack experiments...")
        results = run_adversarial_attack_experiment(
            agent_type=args.agent_type,
            model_path=args.model_path,
            openai_model=args.openai_model,
            num_trials=args.num_trials,
            R_type=args.R_type,
            d=args.d,
            learning_rate=args.learning_rate
        )
        
        successful_attacks = sum(1 for r in results if r['attack_success'])
        print(f"\n=== Final Results ===")
        print(f"Total trials: {len(results)}")
        print(f"Successful attacks: {successful_attacks}")
        if len(results) > 0:
            print(f"Success rate: {successful_attacks/len(results)*100:.1f}%")
        else:
            print("No trials completed successfully.")


def load_and_replot_statistical_data(data_path: str, output_dir: str = None):
    """
    Load saved statistical data and regenerate the plots.
    
    Args:
        data_path: Path to the saved JSON statistical data file
        output_dir: Optional custom output directory for the plots
    """
    from plot_utils import create_adversarial_statistics_plots
    
    with open(data_path, 'r') as f:
        data = json.load(f)
    
    results_by_type = data['results_by_type']
    params = data['parameters']
    agent_type = params['agent_type']
    learning_rate = params['learning_rate']
    max_steps = params['max_steps']
    
    def convert_lists_to_numpy(obj):
        if isinstance(obj, dict):
            return {k: convert_lists_to_numpy(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return np.array(obj)
        else:
            return obj
    
    results_by_type = convert_lists_to_numpy(results_by_type)
    
    if output_dir is None:
        output_dir = os.path.dirname(data_path)
    
    create_adversarial_statistics_plots(results_by_type, max_steps, output_dir, learning_rate, agent_type)
    print(f"Statistical plots regenerated in: {output_dir}")


if __name__ == "__main__":
    main()
