import torch
import numpy as np
import copy
from typing import Tuple, List, Dict, Any, Optional

from typing_extensions import Literal
import flowtorch
import flowtorch.bijectors as bij
import flowtorch.distributions as dist
import zuko
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import (
    StandardScaler,
    RobustScaler,
    MinMaxScaler,
    QuantileTransformer,
)

from emm.algorithms.neural_rules import SimpleMixtureRules, GMMRemixer
from emm.utils.train_utils import TemperatureAnnealer


def flow_flowtorch_gen(**kwargs):
    """Generator for FlowTorch spline bijectors."""

    def f():
        return bij.Spline(**kwargs), "flowtorch"

    return f


def flow_zuko_gen(name="zuko_GF", features=1, **kwargs):
    """Generator for various Zuko flows."""
    if name.lower() == "zuko_gf":
        return lambda: (zuko.flows.GF(features, **kwargs), "zuko")
    elif name.lower() == "zuko_gmm":
        return lambda: (zuko.flows.GMM(features, **kwargs), "zuko")
    elif name.lower() == "zuko_nsf":
        if "bins" not in kwargs:
            kwargs["bins"] = 6
        return lambda: (zuko.flows.NSF(features, **kwargs), "zuko")
    elif name.lower() == "zuko_maf":
        return lambda: (zuko.flows.MAF(features, **kwargs), "zuko")
    elif name.lower() == "zuko_naf":
        return lambda: (zuko.flows.NAF(features, **kwargs), "zuko")
    else:
        raise ValueError(f"Unknown flow type: {name}")


def create_flow(
    flow_gen, base_dist, device, lr
) -> Tuple[torch.nn.Module, torch.optim.Optimizer]:
    """Creates a flow model and its optimizer."""
    f, name = flow_gen()
    if name == "zuko":
        flow = f.to(device)
    else:
        flow = dist.Flow(base_dist, f).to(device)
    optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
    return flow, optimizer


def get_log_prob(flow, X):
    """Gets log probability from either a FlowTorch or Zuko model."""
    if isinstance(flow, flowtorch.distributions.flow.Flow):
        return flow.log_prob(X)
    else:
        return flow().log_prob(X)


def _setup_training_environment(config, device_str: str | torch.device):
    """Sets up device, seeds, and returns the torch device."""
    if isinstance(device_str, str):
        device = torch.device(device_str)
    else:
        device = device_str
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    return device


def _prepare_data(X, Y, device):
    """Moves data to the specified device, ensures float32 dtype."""
    X = X.to(device, dtype=torch.float32)
    Y = Y.to(device, dtype=torch.float32)
    if Y.ndim == 1:
        Y = Y.unsqueeze(1)
    return X, Y


def _initialize_training_state(n_components: int, record_history_every: Optional[int]):
    """Initializes disabled components list and history list."""
    disabled_components = [False] * n_components
    history = [] if record_history_every is not None else None
    return disabled_components, history


def _anneal_parameters(
    step: int,
    total_epochs: int,
    temp_anneal: Optional[TemperatureAnnealer],
    mixture_rules: SimpleMixtureRules,
    current_temp: float,
):
    """Handles temperature annealing."""
    new_temp = current_temp
    if temp_anneal is not None:
        new_temp = temp_anneal.get_temperature(step, total_epochs)
        for rule in mixture_rules.rules:
            rule.set_temperature(new_temp)
    return new_temp


def _check_and_update_component_status(
    step: int,
    check_responsibility_every: int,
    X: torch.Tensor,
    mixture_rules: SimpleMixtureRules,
    min_responsibility_threshold: float,
    disabled_components: List[bool],
    verbose: bool,
) -> Tuple[List[bool], Optional[torch.Tensor]]:
    """Checks component responsibilities and updates the disabled status."""
    mean_responsibilities = None
    if step > 0 and step % check_responsibility_every == 0:
        with torch.no_grad():
            full_rule_probs, _ = mixture_rules(X)
            interpretable_probs = full_rule_probs[:, : mixture_rules.n_components]
            mean_responsibilities = torch.mean(interpretable_probs, dim=0)

            for i, resp in enumerate(mean_responsibilities):
                if resp < min_responsibility_threshold:
                    if not disabled_components[i] and verbose:
                        print(
                            f"Step {step}: Disabling component {i+1} with responsibility {resp:.6f}"
                        )
                    disabled_components[i] = True
                    mixture_rules.rules[i].disabled = True
    return disabled_components, mean_responsibilities


def _record_training_snapshot(
    history: Optional[List[Dict]],
    step: int,
    record_history_every: Optional[int],
    mixture_rules: SimpleMixtureRules,
    density_model_state: Any,
    density_mode: str,
    disabled_components: List[bool],
    current_temp: float,
    latest_total_loss: float,
    nll_loss: torch.Tensor,
    partition_loss: torch.Tensor,
    coverage_loss: torch.Tensor,
    kl_loss: Optional[torch.Tensor] = None,
    gmm_l1_loss: Optional[torch.Tensor] = None,
    and_layer_l1_loss: Optional[torch.Tensor] = None,
):
    """Records a snapshot of the training state."""
    if history is not None and step % record_history_every == 0:
        snapshot = {
            "step": step,
            "mixture_rules_state": copy.deepcopy(mixture_rules.state_dict()),
            "density_model_state": density_model_state,
            "density_mode": density_mode,
            "disabled_components": list(disabled_components),
            "current_temp": current_temp,
            "total_loss": latest_total_loss,
            "nll_loss": nll_loss.item() if torch.is_tensor(nll_loss) else np.nan,
            "partition_loss": partition_loss.item()
            if torch.is_tensor(partition_loss)
            else np.nan,
            "and_layer_l1_loss": and_layer_l1_loss.item()
            if torch.is_tensor(and_layer_l1_loss)
            else np.nan,
            "coverage_loss": coverage_loss.item()
            if torch.is_tensor(coverage_loss)
            else np.nan,
            "kl_loss": kl_loss.item()
            if density_mode == "flow" and torch.is_tensor(kl_loss)
            else np.nan,
            "gmm_l1_loss": gmm_l1_loss.item()
            if density_mode == "gmm_remix" and torch.is_tensor(gmm_l1_loss)
            else np.nan,
        }
        history.append(snapshot)


def _calculate_common_losses(
    rule_probs: torch.Tensor,
    config,
    n_components: int,
    device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Calculates partition and coverage losses."""
    partition_loss = torch.tensor(0.0, device=device)
    if config.partition_weight > 0:
        partition_loss = torch.mean(torch.sum(rule_probs**2, dim=1))

    coverage_loss = torch.tensor(0.0, device=device)
    if config.coverage_weight > 0:
        mean_responsibility = torch.mean(rule_probs, dim=0)
        threshold = 1.0 / (n_components * 3.0)
        coverage_loss = torch.sum(torch.relu(threshold - mean_responsibility))

    return partition_loss, coverage_loss


def _calculate_total_loss(
    nll_loss: torch.Tensor,
    partition_loss: torch.Tensor,
    coverage_loss: torch.Tensor,
    kl_loss: Optional[torch.Tensor],
    gmm_l1_loss: Optional[torch.Tensor],
    and_layer_l1_loss: Optional[torch.Tensor],
    config,
    gmm_remix_l1_weight: float,
    update_rules: bool,
    density_mode: str,
    gmm_div_pen: Optional[torch.Tensor] = None,
    gmm_div_weight: float = 0.0,
) -> torch.Tensor:
    """Calculates the total loss based on the mode and update step."""
    total_loss = nll_loss
    if gmm_div_pen is None:
        gmm_div_pen = torch.tensor(0.0, device=nll_loss.device)

    if config.partition_weight > 0:
        total_loss -= config.partition_weight * partition_loss
    if config.coverage_weight > 0:
        total_loss += config.coverage_weight * coverage_loss
    if (
        hasattr(config, "and_layer_entropy")
        and config.and_layer_entropy > 0
        and and_layer_l1_loss is not None
    ):
        total_loss += config.and_layer_entropy * and_layer_l1_loss
    if density_mode == "flow" and config.kl_weight > 0 and kl_loss is not None:
        total_loss -= config.kl_weight * kl_loss
    if (
        density_mode == "gmm_remix"
        and gmm_remix_l1_weight > 0
        and gmm_l1_loss is not None
    ):
        total_loss += gmm_remix_l1_weight * gmm_l1_loss + gmm_div_pen * gmm_div_weight

    return total_loss


def _log_verbose_output(
    step: int,
    current_temp: float,
    latest_total_loss: float,
    nll_loss: torch.Tensor,
    partition_loss: torch.Tensor,
    coverage_loss: torch.Tensor,
    kl_loss: Optional[torch.Tensor],
    gmm_l1_loss: Optional[torch.Tensor],
    and_layer_l1_loss: Optional[torch.Tensor],
    config,
    gmm_remix_l1_weight: float,
    disabled_components: List[bool],
    mean_responsibilities: Optional[torch.Tensor],
    mixture_rules: SimpleMixtureRules,
    feature_names: Optional[List[str]],
    preprocessor,
    X: torch.Tensor,
    X_unscaled: Optional[torch.Tensor],
    data_limits: Optional[torch.Tensor],
    density_mode: str,
):
    """Prints verbose training progress."""
    with torch.no_grad():
        print(f"\nStep {step}")
        print(f"Temperature: {current_temp:.4f}")
        print(f"Loss: {latest_total_loss:.4f}")
        print("Components:")
        print(f"- NLL: {nll_loss.item():.4f}")
        if config.partition_weight > 0:
            print(f"- Partition Penalty: {partition_loss.item():.4f}")
        if config.coverage_weight > 0:
            print(f"- Coverage Penalty: {coverage_loss.item():.4f}")
        if (
            hasattr(config, "and_layer_entropy")
            and config.and_layer_entropy > 0
            and and_layer_l1_loss is not None
        ):
            print(f"- And Layer L1 Penalty: {and_layer_l1_loss.item():.4f}")
        if density_mode == "flow" and config.kl_weight > 0 and kl_loss is not None:
            print(f"- KL Penalty: {kl_loss.item():.4f}")
        if (
            density_mode == "gmm_remix"
            and gmm_remix_l1_weight > 0
            and gmm_l1_loss is not None
        ):
            print(f"- GMM Remix L1 Penalty: {gmm_l1_loss.item():.4f}")
        print(
            f"- Disabled components: {sum(disabled_components)}/{len(disabled_components)}"
        )

        if mean_responsibilities is not None:
            active_comps = [i for i, d in enumerate(disabled_components) if not d]
            if active_comps:
                resp_str = ", ".join(
                    [f"{i+1}:{mean_responsibilities[i]:.4f}" for i in active_comps]
                )
                print(f"- Active component responsibilities: {resp_str}")
            if mixture_rules.use_background_component:
                full_probs, _ = mixture_rules(X)
                bg_resp = torch.mean(full_probs[:, -1]).item()
                print(f"- Background component responsibility: {bg_resp:.4f}")

        print("\nCurrent effective rules:")
        if feature_names is not None:
            model_device = next(mixture_rules.parameters()).device
            cut_points = preprocessor.cut_points_.to(model_device)
            print(
                mixture_rules.get_rules(
                    cut_points, X_unscaled, feature_names, preprocessor.scaler_x
                )
            )
        print("-" * 80)


def _calculate_final_metrics(
    total_parameters: int, n_samples: int, final_nll: float
) -> Tuple[float, float]:
    """Calculates AIC and BIC."""
    if not np.isnan(final_nll):
        aic = 2 * total_parameters + 2 * n_samples * final_nll
        bic = total_parameters * np.log(n_samples) + 2 * n_samples * final_nll
    else:
        aic, bic = np.nan, np.nan
    return aic, bic


def _create_details_dict(
    density_mode: str,
    final_nll: float,
    latest_total_loss: float,
    partition_loss: torch.Tensor,
    coverage_loss: torch.Tensor,
    kl_loss: Optional[torch.Tensor],
    gmm_l1_loss: Optional[torch.Tensor],
    and_layer_l1_loss: Optional[torch.Tensor],
    aic: float,
    bic: float,
    disabled_components: List[bool],
    total_parameters: int,
    history: Optional[List[Dict]],
    **kwargs,
) -> Dict[str, Any]:
    """Creates the final details dictionary."""
    details = {
        "density_mode": density_mode,
        "nll_loss": final_nll,
        "total_loss": latest_total_loss,
        "partition_loss": partition_loss.item()
        if torch.is_tensor(partition_loss)
        else np.nan,
        "coverage_loss": coverage_loss.item()
        if torch.is_tensor(coverage_loss)
        else np.nan,
        "kl_loss": kl_loss.item()
        if density_mode == "flow" and torch.is_tensor(kl_loss)
        else np.nan,
        "gmm_l1_loss": gmm_l1_loss.item()
        if density_mode == "gmm_remix" and torch.is_tensor(gmm_l1_loss)
        else np.nan,
        "and_layer_l1_loss": and_layer_l1_loss.item()
        if torch.is_tensor(and_layer_l1_loss)
        else np.nan,
        "AIC": aic,
        "BIC": bic,
        "disabled_components": disabled_components,
        "total_parameters": total_parameters,
    }
    details.update(kwargs)
    if history is not None:
        details["training_history"] = history
    return details


def _finalize_training(
    details: Dict[str, Any],
    mixture_rules: SimpleMixtureRules,
    density_model: Any,
):
    """Moves models to CPU."""
    mixture_rules.to(torch.device("cpu"))
    if isinstance(density_model, list):
        for flow in density_model:
            if flow is not None:
                flow.to(torch.device("cpu"))
    else:
        density_model.to(torch.device("cpu"))
    return details


def _flow_data_cleaning(Y, flow_name: str, mode: Literal["drop", "clamp"] = "clamp"):
    """
    Cleans data for flow training based on the flow type.
    flow_name zuko_spline requires values in [-5,5], zuko_GF in [-10,10].
    scale mode rescales into the specified range
    """
    if flow_name.lower() == "zuko_nsf":
        if mode == "drop":
            mask = Y.abs() <= 5
            Y_cleaned = Y[mask]
        elif mode == "clamp":
            Y_cleaned = torch.clamp(Y, -5, 5)
        else:
            raise ValueError("Invalid mode. Use 'drop' or 'clamp'.")
    elif flow_name.lower() in ["zuko_gf", "zuko_naf"]:
        if mode == "drop":
            mask = Y.abs() <= 10
            Y_cleaned = Y[mask]
        elif mode == "clamp":
            Y_cleaned = torch.clamp(Y, -9.5, 9.5)
        else:
            raise ValueError("Invalid mode. Use 'drop' or 'clamp'.")
    else:
        Y_cleaned = Y
    return Y_cleaned


def _initialize_flows(
    config,
    mixture_rules: SimpleMixtureRules,
    Y: torch.Tensor,
    X: torch.Tensor,
    device: torch.device,
    lr_flow: float,
    pop_train_epochs: int,
    batchsize: int,
    verbose: bool,
):
    """Initializes and pre-trains component and population flows."""
    print(f"pop train epochs: {pop_train_epochs}")
    y_features = Y.shape[1]
    if hasattr(config, "flow_gen"):
        flow_name, kwargs = config.flow_gen[0], config.flow_gen[1]
        if flow_name.lower().startswith("zuko"):
            flow_gen = flow_zuko_gen(flow_name, features=y_features, **kwargs)
            Y = _flow_data_cleaning(
                Y, flow_name, mode=kwargs.get("cleaning_mode", "clamp")
            )
        elif flow_name.lower().startswith("flowtorch"):
            flow_gen = flow_flowtorch_gen(**kwargs)
        else:
            raise ValueError(f"Unknown flow type: {flow_name}")
    else:
        flow_gen = flow_flowtorch_gen(count_bins=4)

    base_dist = torch.distributions.Independent(
        torch.distributions.Normal(
            torch.zeros(Y.shape[1]).to(device), torch.ones(Y.shape[1]).to(device)
        ),
        1,
    )

    component_flows = [
        create_flow(flow_gen, base_dist, device, lr_flow)
        for _ in range(mixture_rules.n_components)
    ]
    # component_optimizers = [opt for _, opt in component_flows]
    component_flows = [flow for flow, _ in component_flows]

    if pop_train_epochs > 0:
        raise NotImplementedError(
            "Population flow pre-training is not implemented yet. "
        )
        # pretrain_flows_soft(
        #     component_flows,
        #     component_optimizers,
        #     X,
        #     Y,
        #     mixture_rules,
        #     n_epochs=pop_train_epochs,
        #     batch_size=batchsize,
        # )

    if verbose:
        print("Flow pre-training completed")

    return (
        base_dist,
        component_flows,
        flow_gen,
        X,
        Y,
    )


def _initialize_gmm_remix(
    Y: torch.Tensor,
    n_gmm_components: int,
    gmm_reg_covar: float,
    gmm_max_iter: int,
    config,
    mixture_rules: SimpleMixtureRules,
    device: torch.device,
    verbose: bool,
) -> Tuple[GMMRemixer, torch.Tensor, GaussianMixture]:
    """Fits global GMM, calculates densities, initializes GMMRemixer."""

    Y_np = Y.cpu().numpy()
    if Y_np.ndim == 1:
        Y_np = Y_np.reshape(-1, 1)

    best_bic = np.inf
    max_components = n_gmm_components
    if config.component_scoring == "bic":
        if verbose:
            print("Finding optimal number of GMM components using BIC...")
        for i in range(2, max_components + 1):
            gmm = GaussianMixture(
                n_components=i,
                covariance_type="full",
                # reg_covar=gmm_reg_covar,
                max_iter=gmm_max_iter,
                init_params="k-means++",
                n_init=3,
                random_state=config.seed,
            ).fit(Y_np)
            bic = gmm.bic(Y_np)

            if bic < best_bic:
                best_bic = bic
                n_gmm_components = i
    n_gmm_components = max(5, n_gmm_components)
    if verbose:
        print(f"Fitting global GMM with {n_gmm_components} Gaussian components...")
    gmm = GaussianMixture(
        n_components=n_gmm_components,
        covariance_type="full",
        # reg_covar=gmm_reg_covar,
        max_iter=gmm_max_iter,
        init_params="k-means++",
        n_init=3,
        random_state=config.seed,
    ).fit(Y_np)
    if not gmm.converged_:
        print("Warning: Global GMM did not converge.")

    densities = np.exp(gmm._estimate_log_prob(Y_np))
    gmm_component_densities_full = torch.tensor(
        densities, dtype=torch.float32, device=device
    )

    gmm_remixer = GMMRemixer(
        n_rules=mixture_rules.n_components,
        n_gmm_components=n_gmm_components,
        diagonal=getattr(config, "diagonal_gmm_init", False),
        use_background_component=mixture_rules.use_background_component,
    ).to(device)

    if verbose:
        print("Global GMM fitting and density calculation complete.")

    return gmm_remixer, gmm_component_densities_full, gmm


def scale_nll(nll, scaler_y):
    """Adjusts the NLL for data scaling."""
    if scaler_y is None or np.isnan(nll):
        return nll
    log_jacobian_term = 0.0
    if isinstance(scaler_y, (StandardScaler, RobustScaler)):
        if hasattr(scaler_y, "scale_") and scaler_y.scale_ is not None:
            log_jacobian_term = np.sum(np.log(scaler_y.scale_ + 1e-9))
    elif isinstance(scaler_y, MinMaxScaler):
        if hasattr(scaler_y, "data_range_") and scaler_y.data_range_ is not None:
            log_jacobian_term = np.sum(np.log(scaler_y.data_range_ + 1e-9))
    elif isinstance(scaler_y, QuantileTransformer):
        return nll
    return nll + log_jacobian_term


def pretrain_flows_soft(
    flows, optimizers, X, Y, mixture_rules, n_epochs=100, batch_size=128
):
    """Pre-trains flows with soft assignments from rules."""
    device = X.device
    for flow, optimizer in zip(flows, optimizers):
        for epoch in range(n_epochs):
            idx = torch.randperm(X.shape[0], device=device)
            for i in range(0, X.shape[0], batch_size):
                idx_batch = idx[i : i + batch_size]
                X_batch, Y_batch = X[idx_batch], Y[idx_batch]
                with torch.no_grad():
                    rule_probs, _ = mixture_rules(X_batch)
                j = flows.index(flow)
                w_j = rule_probs[:, j]
                optimizer.zero_grad()
                log_prob = get_log_prob(flow, Y_batch)
                loss = -torch.sum(w_j * log_prob) / (torch.sum(w_j) + 1e-8)
                loss.backward()
                optimizer.step()


def pretrain_flows(
    flows,
    optimizers,
    X,
    Y,
    assignments,
    n_epochs=100,
    batch_size=128,
    device=torch.device("cpu"),
):
    """Pre-trains each flow on its hard-assigned data."""
    for flow, optimizer, indices in zip(flows, optimizers, assignments):
        if len(indices) == 0:
            continue
        Y_component = Y[indices]
        for epoch in range(n_epochs):
            if batch_size < len(indices):
                idx = torch.randperm(len(indices))[:batch_size]
                Y_batch = Y_component[idx]
            else:
                Y_batch = Y_component
            optimizer.zero_grad()
            loss = -get_log_prob(flow, Y_batch).mean()
            loss.backward()
            optimizer.step()
