# pretrain.py
import os
import argparse
import random
from typing import Dict, Any, List, Tuple, Mapping, Optional
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import csv
from omegaconf import OmegaConf

from utils import prepare_obs, HIDDEN_DIM

# Simple MLP policy 
class MLPPolicy(nn.Module):
    """
    A simple MLP policy for behavioral cloning and as a PPO actor.
    Input: flattened image + 4-way one-hot direction
    Output: logits over n_actions
    """
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

def _is_dict_of_timeseries(x) -> bool:
    return isinstance(x, Mapping) and all(hasattr(v, "__getitem__") for v in x.values())

def _time_len_from_observations(obs_seq) -> int:
    if _is_dict_of_timeseries(obs_seq):
        first_key = next(iter(obs_seq))
        return len(obs_seq[first_key])
    else:
        return len(obs_seq)

def _get_obs_t(obs_seq, t: int) -> Dict[str, Any]:
    if _is_dict_of_timeseries(obs_seq):
        return {k: v[t] for k, v in obs_seq.items()}
    else:
        return obs_seq[t]

class BCDataset(Dataset):
    """
    Flattens a Minari dataset into (obs, action) transitions for BC.
    """
    def __init__(self, minari_dataset, use_text: bool, allowed_indices: Optional[set] = None):
        self.use_text = use_text
        self.samples: List[Tuple[Dict[str, Any], int]] = []

        for i, ep in enumerate(minari_dataset.iterate_episodes()):
            if allowed_indices is not None and i not in allowed_indices:
                continue

            obs_seq = ep.observations
            act_seq = ep.actions
            n_obs = _time_len_from_observations(obs_seq)
            n_act = len(act_seq)

            T = min(n_act, n_obs - 1)
            if T <= 0:
                continue

            for t in range(T):
                obs_t = _get_obs_t(obs_seq, t)
                if "image" not in obs_t:
                    continue
                a = int(act_seq[t])
                self.samples.append((obs_t, a))

        if len(self.samples) == 0:
            raise RuntimeError(
                "No (obs, action) pairs found in Minari dataset. "
                "Check episode structure, required keys, and your train split CSV."
            )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs, a = self.samples[idx]
        return obs, a

def collate_fn(batch):
    obs_list, actions = zip(*batch)
    actions = torch.as_tensor(actions, dtype=torch.long)
    return list(obs_list), actions

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed % (2**32 - 1))
    torch.manual_seed(seed % (2**32 - 1))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed % (2**32 - 1))

def split_indices(n: int, val_fraction: float = 0.1):
    idx = np.arange(n)
    np.random.shuffle(idx)
    n_val = max(1, int(n * val_fraction)) if n > 1 else 0
    return (idx[n_val:], idx[:n_val]) if n_val > 0 else (idx, np.array([], dtype=int))

def _obs_to_feature(prepped: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    Convert prepared obs to a single feature row:
    - Flattened image (C,H,W) -> (C*H*W)
    - Direction one-hot (size 4)
    """
    img: torch.Tensor = prepped["image"]          
    x_img = img.view(1, -1).float()              

    dir_tensor: torch.Tensor = prepped.get("direction")
    if dir_tensor is None:
        dir_tensor = torch.zeros((1,), dtype=torch.long, device=img.device)
    dir_tensor = dir_tensor.clamp(0, 3)           
    dir_oh = torch.nn.functional.one_hot(dir_tensor, num_classes=4).float() 

    return torch.cat([x_img, dir_oh], dim=1)    

def _compute_input_dim(ds, device: torch.device, use_text: bool) -> int:
    """
    Probe the dataset to compute the MLP input dimension.
    """
    for ep in ds.iterate_episodes():
        obs_seq = ep.observations
        n_obs = _time_len_from_observations(obs_seq)
        for t in range(n_obs):
            obs_t = _get_obs_t(obs_seq, t)
            if "image" not in obs_t:
                continue
            prepped = prepare_obs(obs_t, device=device, use_text=use_text)
            feat = _obs_to_feature(prepped)
            return feat.shape[1]
    raise RuntimeError("Could not infer input_dim: no observation with 'image' found.")

def batch_to_model_inputs(
    obs_list: List[Dict[str, Any]],
    device: torch.device,
    use_text: bool,
) -> torch.Tensor:
    """
    Build a single [B, input_dim] tensor for the MLP.
    """
    feats = []
    for obs in obs_list:
        prepped = prepare_obs(obs, device=device, use_text=use_text)
        feats.append(_obs_to_feature(prepped))  
    return torch.cat(feats, dim=0)              

def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    use_text: bool,
    label_smoothing: float,
    entropy_coef: float,
    grad_clip: float
) -> Dict[str, float]:
    model.train()
    ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    total_loss = total_ce = total_ent = 0.0
    total_correct = total = 0
    for obs_list, actions in loader:
        actions = actions.to(device, non_blocking=True)
        X = batch_to_model_inputs(obs_list, device, use_text) 
        logits = model(X)                                      

        ce_loss = ce(logits, actions)
        logp = torch.log_softmax(logits, dim=-1)
        p = torch.softmax(logits, dim=-1)
        ent = -(p * logp).sum(dim=-1).mean()

        loss = ce_loss - entropy_coef * ent
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        if grad_clip and grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        with torch.no_grad():
            preds = logits.argmax(dim=-1)
            correct = (preds == actions).sum().item()
            B = actions.size(0)
            total_correct += correct
            total += B
            total_loss += float(loss) * B
            total_ce += float(ce_loss) * B
            total_ent += float(ent) * B

    return {
        "loss": total_loss / total,
        "ce_loss": total_ce / total,
        "entropy": total_ent / total,
        "acc": total_correct / total,
        "n": total
    }

@torch.no_grad()
def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device, use_text: bool) -> Dict[str, float]:
    model.eval()
    total_loss = total_correct = total = 0
    ce = nn.CrossEntropyLoss()
    for obs_list, actions in loader:
        actions = actions.to(device, non_blocking=True)
        X = batch_to_model_inputs(obs_list, device, use_text)  
        logits = model(X)
        loss = ce(logits, actions)
        preds = logits.argmax(dim=-1)
        correct = (preds == actions).sum().item()
        B = actions.size(0)
        total_correct += correct
        total += B
        total_loss += float(loss) * B
    return {"loss": total_loss / total, "acc": total_correct / total, "n": total}

# main
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--game_code", type=str, required=True, help="Game code to load config (e.g., 'open', 'goto')")
    parser.add_argument("--config_dir", type=str, default="configs", help="Directory containing YAML config files")
    parser.add_argument("--dataset_id", type=str, default="", help="Override Minari dataset id from config")
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--entropy_coef", type=float, default=0.0)
    parser.add_argument("--label_smoothing", type=float, default=0.05)
    parser.add_argument("--grad_clip", type=float, default=0.5)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--save_dir", type=str, default="pretrained_models")
    parser.add_argument("--split_csv", type=str, default="", help="Path to CSV with columns episode_idx,split. If set, use only 'train' episodes.")
    args = parser.parse_args()

    config_path = os.path.join(args.config_dir, f"{args.game_code}.yaml")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    cfg = OmegaConf.load(config_path)

    dataset_id = args.dataset_id or cfg.data.dataset_id
    use_text = cfg.model.use_text
    print(f"use_text (ignored for features beyond direction): {use_text}")
    if not dataset_id:
        raise ValueError(f"No Minari dataset_id found in config '{config_path}' or via --dataset_id argument.")

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    import minari
    ds = minari.load_dataset(dataset_id, download=True)

    allowed_indices: Optional[set] = None
    if args.split_csv:
        if not os.path.exists(args.split_csv):
            raise FileNotFoundError(f"Split CSV not found at: {args.split_csv}")
        allowed_indices = set()
        with open(args.split_csv, "r", newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                if str(row.get("split", "")).strip().lower() == "train":
                    try:
                        idx = int(row.get("episode_idx", -1))
                        if idx >= 0:
                            allowed_indices.add(idx)
                    except (ValueError, TypeError):
                        continue
        if not allowed_indices:
            raise ValueError(f"--split_csv provided but no 'train' episode indices were found in {args.split_csv}")
        print(f"Loaded {len(allowed_indices)} 'train' episode indices from {args.split_csv}")

    try:
        print(f"Dataset '{dataset_id}': episodes={ds.total_episodes}, steps={ds.total_steps}")
    except Exception:
        pass

    try:
        env = ds.recover_environment()
        n_actions = env.action_space.n
    except Exception:
        print("Warning: recover_environment() failed; falling back to default n_actions=7.")
        n_actions = 7

    full_ds = BCDataset(ds, use_text=use_text, allowed_indices=allowed_indices)

    input_dim = _compute_input_dim(ds, device=device, use_text=use_text)
    print(f"MLP input_dim={input_dim}, n_actions={n_actions}")

    model = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=HIDDEN_DIM).to(device)
    idx_train, idx_val = split_indices(len(full_ds), val_fraction=0)

    class _Subset(Dataset):
        def __init__(self, base, indices): self.base, self.indices = base, indices
        def __len__(self): return len(self.indices)
        def __getitem__(self, i): return self.base[self.indices[i]]

    train_ds = _Subset(full_ds, idx_train)
    val_ds = _Subset(full_ds, idx_val)

    pin = torch.cuda.is_available()
    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=pin, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=pin, collate_fn=collate_fn
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    print(f"Starting BC pretraining on {dataset_id} | batches: {len(train_loader)} | device: {device}")
    for epoch in range(1, args.epochs + 1):
        tr = train_epoch(
            model, train_loader, optimizer, device,
            use_text=use_text, label_smoothing=args.label_smoothing,
            entropy_coef=args.entropy_coef, grad_clip=args.grad_clip
        )
        va = eval_epoch(model, val_loader, device, use_text=use_text)
        print(f"[Epoch {epoch:02d}] train: loss={tr['loss']:.4f} ce={tr['ce_loss']:.4f} ent={tr['entropy']:.3f} acc={tr['acc']:.3f} | "
              f"val: loss={va['loss']:.4f} acc={va['acc']:.3f}")

    os.makedirs(args.save_dir, exist_ok=True)
    save_path = os.path.join(args.save_dir, f"pretrain_{args.game_code}.pt")
    torch.save(model.state_dict(), save_path)
    print(f"Saved pretrained weights to: {save_path}")

if __name__ == "__main__":
    main()
