import torch
import torch.nn as nn
from typing import Optional, Union, Tuple, Dict
import numpy as np

class TTAPosterior:
    def __init__(self, npe_model: nn.Module, tta_adapter, device: str = "cpu", bypass_embedding: bool = False):
        self.npe_model = npe_model
        self.tta_adapter = tta_adapter
        self.device = device
        self.bypass_embedding = bypass_embedding 
        self.npe_model.to(device)
        self.npe_model.eval()

    def sample(
        self, 
        x_obs: Union[torch.Tensor, np.ndarray], 
        n_samples: int,
        adapt: bool = True,
        return_info: bool = False,
        s_obs: Optional[Union[torch.Tensor, np.ndarray]] = None,
        **kwargs
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
        x_obs_np, x_batch = self._prepare_obs(x_obs)
        
        if s_obs is not None:
            if isinstance(s_obs, np.ndarray):
                s_init = torch.from_numpy(s_obs).float()
            else:
                s_init = s_obs
            if s_init.ndim == 2 and s_init.shape[0] == 1:
                s_init = s_init.squeeze(0)
            elif s_init.ndim > 1:
                raise ValueError("s_obs must be a 1D summary vector or shape (1, H).")
            s_init = s_init.to(self.device)
        else:
            embedding_net = self.npe_model._embedding_net
            with torch.no_grad():
                s_init = embedding_net(x_batch).squeeze(0)

        adapt_result = None
        if adapt:
            adapt_result = self.tta_adapter.adapt(s_init, x_obs_np, **kwargs)
            s_adapted = adapt_result['best_s'].to(self.device)
        else:
            s_adapted = s_init

        samples = self._sample_with_summary(s_adapted, n_samples)
        
        if return_info:
            return samples, adapt_result
        return samples

    def set_config(self, config):
        self.tta_adapter.config = config

    def _prepare_obs(self, x_obs):
        if isinstance(x_obs, np.ndarray):
            x_obs_np = x_obs
            x_obs_torch = torch.from_numpy(x_obs).float().to(self.device)
        else:
            x_obs_np = x_obs.cpu().numpy()
            x_obs_torch = x_obs.to(self.device)
        
        if x_obs_torch.ndim == 2:
            x_batch = x_obs_torch.unsqueeze(0)
        elif x_obs_torch.ndim == 1:
            x_batch = x_obs_torch.unsqueeze(0).unsqueeze(0)
        else:
            x_batch = x_obs_torch
            
        return x_obs_np, x_batch

    def _get_embedding_owner(self):
        if hasattr(self.npe_model, "_embedding_net"):
            return self.npe_model
        if hasattr(self.npe_model, "net") and hasattr(self.npe_model.net, "_embedding_net"):
            return self.npe_model.net
        return None

    def _sample_with_summary(self, s: torch.Tensor, n_samples: int) -> torch.Tensor:
        condition = s.unsqueeze(0)
        sample_shape = (n_samples,) if isinstance(n_samples, int) else n_samples

        with torch.no_grad():
            if self.bypass_embedding:
                owner = self._get_embedding_owner()
                original = owner._embedding_net
                try:
                    owner._embedding_net = nn.Identity()
                    samples = self.npe_model.sample(sample_shape, condition=condition)
                finally:
                    owner._embedding_net = original
            else:
                samples = self.npe_model.sample(sample_shape, condition=condition)
        
        return samples.squeeze(1).cpu() if samples.ndim == 3 else samples.cpu()


class NPEPFNPosterior:
    def __init__(self, npe_pfn_estimator, adapter=None, device: str = "cpu", embed_adapter_output=None):
        self.estimator = npe_pfn_estimator
        self.adapter = adapter
        self.device = device
        self.embed_adapter_output = embed_adapter_output
    
    def sample(
        self, 
        x_obs: Union[torch.Tensor, np.ndarray], 
        n_samples: int,
        adapt: bool = False, 
        return_info: bool = False, 
        s_obs: Optional[Union[torch.Tensor, np.ndarray]] = None,
        **kwargs
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
        info = {}
        
        if s_obs is not None:
            if isinstance(s_obs, np.ndarray):
                s_init = torch.from_numpy(s_obs).float()
            else:
                s_init = s_obs.float() if torch.is_tensor(s_obs) else torch.tensor(s_obs).float()
        else:
            raise ValueError("s_obs is required for NPE-PFN (no embedding net to compute summaries).")
        
        if isinstance(x_obs, torch.Tensor):
            x_obs_np = x_obs.cpu().numpy()
        else:
            x_obs_np = x_obs
        
        if self.adapter is not None and adapt:
            adapt_result = self.adapter.adapt(s_init, x_obs_np, **kwargs)
            s_for_pfn = adapt_result['best_s'].float()
            info = adapt_result
            
            if self.embed_adapter_output is not None:
                with torch.no_grad():
                    s_for_pfn = self.embed_adapter_output(s_for_pfn.unsqueeze(0).to(self.device)).squeeze(0).cpu()
                    info['best_s'] = s_for_pfn 
        else:
            s_for_pfn = s_init
        
        if s_for_pfn.ndim == 1:
            s_for_pfn = s_for_pfn.unsqueeze(0)
        
        # sample from NPE-PFN
        samples = self.estimator.sample((n_samples,), x=s_for_pfn)
        
        if return_info:
            return samples, info
        return samples