from abc import ABC, abstractmethod
import torch
from torch.distributions import Distribution
from typing import Optional, Tuple, Dict, Dict

class BaseTask(ABC):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = getattr(cfg, "device", "cpu")

    @property
    def prior(self) -> Distribution:
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement the 'prior' property "
            "to use with NPE-PFN."
        )

    @abstractmethod
    def sample_prior(self, n_samples: int, seed: int) -> torch.Tensor:
        pass

    @abstractmethod
    def simulate(self, theta: torch.Tensor, seed: int) -> torch.Tensor:
        pass

    def compute_summary_statistics(self, x: torch.Tensor) -> torch.Tensor:
        return x
    
    def generate_train_data(self, n_samples: int, seed: int) -> dict:
        theta = self.sample_prior(n_samples, seed)
        x = self.simulate(theta, seed + 1)
        s = self.compute_summary_statistics(x)
        return {"thetas": theta, "X": x, "S": s}

    def generate_test_data(self, n_samples: int, seed: int, misspec_cfg=None) -> Dict[str, torch.Tensor]:
        theta = self.sample_prior(n_samples, seed)
        x_clean = self.simulate(theta, seed + 1)
        x_obs = x_clean
        s_obs = self.compute_summary_statistics(x_obs)
        
        return {
            "thetas": theta.cpu(),
            "X_clean": x_clean.cpu(),
            "X_obs": x_obs.cpu(),
            "S_obs": s_obs.cpu()
        }