import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


class FitMLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)


def rmse_mae(y_true: torch.Tensor, y_pred: torch.Tensor):
    diff = y_true - y_pred
    mse = torch.mean(diff.pow(2)).item()
    mae = torch.mean(torch.abs(diff)).item()
    return np.sqrt(mse), mae


def load_episode_actions(dataset, meta, ep_idx):
    from_idx = meta.episodes["dataset_from_index"][ep_idx]
    to_idx = meta.episodes["dataset_to_index"][ep_idx]
    try:
        batch = dataset.hf_dataset[from_idx:to_idx]
        actions = batch["action"]
        arr = np.asarray(actions, dtype=np.float32)
        return torch.from_numpy(arr)
    except Exception:
        acts = []
        for i in tqdm(range(from_idx, to_idx), leave=False):
            item = dataset[i]
            a = item["action"]
            if isinstance(a, torch.Tensor):
                acts.append(a)
            else:
                acts.append(torch.tensor(a))
        return torch.stack(acts).float()


def collect_episode_pairs(repo_id: str, root: str, latents_dir: str, latent_keys):
    from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
    dataset = LeRobotDataset(repo_id, root=os.path.join(root, repo_id))
    meta = LeRobotDatasetMetadata(repo_id, root=os.path.join(root, repo_id))
    total_eps = meta.total_episodes
    pairs = []
    for ep in tqdm(range(total_eps), desc=f"Collecting pairs for {repo_id}"):
        fname = f"{repo_id.replace('/', '_')}_ep{ep:05d}.npz"
        fpath = os.path.join(latents_dir, fname)
        if not os.path.isfile(fpath):
            continue
        try:
            arr = np.load(fpath)
        except Exception:
            continue
        z = None
        for key in latent_keys:
            if key in arr.files:
                z = arr[key]
                break
        if z is None:
            print(f"{repo_id}: ep{ep:05d} no latent key found")
            continue
        z = torch.from_numpy(z).float()[:, 0]
        try:
            a = load_episode_actions(dataset, meta, ep)
        except Exception:
            continue
        T = min(z.shape[0], a.shape[0])
        if T <= 0:
            continue
        pairs.append((z[:T], a[:T]))
    return meta, pairs


def make_splits(pairs, train_ratio=0.7):
    n = len(pairs)
    k = int(n * train_ratio)
    train_pairs = pairs[:k]
    test_pairs = pairs[k:]
    def cat(ps):
        X = torch.cat([p[0] for p in ps], dim=0) if ps else torch.empty(0)
        Y = torch.cat([p[1] for p in ps], dim=0) if ps else torch.empty(0)
        return X, Y
    X_train, Y_train = cat(train_pairs)
    X_test, Y_test = cat(test_pairs)
    return X_train, Y_train, X_test, Y_test


def train_one_repo(repo_id: str, root: str, latents_dir: str, latent_keys, epochs: int, batch_size: int, lr: float, hidden_dim: int, device: torch.device, save_dir: str = None):
    meta, pairs = collect_episode_pairs(repo_id, root, latents_dir, latent_keys)
    if not pairs:
        print(f"{repo_id}: no episode pairs found")
        return
    X_train, Y_train, X_test, Y_test = make_splits(pairs, 0.7)
    print(f"Shape: X_train={X_train.shape}, Y_train={Y_train.shape}, X_test={X_test.shape}, Y_test={Y_test.shape}")
    input_dim = X_train.shape[1]
    output_dim = Y_train.shape[1]
    names = meta.features["action"].get("names", None)
    print(f"Task: {repo_id}, episodes: {len(pairs)}, train_eps: {int(len(pairs)*0.7)}, test_eps: {len(pairs)-int(len(pairs)*0.7)}")
    print(f"Action dim {output_dim} : {names}")
    model = FitMLP(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    ds = TensorDataset(X_train.to(device).float(), Y_train.to(device).float())
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)
    pbar = tqdm(range(epochs), desc=repo_id)
    for _ in pbar:
        total = 0.0
        cnt = 0
        for xb, yb in loader:
            optimizer.zero_grad()
            pred = model(xb)
            loss = criterion(pred, yb)
            loss.backward()
            optimizer.step()
            total += loss.item()
            cnt += 1
        with torch.no_grad():
            yhat_tr = model(X_train.to(device).float()).cpu()
            rmse_tr, mae_tr = rmse_mae(Y_train.cpu().float(), yhat_tr.float())
        pbar.set_postfix(loss=total/max(1,cnt), rmse=rmse_tr, mae=mae_tr)
    with torch.no_grad():
        yhat_train = model(X_train.to(device).float()).cpu()
        yhat_test = model(X_test.to(device).float()).cpu() if X_test.numel() > 0 else torch.empty(0)
        rmse_tr, mae_tr = rmse_mae(Y_train.cpu().float(), yhat_train.float())
        if X_test.numel() > 0:
            rmse_te, mae_te = rmse_mae(Y_test.cpu().float(), yhat_test.float())
        else:
            rmse_te, mae_te = float("nan"), float("nan")
        if save_dir:
            fpath = os.path.join(save_dir, f"{repo_id.replace('/', '_')}_action_adapter.pth")
            torch.save(model.state_dict(), fpath)
    print(f"{repo_id} Train RMSE: {rmse_tr:.6f}, MAE: {mae_tr:.6f}")
    print(f"{repo_id} Test  RMSE: {rmse_te:.6f}, MAE: {mae_te:.6f}")
    with open(os.path.join(save_dir, "action_adapter.txt"), "a+") as f:
        f.write(f"{repo_id}\n")
        f.write(f"Shape: X_train={X_train.shape}, Y_train={Y_train.shape}, X_test={X_test.shape}, Y_test={Y_test.shape}")
        f.write(f"Action dim {output_dim} : {names}\n")
        f.write(f"Train RMSE: {rmse_tr:.6f}, MAE: {mae_tr:.6f}\n")
        f.write(f"Test  RMSE: {rmse_te:.6f}, MAE: {mae_te:.6f}\n\n")


def parse_args():
    p = argparse.ArgumentParser(description="Fit latent actions to real actions per dataset")
    p.add_argument("--repo_id", type=str, required=True, help="single or comma-separated repo ids")
    p.add_argument("--root", type=str, required=True)
    p.add_argument("--latents_dir", type=str, required=True)
    p.add_argument("--latent_key", type=str, required=True, help="npz key or comma-separated preferences")
    p.add_argument("--epochs", type=str, required=True, help="int or comma-separated per repo")
    p.add_argument("--batch_size", type=int, default=256)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--hidden_dim", type=int, default=128)
    p.add_argument("--device", type=str, default=None)
    p.add_argument("--save_dir", type=str, default=None)
    return p.parse_args()


def main():
    args = parse_args()
    device = torch.device(args.device) if args.device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.save_dir:
        os.makedirs(args.save_dir, exist_ok=True)
    repo_ids = [r.strip() for r in args.repo_id.split(",")]
    latent_keys = [k.strip() for k in args.latent_key.split(",")]
    epoch_list = [e.strip() for e in args.epochs.split(",")]
    if len(epoch_list) == 1:
        epoch_list = epoch_list * len(repo_ids)
    epochs_per_repo = [int(e) for e in epoch_list[:len(repo_ids)]]
    for repo_id, epc in zip(repo_ids, epochs_per_repo):
        train_one_repo(repo_id, args.root, args.latents_dir, latent_keys, epc, args.batch_size, args.lr, args.hidden_dim, device, args.save_dir)


if __name__ == "__main__":
    main()
