import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional
from torch.distributions import Uniform, Independent
from .base import BaseTask

@dataclass
class OUPConfig:
    n_timesteps: int = 25
    n_trajectories: int = 100
    var: float = 0.1
    T: float = 5.0
    y0: float = 10.0
    theta1_min: float = 0.0
    theta1_max: float = 2.0
    theta2_min: float = -2.0
    theta2_max: float = 2.0

class OUPTask(BaseTask):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.oup_cfg = OUPConfig(
            n_timesteps=cfg.n_timesteps,
            n_trajectories=cfg.n_trajectories,
            var=cfg.var,
            T=cfg.T,
            y0=cfg.y0,
            theta1_min=cfg.theta1_min,
            theta1_max=cfg.theta1_max,
            theta2_min=cfg.theta2_min,
            theta2_max=cfg.theta2_max,
        )
        self.dt = self.oup_cfg.T / (self.oup_cfg.n_timesteps + 1)

    @property
    def prior(self):
        c = self.oup_cfg
        low = torch.tensor([c.theta1_min, c.theta2_min])
        high = torch.tensor([c.theta1_max, c.theta2_max])
        return Independent(Uniform(low, high), 1)

    def sample_prior(self, n_samples: int, seed: int) -> torch.Tensor:
        gen = torch.Generator().manual_seed(seed)
        c = self.oup_cfg
        theta1 = torch.rand(n_samples, generator=gen) * (c.theta1_max - c.theta1_min) + c.theta1_min
        theta2 = torch.rand(n_samples, generator=gen) * (c.theta2_max - c.theta2_min) + c.theta2_min
        return torch.stack([theta1, theta2], dim=1).to(self.device)

    def simulate(self, theta: torch.Tensor, seed: int, var: Optional[float] = None) -> torch.Tensor:
        device = self.device
        theta = theta.to(device)
        M = theta.shape[0] # batch size
        N = self.oup_cfg.n_trajectories
        n_steps = self.oup_cfg.n_timesteps
        
        sim_var = var if var is not None else self.oup_cfg.var

        gen = torch.Generator(device=device).manual_seed(seed)
        
        theta1 = theta[:, 0].unsqueeze(1)
        theta2 = torch.exp(theta[:, 1]).unsqueeze(1)
        
        Y = torch.zeros(M, N, n_steps, device=device)
        
        Y[:, :, 0] = self.oup_cfg.y0
        current_Y = Y[:, :, 0]
        
        sqrt_dt = self.dt ** 0.5
        
        for t in range(n_steps - 1):
            w = torch.randn(M, N, device=device, generator=gen) * sim_var
            
            mu = theta1 * (theta2 - current_Y) * self.dt
            sigma = 0.5 * sqrt_dt * w
            
            current_Y = current_Y + mu + sigma
            
            Y[:, :, t + 1] = current_Y
            
        return Y

    def compute_summary_statistics(self, x: torch.Tensor) -> torch.Tensor:
        return self._oup_moment_summary(x)

    def generate_test_data(self, n_samples: int, seed: int, misspec_cfg: Optional[dict] = None) -> dict:
        theta = self.sample_prior(n_samples, seed)
        X_clean = self.simulate(theta, seed + 1)
        
        X_obs = X_clean.clone()
        misspec_mask = torch.zeros(n_samples, self.oup_cfg.n_trajectories, dtype=torch.bool, device=self.device)

        if misspec_cfg is not None and misspec_cfg.get("type") == "contamination":
            eps = misspec_cfg.get("contamination_eps", misspec_cfg.get("eps", 0.0))
            var_contam = misspec_cfg.get("var_contam", self.oup_cfg.var)
            
            mask_gen = torch.Generator(device=self.device).manual_seed(seed + 999)
            k = int(eps * self.oup_cfg.n_trajectories)
            for i in range(n_samples):
                perm = torch.randperm(self.oup_cfg.n_trajectories, generator=mask_gen, device=self.device)
                misspec_mask[i, perm[:k]] = True
            
            if k > 0:
                if "theta_contam" in misspec_cfg:
                    t_contam = torch.tensor(misspec_cfg["theta_contam"], device=self.device)
                    theta_for_sim = t_contam.unsqueeze(0).expand(n_samples, -1)
                else:
                    theta_for_sim = theta

                X_contam_full = self.simulate(theta_for_sim, seed=seed + 2, var=var_contam)
                X_obs[misspec_mask] = X_contam_full[misspec_mask]

        S_obs = self.compute_summary_statistics(X_obs)

        return {
            "thetas": theta.cpu(),
            "X_clean": X_clean.cpu(),
            "X_obs": X_obs.cpu(),
            "S_obs": S_obs.cpu(),
            "misspec_mask": misspec_mask.cpu()
        }

    def _oup_moment_summary(self, Y: torch.Tensor, K: Optional[int] = None, eps: float = 1e-8) -> torch.Tensor:
        if Y.dim() == 2:
            Y = Y.unsqueeze(0)
        
        M, N, n = Y.shape
        if K is None:
            K = max(5, n // 3)
        
        summaries = torch.zeros(M, 3, device=Y.device)
        
        W = Y[:, :, -K:] # use lag
        
        mu = W.mean(dim=(1, 2))
        var = W.var(dim=(1, 2))
        
        for i in range(M):
            W_i = Y[i, :, -K:]
            mu_i = W_i.mean()
            var_i = ((W_i - mu_i) ** 2).mean()
            
            # lag 1 autocorr
            x_flat = W_i[:, :-1].reshape(-1)
            y_flat = W_i[:, 1:].reshape(-1)
            mx, my = x_flat.mean(), y_flat.mean()
            cov = ((x_flat - mx) * (y_flat - my)).mean()
            vx = ((x_flat - mx) ** 2).mean()
            vy = ((y_flat - my) ** 2).mean()
            rho1 = cov / (torch.sqrt(vx * vy) + eps)
            
            summaries[i] = torch.stack([mu_i, var_i, rho1])
            
        return summaries