import argparse
import os
import numpy as np
import torch
from tqdm import tqdm

from mucola.model import MuCoLA
import functools


def parse_args():
    p = argparse.ArgumentParser(description="Generate latent actions for LeRobot episodes")
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--repo_id", type=str, required=True, help="single or comma-separated repo ids")
    p.add_argument("--root", type=str, required=True)
    p.add_argument("--output_dir", type=str, required=True)
    p.add_argument("--camera_key", type=str, default=None)
    p.add_argument("--camera_index", type=int, default=None)
    p.add_argument("--paired_camera_key", type=str, default=None)
    p.add_argument("--paired_camera_index", type=int, default=None)
    p.add_argument("--resolution", type=int, default=256)
    p.add_argument("--video_backend", type=str, default=None)
    p.add_argument("--chunk_size", type=int, default=64)
    p.add_argument("--device", type=str, default=None)
    return p.parse_args()


def load_episode_frames(dataset, meta, ep_idx, camera_key, resolution):
    from_idx = meta.episodes["dataset_from_index"][ep_idx]
    to_idx = meta.episodes["dataset_to_index"][ep_idx]
    frames = []
    for i in range(from_idx, to_idx):
        item = dataset[i]
        frame = item[camera_key]
        frames.append(frame)
    video = torch.stack(frames).float() / 255.0
    video = video.permute(0, 2, 3, 1)
    if video.shape[1] != video.shape[2]:
        s = min(video.shape[1], video.shape[2])
        hc = (video.shape[1] - s) // 2
        wc = (video.shape[2] - s) // 2
        video = video[:, hc:hc + s, wc:wc + s]
    if video.shape[1] != resolution:
        video = video.permute(3, 0, 1, 2)
        video = torch.nn.functional.interpolate(video, resolution, mode="bicubic")
        video = video.permute(1, 2, 3, 0)
    return video


@torch.no_grad()
def infer_latents(model: MuCoLA, video: torch.Tensor, device: torch.device, chunk: int):
    T = video.shape[0]
    z_rep_list, z_mu_list, z_var_list = [] , [], []
    start = 0
    while start < T:
        end = min(T, start + chunk)
        v = video[start:end].unsqueeze(0).to(device)
        outs = model.mucola.latent_actions(v)
        z_rep_list.append(outs["z_rep"].squeeze(0).squeeze(2).cpu())
        z_mu_list.append(outs["z_mu"].reshape(-1, model.mucola.latent_dim).cpu())
        z_var_list.append(outs["z_var"].reshape(-1, model.mucola.latent_dim).cpu())
        if end == T:
            break
        start = end - 1
    z_rep = torch.cat(z_rep_list, dim=0).numpy()
    z_mu = torch.cat(z_mu_list, dim=0).numpy()
    z_var = torch.cat(z_var_list, dim=0).numpy()
    return {"z_rep": z_rep, "z_mu": z_mu, "z_var": z_var}


def main():
    args = parse_args()
    device = torch.device(args.device) if args.device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.output_dir, exist_ok=True)

    try:
        torch.serialization.add_safe_globals([functools.partial])
    except Exception:
        pass
    try:
        model = MuCoLA.load_from_checkpoint(args.checkpoint)
    except Exception:
        ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
        hparams = ckpt.get("hyper_parameters", {})
        state = ckpt.get("state_dict", {})
        model = MuCoLA(**hparams)
        model.load_state_dict(state, strict=False)
    model.eval().to(device)

    repo_ids = [r.strip() for r in args.repo_id.split(",")]
    from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata

    camera_keys = ['observation.images.' + c.strip() for c in args.camera_key.split(',')]
    paired_camera_keys = ['observation.images.' + c.strip() for c in args.paired_camera_key.split(',')] if args.paired_camera_key else None
    for i, repo_id in enumerate(repo_ids):
        camera_key = camera_keys[i] if camera_keys is not None and i < len(camera_keys) else None
        paired_camera_key = paired_camera_keys[i] if paired_camera_keys is not None and i < len(paired_camera_keys) else None
        ds_kwargs = {}
        if args.video_backend:
            ds_kwargs["video_backend"] = args.video_backend
        dataset = LeRobotDataset(repo_id, root=os.path.join(args.root, repo_id), **ds_kwargs)
        meta = LeRobotDatasetMetadata(repo_id, root=os.path.join(args.root, repo_id))
        if camera_key:
            cam = camera_key
        elif args.camera_index is not None:
            cam = meta.camera_keys[args.camera_index]
        else:
            cam = meta.camera_keys[0]
        paired = None
        if paired_camera_key:
            paired = paired_camera_key
        elif args.paired_camera_index is not None:
            if args.paired_camera_index < len(meta.camera_keys):
                paired = meta.camera_keys[args.paired_camera_index]

        total_eps = meta.total_episodes
        for ep in tqdm(range(total_eps), desc=f"{repo_id} episodes"):
            try:
                vid = load_episode_frames(dataset, meta, ep, cam, args.resolution)
            except Exception:
                continue
            out = infer_latents(model, vid, device, args.chunk_size)
            data = {
                "z_rep": out["z_rep"],
                "z_mu": out["z_mu"],
                "z_var": out["z_var"],
                "latent_dim": model.mucola.latent_dim,
                "episode_idx": ep,
            }
            if paired is not None:
                try:
                    vid2 = load_episode_frames(dataset, meta, ep, paired, args.resolution)
                    out2 = infer_latents(model, vid2, device, args.chunk_size)
                    data.update({
                        "z_rep_view1": data.pop("z_rep"),
                        "z_mu_view1": data.pop("z_mu"),
                        "z_var_view1": data.pop("z_var"),
                        "z_rep_view2": out2["z_rep"],
                        "z_mu_view2": out2["z_mu"],
                        "z_var_view2": out2["z_var"],
                    })
                except Exception:
                    pass
            fname = f"{repo_id.replace('/', '_')}_ep{ep:05d}.npz"
            np.savez(os.path.join(args.output_dir, fname), **data)


if __name__ == "__main__":
    main()
