from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from .utils.seed import set_seed
from .utils.io import ensure_dir
from .gridworld.env import LongCorridor
from .data.dataset import SeqPredDataset
from .models.rnn import SeqPredGRU


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--Lx", type=int, default=48)
    ap.add_argument("--Ly", type=int, default=5)
    ap.add_argument("--n_colors", type=int, default=6)
    ap.add_argument("--obs_size", type=int, default=5)
    ap.add_argument("--k", type=int, default=3)
    ap.add_argument("--hidden", type=int, default=128)
    ap.add_argument("--n_traj_train", type=int, default=200)
    ap.add_argument("--n_traj_val", type=int, default=30)
    ap.add_argument("--T", type=int, default=40)
    ap.add_argument("--batch_size", type=int, default=128)
    ap.add_argument("--epochs", type=int, default=10)
    ap.add_argument("--lr", type=float, default=2e-3)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--seed", type=int, default=123)
    ap.add_argument("--outdir", type=str, default="outputs")
    return ap.parse_args()


def main():
    args = parse_args()
    set_seed(args.seed)
    device = torch.device(args.device)

    env = LongCorridor(Lx=args.Lx, Ly=args.Ly, n_colors=args.n_colors, obs_size=args.obs_size, seed=args.seed)

    train_ds = SeqPredDataset(env, n_traj=args.n_traj_train, T=args.T, k=args.k, seed=args.seed)
    val_ds   = SeqPredDataset(env, n_traj=args.n_traj_val,   T=args.T, k=args.k, seed=args.seed+999)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, drop_last=False)

    model = SeqPredGRU(obs_dim=train_ds.obs_dim, feat_dim=train_ds.feat_dim, hidden=args.hidden, k=args.k).to(device)
    opt = optim.Adam(model.parameters(), lr=args.lr)
    crit = nn.MSELoss()

    def eval_loss(loader):
        model.eval()
        total, count = 0.0, 0
        with torch.no_grad():
            for o0, feats, targets in loader:
                o0 = o0.to(device)
                feats = feats.to(device)
                targets = targets.to(device)
                preds, _ = model(o0, feats)
                loss = crit(preds, targets)
                total += loss.item() * o0.size(0)
                count += o0.size(0)
        return total / max(1, count)

    best_val = float("inf")
    outdir = ensure_dir(args.outdir)
    ckptdir = ensure_dir(Path(outdir) / "checkpoints")
    
    # Track losses
    train_losses = []
    val_losses = []

    for ep in range(1, args.epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {ep}/{args.epochs}")
        epoch_train_loss = 0.0
        batch_count = 0
        for o0, feats, targets in pbar:
            o0 = o0.to(device)
            feats = feats.to(device)
            targets = targets.to(device)
            preds, _ = model(o0, feats)
            loss = crit(preds, targets)
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            epoch_train_loss += loss.item()
            batch_count += 1
            pbar.set_postfix(loss=f"{loss.item():.4f}")
        
        avg_train_loss = epoch_train_loss / batch_count
        train_losses.append(avg_train_loss)
        
        val = eval_loss(val_loader)
        val_losses.append(val)
        print(f"Train loss: {avg_train_loss:.4f}, Val loss: {val:.4f}")
        if val < best_val:
            best_val = val
            torch.save({
                "model": model.state_dict(), 
                "args": vars(args),
                "train_losses": train_losses,
                "val_losses": val_losses
            }, Path(ckptdir) / f"best_k{args.k}.pt")
            print(f"Saved best checkpoint for k={args.k}")

if __name__ == "__main__":
    main()
