from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, List
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler
from pathlib import Path
from typing import Union

from .nn import build_npe_model

@dataclass
class NPE_TrainConfig:
    lr: float = 5e-4
    batch_size: int = 256
    val_frac: float = 0.1
    stop_after_epochs: int = 20
    max_epochs: int = 10_000
    log_every: int = 50


class NPETrainer:
    def __init__(self, model: nn.Module, config: NPE_TrainConfig, device: str = "cpu"):
        self.model = model.to(device)
        self.config = config
        self.device = device
        self.history: Dict[str, List[float]] = {}
    
    def setup(self, thetas: torch.Tensor, xs: torch.Tensor, **kwargs) -> None:
        pass
    
    def train_step(self, theta_b: torch.Tensor, x_b: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
        loss = self.model.loss(theta_b, condition=x_b).mean()
        return loss, {"train_loss": loss.item()}
    
    def validate_batch(self, theta_b: torch.Tensor, x_b: torch.Tensor) -> float:
        return self.model.loss(theta_b, condition=x_b).sum().item()
    
    def on_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
        pass
    
    def finalize_history(self, history: Dict[str, Any]) -> Dict[str, Any]:
        return history
    
    def train(
        self,
        thetas: torch.Tensor,
        xs: torch.Tensor,
        seed: Optional[int] = None,
        **kwargs,
    ) -> Tuple[nn.Module, Dict[str, Any]]:
        self.setup(thetas, xs, **kwargs)
        
        dataset = TensorDataset(thetas, xs)
        n = len(dataset)
        n_train = int((1 - self.config.val_frac) * n)
        
        g = torch.Generator().manual_seed(seed) if seed else torch.Generator()
        perm = torch.randperm(n, generator=g)
        
        loader_kwargs = dict(batch_size=self.config.batch_size, drop_last=True)
        train_loader = DataLoader(
            dataset, sampler=SubsetRandomSampler(perm[:n_train].tolist()), **loader_kwargs
        )
        val_loader = DataLoader(
            dataset, sampler=SubsetRandomSampler(perm[n_train:].tolist()), **loader_kwargs
        )
        
        opt = Adam(self.model.parameters(), lr=self.config.lr)
        best_state, best_val, patience = None, float("inf"), 0
        best_epoch = 0
        
        history: Dict[str, List[float]] = {}
        history_initialized = False
        
        for epoch in range(self.config.max_epochs):
            # === Train ===
            self.model.train()
            epoch_metrics: Dict[str, float] = {}
            n_batches = 0
            
            for theta_b, x_b in train_loader:
                theta_b, x_b = theta_b.to(self.device), x_b.to(self.device)
                
                opt.zero_grad()
                loss, metrics = self.train_step(theta_b, x_b)
                loss.backward()
                opt.step()
                
                if not history_initialized:
                    for key in metrics:
                        history[key] = []
                        epoch_metrics[key] = 0.0
                    history["val_loss"] = []
                    history_initialized = True
                
                for key, val in metrics.items():
                    epoch_metrics[key] = epoch_metrics.get(key, 0.0) + val
                n_batches += 1
            
            for key in epoch_metrics:
                epoch_metrics[key] /= n_batches
                history[key].append(epoch_metrics[key])
            
            self.model.eval()
            with torch.no_grad():
                val_loss = sum(
                    self.validate_batch(t.to(self.device), x.to(self.device))
                    for t, x in val_loader
                ) / (len(val_loader) * self.config.batch_size)
            
            history["val_loss"].append(val_loss)
            epoch_metrics["val_loss"] = val_loss
            
            if self.config.log_every and epoch % self.config.log_every == 0:
                parts = [f"{k}={v:.4f}" for k, v in epoch_metrics.items()]
                print(f"Epoch {epoch}: " + ", ".join(parts))
            self.on_epoch_end(epoch, epoch_metrics)
            
            if val_loss < best_val:
                best_val = val_loss
                best_epoch = epoch
                best_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
                patience = 0
            else:
                patience += 1
                if patience >= self.config.stop_after_epochs:
                    print(f"Early stopping at epoch {epoch}")
                    break
        
        self.model.load_state_dict(best_state)
        
        history["best_epoch"] = best_epoch
        history["best_val_loss"] = best_val
        history["final_epoch"] = epoch
        history = self.finalize_history(history)
        
        return self.model, history

class StandardNPETrainer(NPETrainer):
    pass

def sample_npe_posterior(
    model: torch.nn.Module, 
    x_obs: torch.Tensor, 
    n_samples: int, 
    device: str = "cpu"
) -> torch.Tensor:
    if x_obs.dim() == 1:
        x_obs = x_obs.unsqueeze(0)
    x_obs = x_obs.to(device)
    
    with torch.no_grad():
        samples = model.sample((n_samples,), condition=x_obs)
    
    return samples.squeeze(1).cpu()

def train_NPE_estimator(
    model,
    thetas: torch.Tensor,
    xs: torch.Tensor,
    config: NPE_TrainConfig = NPE_TrainConfig(),
    device: str = "cpu",
    seed: Optional[int] = None,
):
    trainer = StandardNPETrainer(model, config, device)
    return trainer.train(thetas, xs, seed=seed)

def load_npe_model(
    model_path: Union[str, Path],
    theta_sample: torch.Tensor,
    x_sample: torch.Tensor,
    device: str = "cpu",
) -> Tuple[nn.Module, Dict[str, Any]]:
    checkpoint = torch.load(str(model_path), map_location="cpu")
    cfg = checkpoint["config"]

    model = build_npe_model(
        theta_sample=theta_sample,
        x_sample=x_sample,
        embedding_type=cfg.get("embedding_type", "fc"),
        n_obs=cfg.get("n_obs"),
        dim=cfg.get("dim"),
        embedding_dim=cfg.get("embedding_dim"),
        num_transforms=cfg.get("num_transforms"),
        hidden_features=cfg.get("hidden_features"),
        embedding_hidden=cfg.get("embedding_hidden"),
        embedding_layers=cfg.get("embedding_layers"),
    )
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()
    return model, cfg