import numpy as np
import torch
from typing import List, Union, Optional
from openai import OpenAI
from pydantic import BaseModel, Field
import os


class GradientResponse(BaseModel):
    thinking: str = Field(description="Use this field as scratchpad to write your thoughts and calculations for the gradient descent algorithm.")
    gradient_next: List[float] = Field(description="The d-dimensional gradient vector ∇L = X(X^T w - y). Must have exactly d elements, where d is the number of features.")

class GradientAgent:
    def __init__(self, model_name: str, X: Union[List[List[float]], np.ndarray], y: Union[List[List[float]], np.ndarray]):
        self.model_name = model_name
        self.X = np.array(X) if not isinstance(X, np.ndarray) else X
        self.y = np.array(y) if not isinstance(y, np.ndarray) else y
        
        # Initialize OpenAI client
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable is required")
        
        self.client = OpenAI(api_key=api_key)
        
        # Get dimensions 
        d, n = self.X.shape
        
        self.system_prompt = f"""You are an expert optimization agent working on linear regression gradient descent.

        PROBLEM SETUP:
        - Input features X: {d}×{n} matrix (values provided in each request)
        - Target values y: {n}-dimensional vector (values provided in each request)
        - Current weight w: {d}-dimensional vector (what you'll receive)

        TASK: Calculate the gradient ∇L with respect to w, where L = ||X^T w - y||²

        FORMULA: ∇L = X(X^T w - y)
        - X^T w produces an {n}-dimensional vector (predictions)
        - X^T w - y produces an {n}-dimensional vector (residuals)  
        - X @ (residuals) produces a {d}-dimensional vector (gradient)

        CRITICAL: 
        1. Use the EXACT X and y matrices provided in each request
        2. Your output gradient must be exactly {d}-dimensional
        3. Do NOT make up dummy data - use the actual matrices given
        4. Perform the calculation step by step

        The user will provide w_current and the matrices X, y. Calculate and return the {d}-dimensional gradient vector, do not ask the user to validate what is to be done. The user will not be able to interact with you. Be highly precise and accurate on your computations, you will be evaluated on the distance with the ground truth gradient."""
        
        self.chat_messages = [{"role": "system", "content": self.system_prompt}]

    def __call__(self, w: Union[List[float], np.ndarray], w_history: Union[List[float], np.ndarray]) -> np.ndarray:
        w_array = np.array(w) if not isinstance(w, np.ndarray) else w
        w_history_array = np.array(w_history) if not isinstance(w_history, np.ndarray) else w_history
        w_list = w_array.tolist()
        w_history_list = w_history_array.tolist()
        d = self.X.shape[0]
        
        user_message = f"""Calculate the gradient for:
                    w_current = {w_list} (a {d}-dimensional vector)
                    X = {self.X.tolist()}   
                    y = {self.y.flatten().tolist()}
        

                    note that you also have access to the history of the predicted weights up to w_current:
                    w_history = {w_history_list} 
                    You will for sure do all the require computations without asking me to validate and return the result, that is return exactly {d} gradient values using the formula ∇L = X(X^T w - y). In your message response, it is mandatory to output a gradient vector."""
        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.responses.parse(
                    model=self.model_name,
                    input=self.chat_messages + [{"role": "user", "content": user_message}],
                    reasoning={ "effort": "low" },
                    text_format=GradientResponse
                )
                if response.output_parsed is None:
                    print(f"Attempt {attempt + 1}: OpenAI response parsing failed.")
                    if attempt < max_retries - 1:
                        print("Retrying...")
                        continue
                    else:
                        raise ValueError("OpenAI failed to parse response after all retries")
                gradient = response.output_parsed.gradient_next
                if gradient is None or len(gradient) == 0:
                    print(f"Attempt {attempt + 1}: OpenAI returned empty gradient.")
                    if attempt < max_retries - 1:
                        print("Retrying...")
                        continue
                    else:
                        raise ValueError("OpenAI returned empty gradient after all retries")
                
                if len(gradient) != self.X.shape[0]:
                    print(f"Attempt {attempt + 1}: OpenAI returned gradient with wrong size: {len(gradient)} (expected {self.X.shape[0]})")
                    if attempt < max_retries - 1:
                        print("Retrying...")
                        continue
                    else:
                        raise ValueError(f"OpenAI returned wrong gradient size after all retries: {len(gradient)} != {self.X.shape[0]}")
                
                gradient_array = np.array(gradient)
                return gradient_array
                
            except Exception as e:
                print(f"Attempt {attempt + 1}: Error calling OpenAI API: {e}")
                if attempt < max_retries - 1:
                    print("Retrying...")
                    continue
                else:
                    print("All retry attempts failed.")
                    raise e


class OpenAICoTWrapper:
    """
    Wrapper class that mimics the model.CoT interface but uses OpenAI for gradient prediction.
    This allows seamless integration with existing agent_to_agent_flow.py code.
    """
    
    def __init__(self, model_name: str, X: torch.Tensor, y: torch.Tensor, d: int, device: str = 'cpu'):
        """
        Initialize OpenAI wrapper with dataset information.
        
        Args:
            model_name: OpenAI model name (e.g., "gpt-4o-mini")
            X: Input features tensor (d x n)
            y: Target values tensor (n,)
            d: Dimension of the problem
            device: Device (for compatibility, not used by OpenAI)
        """
        self.device = device
        self.d = d
        self.model_name = model_name  
        
        X_np = X.cpu().numpy() if isinstance(X, torch.Tensor) else X
        y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
        
        self.agent = GradientAgent(model_name=model_name, X=X_np, y=y_np)
    
    def forward(self, train_data: torch.Tensor) -> torch.Tensor:
        """
        Forward pass that mimics model.CoT.forward() interface.
        
        Args:
            train_data: Complex training data tensor from agent_to_agent_flow.py
                       Shape: (2*(d+1), n_data + n_history)
                       First d+1 rows: [X; y^T] (data)
                       Last d+1 rows: [w_history; bias_row] (weight history)
        
        Returns:
            torch.Tensor: Gradient prediction (d,)
        """

        w_history_part = train_data[self.d+1:, :]  # Shape: (d+1, n_data + n_history)
        w_current = w_history_part[:-1, -1]  # Last column, excluding bias row
        w_history = w_history_part[:-1, :-1]
        w_current_np = w_current.cpu().numpy()
        if  w_current_np.shape[0] != self.d:
            print(f"Warning: w shape mismatch. Expected {self.d}, got {w_current_np.shape[0]}")
        gradient_np = self.agent(w_current_np, w_history)
        
        if gradient_np.shape[0] != self.d:
            print(f"Warning: gradient shape mismatch. Expected {self.d}, got {gradient_np.shape[0]}")
            
        
        # Convert back to torch tensor on the correct device
        gradient_tensor = torch.tensor(gradient_np, dtype=torch.float32, device=self.device)
        
        return gradient_tensor
    
    def __call__(self, train_data: torch.Tensor) -> torch.Tensor:
        """Allow calling the wrapper like a function, similar to PyTorch modules."""
        return self.forward(train_data)
    
    def to(self, device: str):
        """For compatibility with PyTorch model interface."""
        self.device = device
        return self
    
    def eval(self):
        """For compatibility with PyTorch model interface."""
        return self


def create_openai_agent(model_name: str, X: torch.Tensor, y: torch.Tensor, d: int, device: str = 'cpu') -> OpenAICoTWrapper:
    """
    Factory function to create OpenAI agent wrapper.
    
    Args:
        model_name: OpenAI model name
        X: Input features tensor
        y: Target values tensor  
        d: Problem dimension
        device: Device string
        
    Returns:
        OpenAICoTWrapper: Wrapper that can be used as drop-in replacement for model.CoT
    """
    return OpenAICoTWrapper(model_name=model_name, X=X, y=y, d=d, device=device)
