import torch
from dataclasses import dataclass, field
from typing import Optional, Union, Dict
from torch.distributions import Normal, Independent
from .base import BaseTask

@dataclass
class GaussianConfig:
    dim: int = 10
    n_obs: int = 10
    prior_loc: float = 0.0
    prior_scale: float = 1.0
    sigma: float = 0.1 

class GaussianTask(BaseTask):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.gauss_cfg = GaussianConfig(
            dim=getattr(cfg, "dim", 10),
            n_obs=getattr(cfg, "n_obs", 10),
            prior_loc=getattr(cfg, "prior_loc", 0.0),
            prior_scale=getattr(cfg, "prior_scale", 1.0),
            sigma=getattr(cfg, "sigma", 0.1),
        )

    @property
    def prior(self):
        loc = torch.full((self.gauss_cfg.dim,), self.gauss_cfg.prior_loc)
        scale = torch.full((self.gauss_cfg.dim,), self.gauss_cfg.prior_scale)
        return Independent(Normal(loc, scale), 1)

    def sample_prior(self, n_samples: int, seed: int, 
                     loc: Optional[float] = None, 
                     scale: Optional[float] = None) -> torch.Tensor:
        device = self.device
        gen = torch.Generator(device=device).manual_seed(seed)

        _loc = loc if loc is not None else self.gauss_cfg.prior_loc
        _scale = scale if scale is not None else self.gauss_cfg.prior_scale
        
        return _loc + _scale * torch.randn(
            n_samples, self.gauss_cfg.dim, generator=gen, device=device
        )

    def simulate(self, theta: torch.Tensor, seed: int, 
                 sigma: Optional[float] = None) -> torch.Tensor:
        device = self.device
        M, D = theta.shape
        n_obs = self.gauss_cfg.n_obs
        
        _sigma = sigma if sigma is not None else self.gauss_cfg.sigma
        
        gen = torch.Generator(device=device).manual_seed(seed)
        eps = torch.randn(M, n_obs, D, device=device, generator=gen)
        
        return theta.unsqueeze(1) + _sigma * eps

    def compute_summary_statistics(self, x: torch.Tensor) -> torch.Tensor:
        """
        return flattened observations
        """
        return x.reshape(x.shape[0], -1)

    def get_posterior_params(self, X_obs: torch.Tensor, sigma: Optional[float] = None) -> Dict[str, torch.Tensor]:
        """
        For the model:
            Prior:      theta ~ N(mu_0, sigma_0^2 I)
            Likelihood: X_i | theta ~ N(theta, sigma^2 I)
        
        The posterior is:
            theta | X ~ N(mu_post, sigma_post^2 I)
        
        where:
            sigma_post^2 = (sigma_0^2 * sigma^2) / (sigma^2 + n * sigma_0^2)
            mu_post = (sigma^2 * mu_0 + n * sigma_0^2 * X_bar) / (sigma^2 + n * sigma_0^2)
        """
        single_instance = X_obs.dim() == 2
        if single_instance:
            X_obs = X_obs.unsqueeze(0)
        
        X_obs = X_obs.to(self.device)
        M, n_obs, D = X_obs.shape
        
        mu_0 = self.gauss_cfg.prior_loc
        sigma_0_sq = self.gauss_cfg.prior_scale ** 2
        _sigma = sigma if sigma is not None else self.gauss_cfg.sigma
        sigma_sq = _sigma ** 2
        n = n_obs
        
        sigma_post_sq = (sigma_0_sq * sigma_sq) / (sigma_sq + n * sigma_0_sq)
        sigma_post = torch.sqrt(torch.tensor(sigma_post_sq, device=self.device))
        
        X_bar = X_obs.mean(dim=1)
        mu_post = (sigma_sq * mu_0 + n * sigma_0_sq * X_bar) / (sigma_sq + n * sigma_0_sq)
        
        if single_instance:
            mu_post = mu_post.squeeze(0)
        
        return {"mean": mu_post, "std": sigma_post}

    def sample_posterior(
        self, 
        X_obs: torch.Tensor, 
        n_samples: int, 
        seed: int,
        sigma: Optional[float] = None
    ) -> torch.Tensor:
        device = self.device
        
        single_instance = X_obs.dim() == 2
        if single_instance:
            X_obs = X_obs.unsqueeze(0)  # [1, n_obs, D]
        
        X_obs = X_obs.to(device)
        M, n_obs, D = X_obs.shape
        
        params = self.get_posterior_params(X_obs, sigma=sigma)
        mu_post = params["mean"] 
        sigma_post = params["std"]
        
        gen = torch.Generator(device=device).manual_seed(seed)
        
        eps = torch.randn(M, n_samples, D, device=device, generator=gen)
        
        posterior_samples = mu_post.unsqueeze(1) + sigma_post * eps
        
        if single_instance:
            return posterior_samples.squeeze(0) 
        
        return posterior_samples 

    def generate_test_data(self, n_samples: int, seed: int, misspec_cfg: Optional[Dict] = None) -> Dict:
        m_cfg = misspec_cfg or {}
        m_type = m_cfg.get("type", "none")
        
        if m_type == "prior_location":
            shift = m_cfg.get("prior_location_shift", 0.0)
            theta = self.sample_prior(n_samples, seed, loc=shift)
        elif m_type == "prior_scale":
            factor = m_cfg.get("prior_scale_factor", 1.0)
            theta = self.sample_prior(n_samples, seed, scale=factor * self.gauss_cfg.prior_scale)
        else:
            theta = self.sample_prior(n_samples, seed)

        X_clean = self.simulate(theta, seed + 1)
        
        if m_type == "likelihood_scale":
            factor = m_cfg.get("likelihood_scale_factor", 1.0)
            X_obs = self.simulate(theta, seed + 1, sigma=factor * self.gauss_cfg.sigma)
        else:
            X_obs = X_clean.clone()

        misspec_mask = torch.zeros(n_samples, self.gauss_cfg.n_obs, dtype=torch.bool, device=self.device)
        
        if m_type == "contamination":
            eps = m_cfg.get("contamination_eps", 0.0)
            shift = m_cfg.get("contamination_shift", 2.0)
            
            mask_gen = torch.Generator(device=self.device).manual_seed(seed + 4)
            misspec_mask = torch.rand(n_samples, self.gauss_cfg.n_obs, generator=mask_gen, device=self.device) < eps
            
            if misspec_mask.any():
                num_contaminated = misspec_mask.sum().item()
                
                sign_gen = torch.Generator(device=self.device).manual_seed(seed + 5) # +/- 1
                signs = torch.where(
                    torch.rand(num_contaminated, generator=sign_gen, device=self.device) < 0.5,
                    torch.tensor(1.0, device=self.device),
                    torch.tensor(-1.0, device=self.device)
                )
                
                outlier_values = signs.unsqueeze(-1) * shift 
                
                X_obs[misspec_mask] = outlier_values

        S_obs = self.compute_summary_statistics(X_obs)

        return {
            "thetas": theta.cpu(),
            "X_clean": X_clean.cpu(),
            "X_obs": X_obs.cpu(),
            "S_obs": S_obs.cpu(),
            "misspec_mask": misspec_mask.cpu()
        }