import torch
import numpy as np
import scipy.stats
from typing import Optional, Dict
from .base import BaseTask
import sys
sys.path.insert(0, './external/cryoSBI/src')

from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator


class CryoEMTask(BaseTask):
    def __init__(self, cfg):
        super().__init__(cfg)
        config_path = getattr(cfg, "sim_config_path", 
                             "./configs/cyro_em_sim_config.json")
        self.simulator = CryoEmSimulator(config_path, device=self.device)
        
        self.num_models = self.simulator.max_index + 1
        self.num_images_per_obs = getattr(cfg, "n_obs", 100)  # N images per observation
        self.num_pixels = self.simulator._config["N_PIXELS"]
        
    def sample_prior(self, n_samples: int, seed: int) -> torch.Tensor:
        gen = torch.Generator(device=self.device).manual_seed(seed)
        theta = torch.randint(0, self.num_models, (n_samples,), 
                              generator=gen, device=self.device)
        return theta.float().unsqueeze(-1)  # (n_samples, 1)
    
    def simulate(self, theta: torch.Tensor, seed: int) -> torch.Tensor:
        _ = seed  # Note: no seed usage in simulator
        batch_size = theta.shape[0]
        all_observations = []
        
        for i in range(batch_size):
            theta_i = int(torch.clamp(theta[i, 0], 0, self.num_models - 1).round().item())
            indices = torch.full((self.num_images_per_obs,), theta_i, 
                                 device=self.device, dtype=torch.long)
            indices_float = indices.float().unsqueeze(-1)
            
            images = self.simulator.simulate(
                num_sim=self.num_images_per_obs,
                indices=indices_float,
                return_parameters=False,
                batch_size=min(self.num_images_per_obs, 32)
            )
            all_observations.append(images)
        
        return torch.stack(all_observations)  # (batch, N, H, W)
    
    def compute_summary_statistics(self, x: torch.Tensor) -> torch.Tensor:
        if isinstance(x, torch.Tensor):
            images = x.cpu().numpy()
        else:
            images = x
        
        single_input = images.ndim == 3
        if single_input:
            images = images[np.newaxis, ...]
        
        M, N, H, W = images.shape
        flat = images.reshape(M, N, -1)  # (M, N, H*W)
        
        per_image_skew = scipy.stats.skew(flat, axis=2)     
        per_image_kurt = scipy.stats.kurtosis(flat, axis=2) 
        per_image_max = flat.max(axis=2)
        per_image_min = flat.min(axis=2)
        per_image_range = per_image_max - per_image_min    
        
        result = np.stack([
            per_image_skew.mean(axis=1),    
            per_image_kurt.mean(axis=1),    
            per_image_range.mean(axis=1),   
            per_image_skew.std(axis=1),     
            per_image_kurt.std(axis=1),     
            per_image_range.std(axis=1),    
        ], axis=1)  # (M, 6)
        
        if single_input:
            result = result[0]
    
        return torch.from_numpy(result).float()
    
    def _generate_pure_noise(self, num_images: int, seed: Optional[int] = None) -> torch.Tensor:
        if seed is not None:
            gen = torch.Generator().manual_seed(seed)
            noise = torch.randn(num_images, self.num_pixels, self.num_pixels, generator=gen)
        else:
            noise = torch.randn(num_images, self.num_pixels, self.num_pixels)
        mean = noise.mean(dim=[1, 2], keepdim=True)
        std = noise.std(dim=[1, 2], keepdim=True)
        return (noise - mean) / (std + 1e-8)
    
    def generate_test_data(self, n_samples: int, seed: int, misspec_cfg: Optional[Dict] = None) -> Dict:
        misspec_cfg = misspec_cfg or {}
        m_type = misspec_cfg.get("type", "none")
        
        theta = self.sample_prior(n_samples, seed)
        
        X_clean = self.simulate(theta, seed + 1)
        
        X_obs = X_clean.clone()
        misspec_mask = torch.zeros(n_samples, self.num_images_per_obs, dtype=torch.bool)
        
        if m_type == "contamination":
            eps = misspec_cfg.get("contamination_eps", 0.0)
            
            if eps > 0:
                mask_gen = torch.Generator().manual_seed(seed + 10000)
                k = int(eps * self.num_images_per_obs)
                for i in range(n_samples):
                    perm = torch.randperm(self.num_images_per_obs, generator=mask_gen)
                    misspec_mask[i, perm[:k]] = True
                
                for i in range(n_samples):
                    num_contam = k
                    if num_contam > 0:
                        noise_seed = seed + 20000 + i
                        noise_images = self._generate_pure_noise(int(num_contam), seed=noise_seed)
                        X_obs[i, misspec_mask[i]] = noise_images
        
        S_obs = self.compute_summary_statistics(X_obs)
        
        return {
            "thetas": theta.cpu(),
            "X_clean": X_clean.cpu(),
            "X_obs": X_obs.cpu(),
            "S_obs": S_obs.cpu(),
            "misspec_mask": misspec_mask.cpu()
        }