import torch
import torch.nn as nn
import numpy as np
import os
import re
import model
from typing import Optional, Tuple, Dict, List
from openai_generate import create_openai_agent


def compute_theoretical_errors_theorem1(X1, y1, X2, y2, w_star1, w_star2, eta, device):
    """
    Implement Theorem 1 to compute theoretical plateau errors.
    
    In the math notation:
    - w (W) corresponds to agent 1 
    - γ (Gamma) corresponds to agent 2
    - S_W = X1 @ X1.T (agent 1's second moment matrix)
    - S_Γ = X2 @ X2.T (agent 2's second moment matrix)
    - w* = w_star1, γ* = w_star2
    - Δ = γ* - w* = w_star2 - w_star1
    
    Args:
        X1, y1: Agent 1's data (corresponds to W in math)
        X2, y2: Agent 2's data (corresponds to Γ in math)  
        w_star1: Agent 1's optimal weight (corresponds to w* in math)
        w_star2: Agent 2's optimal weight (corresponds to γ* in math)
        eta: Learning rate (step size)
        device: PyTorch device
        
    Returns:
        dict: Contains theoretical plateau errors and intermediate matrices
    """
    d = X1.shape[0]
    n = X1.shape[1]
    

    S_1 = X1 @ X1.T  # Shape: (d, d)
    S_2 = X2 @ X2.T      # Shape: (d, d)
    
    S = S_1 + S_2
    
    Delta = w_star2 - w_star1  # Shape: (d,)
    
    S_inv = torch.pinverse(S)
    
    z_1 =  S_inv @ S_2 @ Delta 
    z_2 =   S_inv @ S_1 @ Delta 

    plateau_error_agent1_boxed = torch.norm(z_1)
    plateau_error_agent2_boxed = torch.norm(z_2)

    return {
        'plateau_error_agent1_boxed': plateau_error_agent1_boxed.item(),
        'plateau_error_agent2_boxed': plateau_error_agent2_boxed.item(),
        'S_Gamma': S_2,
        'S_W': S_1,
        'S': S,
        'Delta': Delta,
    }


def extract_learning_rate_from_path(model_path: str) -> float:
    """
    Extract learning rate from model path.
    Expected format: *_lr{value}_*
    """
    match = re.search(r'_lr([0-9]*\.?[0-9]+)_', model_path)
    if match:
        return float(match.group(1))
    else:
        raise ValueError(f"Could not extract learning rate from path: {model_path}")


def load_model(model_path: str, device: Optional[str] = None) -> tuple:
    """
    Load model from checkpoint and extract configuration.
    
    Returns:
        tuple: (model, d, T, learning_rate, device)
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model checkpoint not found at: {model_path}")
    
    # Extract learning rate from path
    learning_rate = extract_learning_rate_from_path(model_path)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    if 'model_state_dict' not in checkpoint:
        raise KeyError("Checkpoint does not contain 'model_state_dict'")
    
    state_dict = checkpoint['model_state_dict']
    
    # Infer model dimensions from weight matrices
    if 'layer.W' in state_dict:
        weight_shape = state_dict['layer.W'].shape
        d = weight_shape[0] // 2 - 1
        
        # Extract T from model path
        T_match = re.search(r'_T([0-9]+)_', model_path)
        T = int(T_match.group(1)) if T_match else 10  # Default fallback
    else:
        raise KeyError("Model state dict does not contain expected 'layer.W' parameter")
    
    # Create and load model
    model_instance = model.CoT(data_num=T, d=d)
    model_instance.load_state_dict(state_dict)
    model_instance.to(device)
    model_instance.eval()
    
    print(f"Model loaded: d={d}, T={T}, lr={learning_rate}, device={device}")
    
    return model_instance, d, T, learning_rate, device


def create_openai_model(openai_model_name: str, X: torch.Tensor, y: torch.Tensor, d: int, device: str = 'cpu') -> tuple:
    """
    Create OpenAI-based agent wrapper.
    
    Args:
        openai_model_name: Name of OpenAI model (e.g., "gpt-4o-mini")
        X: Input features tensor (d x n)
        y: Target values tensor (n,)
        d: Problem dimension
        device: Device string
        
    Returns:
        tuple: (openai_wrapper, d, T=None, learning_rate=None, device)
               Note: T and learning_rate are None for OpenAI models
    """
    openai_wrapper = create_openai_agent(openai_model_name, X, y, d, device)
    print(f"OpenAI model created: model={openai_model_name}, d={d}, device={device}")
    
    return openai_wrapper, d, None, None, device


def create_agent_model(agent_config: Dict, X: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, 
                      d: Optional[int] = None, device: str = 'cpu') -> tuple:
    """
    Universal function to create either a traditional model or OpenAI agent.
    
    Args:
        agent_config: Dictionary with agent configuration
                     For traditional model: {"type": "model", "path": "path/to/model.pth"}
                     For OpenAI model: {"type": "openai", "model_name": "gpt-4o-mini"}
        X: Input features tensor (required for OpenAI models)
        y: Target values tensor (required for OpenAI models)
        d: Problem dimension (required for OpenAI models)
        device: Device string
        
    Returns:
        tuple: (agent, d, T, learning_rate, device)
               For OpenAI models: T and learning_rate will be None
    """
    if agent_config["type"] == "model":
        # Load traditional PyTorch model
        return load_model(agent_config["path"], device)
    
    elif agent_config["type"] == "openai":
        # Create OpenAI agent
        if X is None or y is None or d is None:
            raise ValueError("X, y, and d are required for OpenAI models")
        return create_openai_model(agent_config["model_name"], X, y, d, device)
    
    else:
        raise ValueError(f"Unknown agent type: {agent_config['type']}. Use 'model' or 'openai'.")


def run_agent_interaction(model_agent1, model_agent2, X1, y1, w_star1, max_step1, X2, y2, w_star2, max_step2, d, learning_rate, device):
    """
    Run the interaction between two agents and return the results.
    
    Returns:
        dict: Dictionary containing interaction results
    """
    shared_w_history = []
    
    # Track distances from respective optimal weights (w_star) for each agent
    agent1_distances = []
    agent2_distances = []
    
    # Track gradient predictions for both agents
    g1_predictions = []
    g2_predictions = []
    
    # Initialize current weight
    w_current = torch.zeros(d, 1, device=device)
    
    using_openai = hasattr(model_agent1, 'model_name') or hasattr(model_agent2, 'model_name')
    
    if using_openai:
        # Limit interaction steps to 40 for faster execution with OpenAI API calls
        max_steps_for_interaction = min(40, max(max_step1, max_step2) + 400)
        print(f"Using OpenAI agents - limiting interaction to 100 steps for faster execution")
    else:
        max_steps_for_interaction = max(max_step1, max_step2) + 400
    
    # Run interaction loop
    print(f"Starting agent interaction loop for {max_steps_for_interaction} steps...")
    
    # Adjust print frequency based on total steps
    print_frequency = 1 if using_openai else 1000
    
    for step in range(max_steps_for_interaction):
        if step % print_frequency == 0:
            print(f"  Interaction step {step}/{max_steps_for_interaction}")
        
        # === AGENT 1 TURN ===
        Z1 = torch.cat((X1, y1.T))
        
        if step == 0:
            w_history_tensor = w_current
        else:
            w_history_tensor = torch.cat(shared_w_history, dim=1)
        
        bias_row = torch.ones((1, w_history_tensor.shape[1]), device=device)
        ww = torch.cat((w_history_tensor, bias_row), dim=0)
        
        # Create training data for agent 1
        train_data1 = torch.zeros((2 * (d + 1), X1.shape[1] + w_history_tensor.shape[1]), device=device)
        train_data1[:d+1, :X1.shape[1]] = Z1
        train_data1[d+1:, X1.shape[1]:] = ww
        
        # Agent 1 predicts gradient
        if step % print_frequency == 0:
            print(f"    Agent 1 predicting gradient at step {step}...")
        with torch.no_grad():
            g1_pred = model_agent1(train_data1)
        if step % print_frequency == 0:
            print(f"    Agent 1 gradient predicted: {g1_pred.shape}")
        
        # Store gradient prediction
        g1_predictions.append(g1_pred.clone())        
        w_current = w_current - learning_rate * g1_pred.unsqueeze(1)
        shared_w_history.append(w_current.clone())
        
        # Calculate distance from agent 1's optimal weight (w_star)
        agent1_dist = torch.norm(w_current.squeeze() - w_star1).item()
        agent1_distances.append(agent1_dist)
        
        # === AGENT 2 TURN ===
        Z2 = torch.cat((X2, y2.T))
        
        w_history_tensor = torch.cat(shared_w_history, dim=1)
        bias_row = torch.ones((1, w_history_tensor.shape[1]), device=device)
        ww = torch.cat((w_history_tensor, bias_row), dim=0)
        
        # Create training data for agent 2
        train_data2 = torch.zeros((2 * (d + 1), X2.shape[1] + w_history_tensor.shape[1]), device=device)
        train_data2[:d+1, :X2.shape[1]] = Z2
        train_data2[d+1:, X2.shape[1]:] = ww
        
        # Agent 2 predicts gradient
        if step % print_frequency == 0:
            print(f"    Agent 2 predicting gradient at step {step}...")
        with torch.no_grad():
            g2_pred = model_agent2(train_data2)
        if step % print_frequency == 0:
            print(f"    Agent 2 gradient predicted: {g2_pred.shape}")
        
        # Store gradient prediction
        g2_predictions.append(g2_pred.clone())        
        w_current = w_current - learning_rate * g2_pred.unsqueeze(1)
        shared_w_history.append(w_current.clone())
        
        # Calculate distance from agent 2's optimal weight (w_star)
        agent2_dist = torch.norm(w_current.squeeze() - w_star2).item()
        agent2_distances.append(agent2_dist)
    
    # Compute angle between optimal weights (w_star)
    cosine_sim = torch.dot(w_star1, w_star2) / (torch.norm(w_star1) * torch.norm(w_star2))
    cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)
    angle_deg = np.degrees(torch.acos(cosine_sim).item())
    
    # Compute theoretical errors using Theorem 1
    theoretical_results = compute_theoretical_errors_theorem1(
        X1, y1, X2, y2, w_star1, w_star2, learning_rate, device
    )
    
    return {
        'agent1_distances': agent1_distances,
        'agent2_distances': agent2_distances,
        'g1_predictions': g1_predictions,
        'g2_predictions': g2_predictions,
        'max_step1': max_step1,
        'max_step2': max_step2,
        'w_star1': w_star1,
        'w_star2': w_star2,
        'angle_deg': angle_deg,
        'theoretical_results': theoretical_results,
        'sample_count': 0 
    }


def find_convergence_step(g1_predictions, g2_predictions, threshold=1e-4):
    """Find step where gradients converged."""
    if not g1_predictions or not g2_predictions or len(g1_predictions) < 2:
        return None
        
    for step in range(1, min(len(g1_predictions), len(g2_predictions))):
        g1_diff = torch.norm(g1_predictions[step] - g1_predictions[step-1]).item()
        g2_diff = torch.norm(g2_predictions[step] - g2_predictions[step-1]).item()
        
        if g1_diff < threshold and g2_diff < threshold:
            return step + 1
            
    return None


def compute_angle_between_objectives(w_star1, w_star2):
    """Compute angle between two objective vectors in degrees."""
    cosine_sim = torch.dot(w_star1, w_star2) / (torch.norm(w_star1) * torch.norm(w_star2))
    cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)
    angle_deg = np.degrees(torch.acos(cosine_sim).item())
    return angle_deg
