import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random

# --- 1. Transformer Model ---
class SimplifiedTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads=1, ff_dim=128):
        super(SimplifiedTransformer, self).__init__()
        self.embed_dim = embed_dim
        
        # Single-head self-attention layer
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        
        # Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # Final linear layer to predict the scalar output
        self.output_layer = nn.Linear(embed_dim, 1)

    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        
        # Self-attention
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)
        
        # Feedforward network
        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)
        
        # We only care about the prediction for the last token (the query)
        query_token_embedding = x[:, -1, :]
        
        # Predict the output value
        prediction = self.output_layer(query_token_embedding)
        return prediction

# --- 2. Task Functions ---

class CosineTask:
    """The cosine function used in the paper's experiments."""
    def __init__(self, L):
        self.L = L
        # Choose a random frequency for the cosine function
        self.c = random.uniform(1.0, 5.0)

    def __call__(self, x):
        if not hasattr(self, 'direction'):
            self.direction = torch.randn(x.shape[-1])
            self.direction /= torch.norm(self.direction)
        
        dot_product = torch.einsum('...d,d->...', x, self.direction.to(x.device))
        return (self.L / self.c) * torch.cos(self.c * dot_product)

class PolynomialTask:
    """An alternative nonlinear function: a random 2nd-degree polynomial."""
    def __init__(self, L, dim):
        self.L = L # Note: L is not strictly enforced here, just a scaling factor
        # Generate random coefficients for a quadratic function: x^T A x + b^T x
        self.A = torch.randn(dim, dim) * L / (dim) # Scale to control magnitude
        self.b = torch.randn(dim) * L / np.sqrt(dim)

    def __call__(self, x):
        # x shape can be (..., dim)
        x_device = x.device
        A = self.A.to(x_device)
        b = self.b.to(x_device)
        
        # Calculate x^T * A * x
        quad_term = torch.einsum('...i,ij,...j->...', x, A, x)
        # Calculate b^T * x
        linear_term = torch.einsum('...i,i->...', x, b)
        
        return quad_term + linear_term

# --- 3. Data Generation ---
def generate_prompts(config, task_function):
    """
    Generates a batch of prompts for in-context learning.
    Each prompt is a sequence of (x_i, y_i) pairs plus a query (x_query, 0).
    """
    M = config['M']
    N = config['N']
    d = config['d']
    
    # Generate feature vectors
    if config['use_continuous_features']:
        features = torch.randn(M, N + 1, d)
    else:
        K = config['K']
        feature_set = torch.randn(K, d)
        indices = torch.randint(0, K, (M, N + 1))
        features = feature_set[indices]
    
    x_query = features[:, -1, :]
    
    ys = task_function(features[:, :-1, :]).unsqueeze(-1)
    
    y_query_true = task_function(x_query).unsqueeze(-1)

    prompt_xs = features
    prompt_ys = torch.cat([ys, torch.zeros(M, 1, 1)], dim=1) # y_query is 0
    
    prompts = torch.cat([prompt_xs, prompt_ys], dim=-1)
    
    return prompts, y_query_true

# --- 4. Main Experiment ---
def run_experiment(config):
    """
    Runs the full training loop for a given configuration and plots the loss.
    """
    print("Starting experiment with the following configuration:")
    print(config)

    regimes = {
        "Flat L-Regime": config['L_flat'],
        "Sharp L-Regime": config['L_sharp']
    }

    for regime_name, L_values in regimes.items():
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        # fig.suptitle(
        #     f"Convergence Dynamics: {regime_name} (Avg over {config['num_runs']} runs)", 
        #     fontsize=16
        # )
        print(f"\n--- Running Experiment for {regime_name} ---")

        colors = plt.cm.viridis(np.linspace(0, 1, len(L_values)))
        
        # Iterate over L_values in reverse for plotting
        for l_idx, L in enumerate(L_values[::-1]):
            print(f"  Running {config['num_runs']} trials for L = {L}...")
            
            all_loss_histories = []
            
            for run in range(config['num_runs']):
                print(f"    Trial [{run + 1}/{config['num_runs']}]...", end='\r')
                
                model = SimplifiedTransformer(embed_dim=config['d'] + 1)
                optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
                loss_fn = nn.MSELoss()
                loss_history = []
                
                if config['function_type'] == 'cosine':
                    task = CosineTask(L=L)
                else:
                    task = PolynomialTask(L=L, dim=config['d'])

                for epoch in range(config['T']):
                    prompts, y_query_true = generate_prompts(config, task)
                    optimizer.zero_grad()
                    y_query_pred = model(prompts)
                    loss = loss_fn(y_query_pred, y_query_true)
                    loss.backward()
                    optimizer.step()
                    loss_history.append(loss.item())
                
                all_loss_histories.append(loss_history)
            
            print(f"    All trials for L = {L} complete. Plotting results.   ")
            
            loss_array = np.array(all_loss_histories)
            mean_loss = np.mean(loss_array, axis=0)
            std_loss = np.std(loss_array, axis=0)
            
            plot_color = colors[l_idx]
            
            legend_L = L_values[l_idx]

            ax.plot(mean_loss, label=f'L = {legend_L}', color=plot_color)
            ax.fill_between(
                range(config['T']),
                mean_loss - std_loss,
                mean_loss + std_loss,
                alpha=0.2,
                color=plot_color
            )
        
        # ax.set_title(f"{config['function_type'].capitalize()} Task | Continuous Features: {config['use_continuous_features']}")
        ax.set_xlabel('Epoch',fontsize=24)
        ax.set_ylabel('Prediction Loss (MSE)',fontsize=24)
        ax.set_yscale('log')
        
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, title_fontsize=26, fontsize=24)
        
        ax.grid(True, which="both", ls="--", alpha=0.5)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()


if __name__ == '__main__':
    experiment_config = {
        # Data parameters
        "d": 30,
        "K": 4,
        "N": 150,
        "M": 500,
        
        # Training parameters
        "T": 1000,
        "learning_rate": 3e-4,
        "num_runs": 30,
        
        # Experiment-specific parameters split into regimes
        "L_flat": [0.1, 0.2, 0.4],
        "L_sharp": [1.0, 1.5, 2.0],
        
        "use_continuous_features": True,
        "function_type": "cosine", 
    }

    run_experiment(experiment_config)

