import argparse
import os
import sys
from typing import Tuple

import numpy as np
import torch
from tqdm import tqdm

# Ensure local patterns/utils is importable regardless of CWD.
sys.path.append(os.path.dirname(__file__))
from utils.model import Autoencoder, DeepSetAutoencoder


# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def build_model_from_checkpoint(
    ckpt_path: str,
    device: torch.device,
    model_type: str | None = None,
) -> Tuple[torch.nn.Module, dict, dict | None, str]:
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

    inferred_type = model_type
    if inferred_type is None:
        inferred_type = "deepset" if "pool" in ckpt else "autoencoder"

    if inferred_type == "deepset":
        data_dim = ckpt.get("data_dim", ckpt.get("args", {}).get("dim"))
        n_anchors = ckpt["n_anchors"]
        z_dim = ckpt.get("z_dim", 8)
        hidden_dim = ckpt.get("hidden_dim", 256)
        pool = ckpt.get("pool", "mean")
        model = DeepSetAutoencoder(
            in_dim=data_dim,
            n_points=n_anchors,
            hidden_dim=hidden_dim,
            z_dim=z_dim,
            pool=pool,
        ).to(device)
    elif inferred_type == "autoencoder":
        input_dim = ckpt["input_dim"]
        z_dim = ckpt.get("z_dim", 8)
        hidden_dim = ckpt.get("hidden_dim", 256)
        model = Autoencoder(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            z_dim=z_dim,
        ).to(device)
    else:
        raise ValueError(f"Unknown model_type: {inferred_type}")

    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    normalization_info = ckpt.get(
        "normalization_info", ckpt.get("args", {}).get("normalization_info")
    )
    return model, ckpt, normalization_info, inferred_type


# ---------------------------------------------------------------------------
# Encoding
# ---------------------------------------------------------------------------
def _encode_batch(
    model: torch.nn.Module,
    anchors_batch: torch.Tensor,
    model_type: str,
) -> torch.Tensor:
    # anchors_batch: (B, M, D)
    if model_type == "autoencoder":
        bsz = anchors_batch.shape[0]
        anchors_flat = anchors_batch.view(bsz, -1)
        return model.encoder(anchors_flat)
    return model.encoder(anchors_batch)


def get_z_over_time(
    model: torch.nn.Module,
    anchors: np.ndarray,  # (E, T, M, D)
    device: torch.device,
    model_type: str,
) -> np.ndarray:
    anchors_t = torch.from_numpy(anchors).float()
    E, T, M, D = anchors_t.shape
    z_dim = model.z_dim
    Z_list = []

    with torch.no_grad():
        for e in tqdm(range(E), desc="Iterating experiments for Z over time"):
            anchors_e = anchors_t[e].contiguous().to(device)  # (T, M, D)
            z_e = _encode_batch(model, anchors_e, model_type)  # (T, z_dim)
            Z_list.append(z_e.cpu().numpy())

    return np.stack(Z_list, axis=0)  # (E, T, z_dim)


def generate_for_dataset(
    model: torch.nn.Module,
    model_type: str,
    normalization_info: dict,
    data_path: str,
    macro_path: str,
    output_path: str,
    device: torch.device,
) -> None:
    input_data = np.load(data_path, allow_pickle=True)
    print(f"\nData: {data_path}")
    print(f"input_data shape: {input_data.shape}")

    mean = np.asarray(normalization_info["mean"])
    std = np.asarray(normalization_info["std"])
    anchors_normalized = (input_data - mean) / std

    print("Encoding Z...")
    Z_E_T_D = get_z_over_time(
        model=model,
        anchors=anchors_normalized,
        device=device,
        model_type=model_type,
    )
    print(f"Z shape: {Z_E_T_D.shape}")

    macro_feat_E_T = np.load(macro_path).squeeze()
    macro_feat_E_T_1 = macro_feat_E_T[:, :, None]

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    np.savez(
        output_path,
        macro=macro_feat_E_T_1,
        Z=Z_E_T_D,
    )
    print(f"Saved to: {output_path}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate latent Z from a trained autoencoder checkpoint"
    )
    parser.add_argument(
        "--trained_dir",
        type=str,
        required=True,
        help="Path to autoencoder checkpoint (.pth)",
    )
    
    parser.add_argument(
        "--model_type",
        type=str,
        default=None,
        choices=["autoencoder", "deepset"],
        help="Force model type (default: infer from checkpoint)",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default="generate_dataset/data",
        help="Base directory for batch generation",
    )
    parser.add_argument(
        "--device",
        type=int,
        default=0,
        help="GPU device index (-1 for CPU)",
    )

    args = parser.parse_args()

    resolved_device = (
        torch.device(f"cuda:{args.device}")
        if (args.device >= 0 and torch.cuda.is_available())
        else torch.device("cpu")
    )
    
    for exp in ["exp1", "exp2", "exp3"]:
        ckpt_path = os.path.join(args.trained_dir, exp, "best_model.pth")
        # Load model + normalization
        model, ckpt, normalization_info, model_type = build_model_from_checkpoint(
            ckpt_path, resolved_device, args.model_type
        )
        print(f"Loaded {model_type} checkpoint: {ckpt_path}")

        if normalization_info is None:
            raise ValueError("Missing normalization_info in checkpoint.")

        data_dir = args.data_dir
        output_dir = os.path.join(os.path.dirname(ckpt_path), "macro_input")
        dataset_groups = [
            ("trajectories.npy", "macro_feature.npy", "macro_and_Z_over_time.npz"),
            (
                "trajectories_inDistribution.npy",
                "macro_feature_inDistribution.npy",
                "test_inDistribution.npz",
            ),
            (
                "trajectories_outDistribution_3gmm.npy",
                "macro_feature_outDistribution_3gmm.npy",
                "test_outDistribution_3gmm.npz",
            ),
            (
                "trajectories_outDistribution_400N.npy",
                "macro_feature_outDistribution_400N.npy",
                "test_outDistribution_400N.npz",
            ),
        ]

        for traj_file, macro_file, out_npz in dataset_groups:
            generate_for_dataset(
                model=model,
                model_type=model_type,
                normalization_info=normalization_info,
                data_path=os.path.join(data_dir, traj_file),
                macro_path=os.path.join(data_dir, macro_file),
                output_path=os.path.join(output_dir, out_npz),
                device=resolved_device,
            )
