# pretrain.py 
import os
import argparse
import random
from typing import Dict, Any, List, Tuple, Mapping, Optional
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import csv
from omegaconf import OmegaConf

from utils import prepare_obs, BabyAI_BC

def _is_dict_of_timeseries(x) -> bool:
    return isinstance(x, Mapping) and all(hasattr(v, "__getitem__") for v in x.values())

def _time_len_from_observations(obs_seq) -> int:
    if _is_dict_of_timeseries(obs_seq):
        first_key = next(iter(obs_seq))
        return len(obs_seq[first_key])
    else:
        return len(obs_seq)

def _get_obs_t(obs_seq, t: int) -> Dict[str, Any]:
    if _is_dict_of_timeseries(obs_seq):
        return {k: v[t] for k, v in obs_seq.items()}
    else:
        return obs_seq[t]

class BCDataset(Dataset):
    """
    Flattens a Minari dataset into (obs, action) transitions for BC.
    """
    def __init__(self, minari_dataset, use_text: bool, allowed_indices: Optional[set] = None):
        self.use_text = use_text
        self.samples: List[Tuple[Dict[str, Any], int]] = []

        for i, ep in enumerate(minari_dataset.iterate_episodes()):
            if allowed_indices is not None and i not in allowed_indices:
                continue

            obs_seq = ep.observations
            act_seq = ep.actions
            n_obs = _time_len_from_observations(obs_seq)
            n_act = len(act_seq)

            T = min(n_act, n_obs - 1)
            if T <= 0:
                continue

            for t in range(T):
                obs_t = _get_obs_t(obs_seq, t)
                a = int(act_seq[t])
                if "image" not in obs_t:
                    continue
                self.samples.append((obs_t, a))

        if len(self.samples) == 0:
            raise RuntimeError("No (obs, action) pairs found in Minari dataset. "
                               "Check episode structure, required keys, and your train split CSV.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs, a = self.samples[idx]
        return obs, a

def collate_fn(batch):
    obs_list, actions = zip(*batch)
    actions = torch.as_tensor(actions, dtype=torch.long)
    return list(obs_list), actions

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed % (2**32 - 1))
    torch.manual_seed(seed % (2**32 - 1))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed % (2**32 - 1))

def split_indices(n: int, val_fraction: float = 0):
    idx = np.arange(n)
    np.random.shuffle(idx)
    n_val = max(1, int(n * val_fraction)) if n > 1 else 0
    return (idx[n_val:], idx[:n_val]) if n_val > 0 else (idx, np.array([], dtype=int))

def batch_to_model_inputs(
    obs_list: List[Dict[str, Any]],
    device: torch.device,
    use_text: bool,
) -> Dict[str, torch.Tensor]:
    imgs, dirs, toks = [], [], []
    for obs in obs_list:
        prepped = prepare_obs(obs, device=device, use_text=use_text)
        imgs.append(prepped["image"])
        dir_tensor = prepped.get("direction")
        if dir_tensor is None:
            dir_tensor = torch.zeros((1,), dtype=torch.long, device=device)
        dirs.append(dir_tensor)
        if use_text:
            tok_tensor = prepped.get("mission_tokens")
            if tok_tensor is None:
                tok_tensor = torch.zeros((1, 0), dtype=torch.long, device=device)
            toks.append(tok_tensor)
    batch = {"image": torch.cat(imgs, dim=0), "direction": torch.cat(dirs, dim=0)}
    if use_text:
        batch["mission_tokens"] = torch.cat(toks, dim=0)
    return batch
def train_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, use_text: bool, label_smoothing: float, entropy_coef: float, grad_clip: float) -> Dict[str, float]:
    model.train(); ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing); total_loss = total_ce = total_ent = 0.0; total_correct = total = 0
    for obs_list, actions in loader:
        actions = actions.to(device, non_blocking=True); batch = batch_to_model_inputs(obs_list, device, use_text); logits = model(batch); ce_loss = ce(logits, actions); logp = torch.log_softmax(logits, dim=-1); p = torch.softmax(logits, dim=-1); ent = -(p * logp).sum(dim=-1).mean(); loss = ce_loss - entropy_coef * ent; optimizer.zero_grad(set_to_none=True); loss.backward()
        if grad_clip and grad_clip > 0: nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        with torch.no_grad():
            preds = logits.argmax(dim=-1); correct = (preds == actions).sum().item(); B = actions.size(0); total_correct += correct; total += B; total_loss += float(loss) * B; total_ce += float(ce_loss) * B; total_ent += float(ent) * B
    return {"loss": total_loss / total, "ce_loss": total_ce / total, "entropy": total_ent / total, "acc": total_correct / total, "n": total}
@torch.no_grad()
def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device, use_text: bool) -> Dict[str, float]:
    model.eval(); total_loss = total_correct = total = 0; ce = nn.CrossEntropyLoss()
    for obs_list, actions in loader:
        actions = actions.to(device, non_blocking=True); batch = batch_to_model_inputs(obs_list, device, use_text); logits = model(batch); loss = ce(logits, actions); preds = logits.argmax(dim=-1); correct = (preds == actions).sum().item(); B = actions.size(0); total_correct += correct; total += B; total_loss += float(loss) * B
    return {"loss": total_loss / total, "acc": total_correct / total, "n": total}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--game_code", type=str, required=True, help="Game code to load config (e.g., 'open', 'goto')")
    parser.add_argument("--config_dir", type=str, default="configs", help="Directory containing YAML config files")
    parser.add_argument("--dataset_id", type=str, default="", help="Override Minari dataset id from config")
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--entropy_coef", type=float, default=0)
    parser.add_argument("--label_smoothing", type=float, default=0.05)
    parser.add_argument("--grad_clip", type=float, default=0.5)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--save_dir", type=str, default="pretrained_models")
    parser.add_argument("--split_csv", type=str, default="", help="Path to CSV with columns episode_idx,split. If set, use only 'train' episodes.")
    args = parser.parse_args()

    config_path = os.path.join(args.config_dir, f"{args.game_code}.yaml")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    cfg = OmegaConf.load(config_path)

    dataset_id = args.dataset_id or cfg.data.dataset_id
    use_text = cfg.model.use_text
    print(f"use_text: {use_text}")
    if not dataset_id:
        raise ValueError(f"No Minari dataset_id found in config '{config_path}' or via --dataset_id argument.")

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    import minari
    ds = minari.load_dataset(dataset_id, download=True)
    
    allowed_indices: Optional[set] = None
    if args.split_csv:
        if not os.path.exists(args.split_csv):
            raise FileNotFoundError(f"Split CSV not found at: {args.split_csv}")
        
        allowed_indices = set()
        with open(args.split_csv, "r", newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                if str(row.get("split", "")).strip().lower() == "train":
                    try:
                        idx = int(row.get("episode_idx", -1))
                        if idx >= 0:
                            allowed_indices.add(idx)
                    except (ValueError, TypeError):
                        continue
        
        if not allowed_indices:
            raise ValueError(f"--split_csv provided but no 'train' episode indices were found in {args.split_csv}")
        
        print(f"Loaded {len(allowed_indices)} 'train' episode indices from {args.split_csv}")
    
    try:
        print(f"Dataset '{dataset_id}': episodes={ds.total_episodes}, steps={ds.total_steps}")
    except Exception:
        pass

    try:
        env = ds.recover_environment()
        n_actions = env.action_space.n
    except Exception:
        print("Warning: recover_environment() failed; falling back to default n_actions=7.")
        n_actions = 7

    max_tok_id = 1
    for ep in ds.iterate_episodes():
        obs_seq = ep.observations
        n_obs = _time_len_from_observations(obs_seq)
        for t in range(n_obs):
            obs_t = _get_obs_t(obs_seq, t)
            if "mission_tokens" in obs_t:
                arr = np.asarray(obs_t["mission_tokens"]).reshape(-1)
                if arr.size:
                    m = int(arr.max())
                    if m > max_tok_id: max_tok_id = m
    vocab_size = max(200, max_tok_id + 1)
    model = BabyAI_BC(n_actions=n_actions, use_text=use_text, vocab_size=vocab_size).to(device)

    full_ds = BCDataset(ds, use_text=use_text, allowed_indices=allowed_indices)
    
    idx_train, idx_val = split_indices(len(full_ds), val_fraction=0)

    class _Subset(Dataset):
        def __init__(self, base, indices): self.base, self.indices = base, indices
        def __len__(self): return len(self.indices)
        def __getitem__(self, i): return self.base[self.indices[i]]
    
    train_ds = _Subset(full_ds, idx_train)
    val_ds = _Subset(full_ds, idx_val)
    
    pin = torch.cuda.is_available()
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=pin, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=pin, collate_fn=collate_fn)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    print(f"Starting BC pretraining on {dataset_id} | batches: {len(train_loader)} | device: {device}")
    for epoch in range(1, args.epochs + 1):
        tr = train_epoch(model, train_loader, optimizer, device, use_text=use_text, label_smoothing=args.label_smoothing, entropy_coef=args.entropy_coef, grad_clip=args.grad_clip)
        va = eval_epoch(model, val_loader, device, use_text=use_text)
        print(f"[Epoch {epoch:02d}] train: loss={tr['loss']:.4f} ce={tr['ce_loss']:.4f} ent={tr['entropy']:.3f} acc={tr['acc']:.3f} | val: loss={va['loss']:.4f} acc={va['acc']:.3f}")

    os.makedirs(args.save_dir, exist_ok=True)
    save_path = os.path.join(args.save_dir, f"pretrain_{args.game_code}.pt")
    torch.save(model.state_dict(), save_path)
    print(f"Saved pretrained weights to: {save_path}")


if __name__ == "__main__":
    main()