#!/usr/bin/env python3
"""
Models for basic RandOpt for toy experiments
"""

import torch
import torch.nn as nn


def positional_encoding(x, pos_encoding_dim=16, device=None):
    """Apply sinusoidal positional encoding to input tensor.
    
    Args:
        x: tensor of shape [batch_size, seq_len] or [seq_len]
        pos_encoding_dim: dimension of positional encoding
        device: device to put the encoding on
    
    Returns:
        tensor of shape [batch_size, seq_len * pos_encoding_dim] or [seq_len * pos_encoding_dim]
    """
    if device is None:
        device = x.device
    
    was_1d = x.dim() == 1
    if was_1d:
        x = x.unsqueeze(0)
    
    batch_size, seq_len = x.shape
    
    # Vectorized: create div_term [pos_encoding_dim]
    i = torch.arange(pos_encoding_dim, device=device, dtype=torch.float32)
    div_term = 10000 ** ((i // 2 * 2) / pos_encoding_dim)
    
    # x: [batch, seq] -> [batch, seq, 1], div_term: [dim] -> [1, 1, dim]
    scaled = x.unsqueeze(-1) / div_term.view(1, 1, -1)  # [batch, seq, dim]
    
    # Apply sin to even indices, cos to odd indices
    pos_enc = torch.where(i % 2 == 0, torch.sin(scaled), torch.cos(scaled))
    pos_enc = pos_enc.view(batch_size, seq_len * pos_encoding_dim)
    
    return pos_enc.squeeze(0) if was_1d else pos_enc


class Net(nn.Module):
    def __init__(self, width, depth, dim_in, dim_out, device, pos_encoding_dim):
        super().__init__()
        self.width, self.depth, self.dim_out = width, depth, dim_out
        self.device, self.pos_encoding_dim = device, pos_encoding_dim
        self.ctx_sz = dim_in - 1
        
        # Build layers: input -> [hidden + ReLU] * (depth-1) -> output
        adjusted_dim_in = self.ctx_sz + pos_encoding_dim
        layers = [nn.Linear(adjusted_dim_in, width, device=device)]
        for _ in range(depth - 2):
            layers.extend([nn.ReLU(), nn.Linear(width, width, device=device)])
        layers.extend([nn.ReLU(), nn.Linear(width, dim_out, device=device)])
        self.layers = nn.ModuleList(layers)

    def forward(self, ctx):
        """Forward pass. Predicts y_next given ctx.
        
        Args:
            ctx: tensor of shape [batch_size, ctx_sz] or [ctx_sz] (raw values)
        """
        # Handle both batched and non-batched inputs
        was_1d = ctx.dim() == 1
        if was_1d:
            if ctx is not None:
                ctx = ctx.unsqueeze(0)
        
        h = ctx
        for layer in self.layers:
            h = layer(h)
        
        if was_1d:
            h = h.squeeze(0)
        
        return h.squeeze(-1)
    
    def compute_loss(self, ctx, y):
        """Compute loss.
        
        Args:
            ctx: tensor of shape [batch_size, ctx_sz]
            y: tensor of shape [batch_size, 1] or [batch_size]
        """
        y_pred = self.forward(ctx)  # [batch_size]
        return nn.MSELoss()(y_pred, y.squeeze(-1))

    def init_weights(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def perturb_weights(self, seed, sigma):
        torch.manual_seed(seed)
        for p in self.parameters():
            p.data.add_(torch.randn_like(p.data) * sigma)

    def AR_rollout(self, ctx, T):
        """AR rollout.
        
        Args:
            ctx: tensor of shape [batch_size, ctx_sz]
            T: number of steps to roll out

        Returns:
            y_preds: tensor of shape [batch_size, T]
        """
        y_preds = []
        for t in range(T):
            y_pred = self.forward(ctx)  # [batch_size]
            ctx = torch.cat([ctx, y_pred.unsqueeze(-1)], dim=1)
            ctx = ctx[:, 1:]
            y_preds.append(y_pred)

        y_preds = torch.stack(y_preds, dim=1)  # [batch_size, T]
        
        return y_preds