import os
import argparse
from typing import Tuple

import numpy as np
import torch
from tqdm import tqdm

from models import SetEncoderND
from utils.model import Autoencoder, MLPDecoder


# ---------------------------------------------------------------------------
# Input data
# ---------------------------------------------------------------------------

def _should_record(time_step: int) -> bool:
    return True


def _load_trajectories_from_npz(npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
    data = np.load(npz_path, allow_pickle=True)
    trajectories = data["trajectories"]
    types = data["types"]
    data.close()
    return trajectories, types


def load_anchors_from_npz(
    data_path: str,
    time_step=None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Load anchor sets and particle types from the NPZ file used in training.

    Returns:
      anchors_all: (E, S, M, D) where S = number of selected time steps
      types:       (E, M) int array with values {1,2}
      time_idx:    (S,) numpy array of the resolved time indices (0-based)
    """
    anchors, types = _load_trajectories_from_npz(data_path)  # (E, T, M, D), (E, M)
    E, T, M, D = anchors.shape
    if types.shape != (E, M):
        raise ValueError(f"types must have shape {(E, M)}, got {types.shape}")

    if time_step is None:
        time_step = [i for i in range(T) if _should_record(i)]
        if not time_step:
            raise ValueError("No time steps selected by _should_record().")

    if isinstance(time_step, (int, np.integer)):
        time_idx = np.array([int(time_step)], dtype=int)
    else:
        time_idx = np.array(list(time_step), dtype=int)

    resolved = []
    for ts in time_idx:
        ts_res = T + ts if ts < 0 else ts
        if not (0 <= ts_res < T):
            raise IndexError(
                f"time_step {ts} (resolved to {ts_res}) is out of range [0, {T-1}]"
            )
        resolved.append(ts_res)
    time_idx = np.array(resolved, dtype=int)

    anchors_all = anchors[:, time_idx, :, :]  # (E, S, M, D)
    return anchors_all, types, time_idx


def _reshape_norm_values(values: np.ndarray, ndim: int) -> np.ndarray:
    flat = np.asarray(values).reshape(-1)
    shape = (1,) * (ndim - 1) + (flat.shape[0],)
    return flat.reshape(shape)


def normalize_anchors(
    anchors: np.ndarray,
    normalization_info: dict,
) -> np.ndarray:
    min_val = _reshape_norm_values(normalization_info["min"], anchors.ndim)
    max_val = _reshape_norm_values(normalization_info["max"], anchors.ndim)
    scale = max_val - min_val
    scale[scale == 0] = 1.0
    return 2.0 * (anchors - min_val) / scale - 1.0


# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------

class TypeSplitDeepSetAutoencoder(torch.nn.Module):
    """
    Shared DeepSet encoder is applied separately to type-1 and type-2 particles.
    The two latent codes are concatenated and decoded by an MLP to reconstruct
    all particles (decoder does not distinguish types).
    """

    def __init__(
        self,
        in_dim: int,
        n_points: int,
        hidden_dim: int = 256,
        z_dim: int = 16,
        pool: str = "mean",
    ):
        super().__init__()
        self.in_dim = in_dim
        self.n_points = n_points
        self.z_dim = z_dim
        self.encoder = SetEncoderND(in_dim=in_dim, hidden_dim=hidden_dim, z_dim=z_dim, pool=pool)
        self.decoder = MLPDecoder(z_dim * 2, hidden_dim, n_points * in_dim)

    def forward(self, x: torch.Tensor, types: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        if types.dim() != 2 or types.shape[:2] != x.shape[:2]:
            raise ValueError(f"types must have shape {x.shape[:2]}, got {tuple(types.shape)}")

        mask_type1 = types == 1
        mask_type2 = types == 2

        if (mask_type1.sum(dim=1) == 0).any() or (mask_type2.sum(dim=1) == 0).any():
            raise ValueError("Each sample must contain at least one particle of each type.")

        z_type1 = self.encoder(x, mask=mask_type1)
        z_type2 = self.encoder(x, mask=mask_type2)
        z = torch.cat([z_type1, z_type2], dim=-1)

        x_recon_flat = self.decoder(z)
        x_recon = x_recon_flat.view(x.shape[0], self.n_points, self.in_dim)
        return z, x_recon


# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------

def build_model_from_checkpoint(
    ckpt_path: str,
    device: torch.device,
    model_type: str | None = None,
) -> Tuple[torch.nn.Module, dict, dict, str]:
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

    inferred_type = model_type
    if inferred_type is None:
        inferred_type = "deepset" if "pool" in ckpt else "autoencoder"

    if inferred_type == "deepset":
        data_dim = ckpt.get("data_dim", ckpt.get("args", {}).get("dim"))
        n_anchors = ckpt["n_anchors"]
        z_dim = ckpt.get("z_dim", 8)
        hidden_dim = ckpt.get("hidden_dim", 256)
        pool = ckpt.get("pool", "mean")
        model = TypeSplitDeepSetAutoencoder(
            in_dim=data_dim,
            n_points=n_anchors,
            hidden_dim=hidden_dim,
            z_dim=z_dim,
            pool=pool,
        ).to(device)
    elif inferred_type == "autoencoder":
        input_dim = ckpt["input_dim"]
        z_dim = ckpt.get("z_dim", 8)
        hidden_dim = ckpt.get("hidden_dim", 256)
        model = Autoencoder(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            z_dim=z_dim,
        ).to(device)
    else:
        raise ValueError(f"Unknown model_type: {inferred_type}")

    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    normalization_info = ckpt.get(
        "normalization_info", ckpt.get("args", {}).get("normalization_info", None)
    )
    return model, ckpt, normalization_info, inferred_type


# ---------------------------------------------------------------------------
# Encoding
# ---------------------------------------------------------------------------

def _encode_batch_autoencoder(
    model: torch.nn.Module,
    anchors_batch: torch.Tensor,
) -> torch.Tensor:
    bsz = anchors_batch.shape[0]
    anchors_flat = anchors_batch.view(bsz, -1)
    return model.encoder(anchors_flat)


def get_z_over_time_autoencoder(
    model: torch.nn.Module,
    anchors: np.ndarray,  # (E, T, M, D)
    device: torch.device,
) -> np.ndarray:
    anchors_t = torch.from_numpy(anchors).float()
    E, T, M, D = anchors_t.shape
    Z_list = []

    with torch.no_grad():
        for e in tqdm(range(E), desc="Iterating experiments for Z over time"):
            anchors_e = anchors_t[e].contiguous().to(device)  # (T, M, D)
            z_e = _encode_batch_autoencoder(model, anchors_e)  # (T, z_dim)
            Z_list.append(z_e.cpu().numpy())

    return np.stack(Z_list, axis=0)  # (E, T, z_dim)


def get_concat_z_over_time_deepset(
    model: TypeSplitDeepSetAutoencoder,
    anchors: np.ndarray,  # (E, T, M, D)
    types: np.ndarray,    # (E, M)
    device: torch.device,
) -> np.ndarray:
    E, T, M, D = anchors.shape
    Z_list = []

    with torch.no_grad():
        for e in tqdm(range(E), desc="Iterating experiments for Z over time"):
            mask_type1 = types[e] == 1
            mask_type2 = types[e] == 2
            if not np.any(mask_type1) or not np.any(mask_type2):
                raise ValueError(f"Experiment {e} must have both particle types.")

            anchors_e = torch.from_numpy(anchors[e]).float().to(device).contiguous()  # (T, M, D)
            mask1 = torch.from_numpy(mask_type1).to(device)
            mask2 = torch.from_numpy(mask_type2).to(device)
            mask1 = mask1.unsqueeze(0).expand(T, M)
            mask2 = mask2.unsqueeze(0).expand(T, M)

            z_type1 = model.encoder(anchors_e, mask=mask1)
            z_type2 = model.encoder(anchors_e, mask=mask2)
            z_concat = torch.cat([z_type1, z_type2], dim=-1)  # (T, 2*z_dim)
            Z_list.append(z_concat.cpu().numpy())

    return np.stack(Z_list, axis=0)  # (E, T, 2*z_dim)


# ---------------------------------------------------------------------------
# Main generation
# ---------------------------------------------------------------------------

def generate_for_dataset(
    model: torch.nn.Module,
    model_type: str,
    normalization_info: dict,
    data_path: str,
    macro_path: str,
    output_path: str,
    time_step: list[int] | None,
    device: torch.device,
) -> None:
    anchors, types, time_idx = load_anchors_from_npz(
        data_path=data_path,
        time_step=time_step,
    )
    print(f"Input anchors shape: {anchors.shape}, types shape: {types.shape}")

    anchors_normalized = normalize_anchors(anchors, normalization_info)

    macro_feature = np.load(macro_path, allow_pickle=True)
    if macro_feature.ndim != 3 or macro_feature.shape[2] != 2:
        raise ValueError(
            f"macro_feature must have shape (n_traj, T, 2), got {macro_feature.shape}"
        )
    if macro_feature.shape[0] < anchors.shape[0]:
        raise ValueError(
            f"macro_feature has {macro_feature.shape[0]} trajectories, but data has {anchors.shape[0]}"
        )
    if macro_feature.shape[1] <= np.max(time_idx):
        raise ValueError("macro_feature has fewer timesteps than requested.")
    macro_feature = macro_feature[: anchors.shape[0], time_idx, :]

    if model_type == "autoencoder":
        Z = get_z_over_time_autoencoder(
            model=model,
            anchors=anchors_normalized,
            device=device,
        )
    elif model_type == "deepset":
        Z = get_concat_z_over_time_deepset(
            model=model,
            anchors=anchors_normalized,
            types=types,
            device=device,
        )
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    np.savez_compressed(
        output_path,
        macro_feature=macro_feature,
        Z=Z,
    )
    print(f"Saved macro/Z arrays to {output_path}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate autoencoder Z over time for mix data")

    script_dir = os.path.dirname(os.path.abspath(__file__))
    default_data_path = os.path.join(script_dir, "generate_data", "dataset", "trajectories_diffN_test.npz")
    default_macro_path = os.path.join(script_dir, "generate_data", "dataset", "macro_feature_diffN_test.npy")
    out_name_default = "macro_and_Z_diffN_test.npz"

    parser.add_argument(
        "--base_path",
        type=str,
        required=True,
        help="Directory to checkpoint (e.g., trained_autoencoder/exp1/...).",
    )
    parser.add_argument("--z_dim", type=int, default=8, help="latent dimension (per type for deepset)")
    parser.add_argument(
        "--model_type",
        type=str,
        default=None,
        choices=["autoencoder", "deepset"],
        help="Force model type (default: infer from checkpoint)",
    )
    parser.add_argument(
        "--ckpt_path",
        type=str,
        default=None,
        help="Override checkpoint path (default based on base_path/model_type/z_dim)",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default=default_data_path,
        help="NPZ path with 'trajectories' and 'types'.",
    )
    parser.add_argument(
        "--macro_path",
        type=str,
        default=default_macro_path,
        help="Path to macro_feature.npy (shape: n_traj, T, 2).",
    )
    parser.add_argument(
        "--time_step",
        type=int,
        nargs="+",
        default=None,
        help="Optional time steps; default matches training selection.",
    )
    parser.add_argument(
        "--output_name",
        type=str,
        default=out_name_default,
        help="Output NPZ filename.",
    )
    parser.add_argument("--device", type=int, default=0)

    args = parser.parse_args()

    device = torch.device(
        f"cuda:{args.device}"
        if (args.device >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    if args.ckpt_path is None:
        inferred_type = args.model_type
        if inferred_type is None:
            raise ValueError("--model_type is required when --ckpt_path is not provided.")
        if inferred_type == "autoencoder":
            ckpt_name = f"best_autoencoder_Z{args.z_dim}.pth"
        else:
            ckpt_name = f"best_deepset_autoencoder_Z{args.z_dim}.pth"
        ckpt_path = os.path.join(args.base_path, ckpt_name)
    else:
        ckpt_path = args.ckpt_path

    model, ckpt, normalization_info, model_type = build_model_from_checkpoint(
        ckpt_path, device, args.model_type
    )
    if normalization_info is None:
        raise ValueError("Checkpoint is missing normalization_info.")
    print(f"Loaded {model_type} checkpoint: {ckpt_path}")

    output_root = os.path.join(args.base_path, f"autoencoder_z_Z{args.z_dim}")
    os.makedirs(output_root, exist_ok=True)
    output_path = os.path.join(output_root, args.output_name)

    generate_for_dataset(
        model=model,
        model_type=model_type,
        normalization_info=normalization_info,
        data_path=args.data_path,
        macro_path=args.macro_path,
        output_path=output_path,
        time_step=args.time_step,
        device=device,
    )
