import torch
from dataclasses import dataclass
from typing import Tuple, Optional, Dict
from .base import BaseTask

@dataclass
class SIRConfig:
    sigma: float = 0.05            
    eta: float = 0.05              
    T: int = 365                   
    N: int = 100_000               
    sir_0: Tuple[float, float, float] = (0.999, 0.001, 0.0) 
    dt: float = 0.1                 
    n_trajectories: int = 100       


class SIRSimulator:
    x_names = ["Mean", "Median", "Max", "Max Day", "Half Day", "Autocor"]
    theta_names = [r"$\beta$", r"$\gamma$"]
    
    def __init__(self, config: SIRConfig, device: str = 'cpu'):
        self.config = config
        self.device = device
        self.steps_per_day = int(1.0 / config.dt)
    
    def sample_prior(self, n: int, seed: Optional[int] = None) -> torch.Tensor:
        gen = torch.Generator(device=self.device).manual_seed(seed) if seed is not None else None
        
        u1 = torch.rand(n, device=self.device, generator=gen)
        u2 = torch.rand(n, device=self.device, generator=gen)
        
        beta = torch.sqrt(u1) / 2 # rejection sampling
        gamma = u2 * beta
        
        return torch.stack([beta, gamma], dim=1)
    
    def simulate(
        self, 
        theta: torch.Tensor, 
        n_trajectories: Optional[int] = None, 
        seed: Optional[int] = None
    ) -> torch.Tensor:
        gen = torch.Generator(device=self.device).manual_seed(seed) if seed is not None else None
        
        cfg = self.config
        n_traj = n_trajectories if n_trajectories is not None else cfg.n_trajectories
        batch_size = theta.shape[0]
        total_sims = batch_size * n_traj
        
        theta = theta.to(self.device)
        theta_expanded = theta.repeat_interleave(n_traj, dim=0)
        beta = theta_expanded[:, 0]
        gamma = theta_expanded[:, 1]
        
        R0_bar = beta / gamma 
        
        s = torch.full((total_sims,), cfg.sir_0[0], device=self.device)
        i = torch.full((total_sims,), cfg.sir_0[1], device=self.device)
        r = torch.full((total_sims,), cfg.sir_0[2], device=self.device)
        R0 = R0_bar.clone()
        
        infections = torch.zeros((total_sims, cfg.T), device=self.device)
        
        sqrt_dt = (cfg.dt ** 0.5)
        
        for day in range(cfg.T):
            infections[:, day] = i.clone()
            
            for _ in range(self.steps_per_day):
                beta_eff = gamma * R0
                
                ds = -beta_eff * s * i * cfg.dt
                di = (beta_eff * s * i - gamma * i) * cfg.dt
                dr = gamma * i * cfg.dt
                dR0_det = cfg.eta * (R0_bar - R0) * cfg.dt
                
                dW = torch.randn(total_sims, device=self.device, generator=gen) * sqrt_dt
                dR0_stoch = cfg.sigma * torch.sqrt(torch.abs(R0)) * dW
                
                s = s + ds
                i = i + di
                r = r + dr
                R0 = R0 + dR0_det + dR0_stoch
                
                s = torch.clamp(s, 0, 1)
                i = torch.clamp(i, 0, 1)
                r = torch.clamp(r, 0, 1)
                R0 = torch.clamp(R0, 0.01, 100)
        
        infections = infections * cfg.N
        
        infections = infections.view(batch_size, n_traj, cfg.T)
        
        return infections
    
    def misspecify(self, x: torch.Tensor, multiplier: float = 0.95) -> torch.Tensor:
        x = x.clone()
        T = self.config.T
        
        sat_idx = list(range(1, T, 7))
        sun_idx = list(range(2, T, 7))
        mon_idx = list(range(3, T, 7))
        
        n_weekends = min(len(sat_idx), len(sun_idx), len(mon_idx))
        sat_idx = sat_idx[:n_weekends]
        sun_idx = sun_idx[:n_weekends]
        mon_idx = mon_idx[:n_weekends]
        
        sat_orig = x[..., sat_idx]
        sun_orig = x[..., sun_idx]
        
        sat_new = sat_orig * multiplier
        sun_new = sun_orig * multiplier
        
        missed_cases = (sat_orig - sat_new) + (sun_orig - sun_new)
        
        x[..., sat_idx] = sat_new
        x[..., sun_idx] = sun_new
        x[..., mon_idx] = x[..., mon_idx] + missed_cases
        
        return x
    
    def summarise(self, x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
        x_safe = torch.clamp(x, min=eps)
        
        mean_inf = torch.log(x_safe.mean(dim=-1))
        
        median_inf = torch.log(x_safe.median(dim=-1).values)
        
        max_inf = torch.log(x_safe.max(dim=-1).values)
        
        max_day = torch.log(x_safe.argmax(dim=-1).float() + 1)
        
        cumsum = torch.cumsum(x_safe, dim=-1)
        total = cumsum[..., -1:]
        prop = cumsum / (total + eps)
        half_day = (prop > 0.5).float().argmax(dim=-1)
        half_day = torch.log(half_day.float() + 1)
        
        x1 = x[..., :-1]
        x2 = x[..., 1:]
        x1_centered = x1 - x1.mean(dim=-1, keepdim=True)
        x2_centered = x2 - x2.mean(dim=-1, keepdim=True)
        numerator = (x1_centered * x2_centered).sum(dim=-1)
        denominator = torch.sqrt((x1_centered**2).sum(dim=-1) * (x2_centered**2).sum(dim=-1))
        autocorr = numerator / (denominator + eps)
        
        return torch.stack([mean_inf, median_inf, max_inf, max_day, half_day, autocorr], dim=-1)

class SIRTask(BaseTask):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.sir_config = SIRConfig(
            sigma=getattr(cfg, "sigma", 0.05),
            eta=getattr(cfg, "eta", 0.05),
            T=getattr(cfg, "T", 365),
            N=getattr(cfg, "N", 100_000),
            sir_0=getattr(cfg, "sir_0", (0.999, 0.001, 0.0)),
            dt=getattr(cfg, "dt", 0.1),
            n_trajectories=getattr(cfg, "n_trajectories", 100)
        )
        self.simulator = SIRSimulator(self.sir_config, device=self.device)

    def sample_prior(self, n_samples: int, seed: int) -> torch.Tensor:
        return self.simulator.sample_prior(n_samples, seed)

    def simulate(self, theta: torch.Tensor, seed: int) -> torch.Tensor:
        return self.simulator.simulate(theta, seed=seed)

    def compute_summary_statistics(self, x: torch.Tensor) -> torch.Tensor:
        S_per_traj = self.simulator.summarise(x)
        return S_per_traj.mean(dim=1)

    def generate_test_data(self, n_samples: int, seed: int, misspec_cfg: Optional[Dict] = None) -> Dict:
        misspec_cfg = misspec_cfg or {}
        m_type = misspec_cfg.get("type", "none")
        
        theta = self.sample_prior(n_samples, seed)
        
        X_clean = self.simulate(theta, seed + 1)
        
        X_obs = X_clean.clone()
        misspec_mask = torch.zeros(n_samples, self.sir_config.n_trajectories, dtype=torch.bool, device=self.device)

        if m_type == "contamination":
             eps = misspec_cfg.get("contamination_eps", 0.0)
             multiplier = misspec_cfg.get("multiplier", 0.95)
             
             mask_gen = torch.Generator(device=self.device).manual_seed(seed + 10000)
             k = int(eps * self.sir_config.n_trajectories)
             for i in range(n_samples):
                 perm = torch.randperm(self.sir_config.n_trajectories, generator=mask_gen, device=self.device)
                 misspec_mask[i, perm[:k]] = True
             
             for i in range(n_samples):
                 if misspec_mask[i].any():
                     X_obs[i, misspec_mask[i]] = self.simulator.misspecify(
                         X_clean[i, misspec_mask[i]], 
                         multiplier=multiplier
                     )
        
        S_obs = self.compute_summary_statistics(X_obs)

        return {
            "thetas": theta.cpu(),
            "X_clean": X_clean.cpu(),
            "X_obs": X_obs.cpu(),
            "S_obs": S_obs.cpu(),
            "misspec_mask": misspec_mask.cpu()
        }