"""
sota_comparison.py

Lightweight benchmarking harness to compare AWML (if available) with simple baselines.
"""

# At top of experiment file
import importlib
pkg = "Code"  # package directory name in this repo; adjust if you rename the folder
datasets = importlib.import_module(f"{pkg}.datasets")
climate_agriculture = getattr(datasets, "climate_agriculture")
healthcare_sparse   = getattr(datasets, "healthcare_sparse")


import os, sys, argparse, time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

try:
    from ..datasets import synthetic_classification as synth_datasets
except Exception:
    try:
        from datasets import synthetic_classification as synth_datasets
    except Exception as e:
        synth_datasets = None

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim,256), nn.ReLU(), nn.Linear(256, out_dim))
    def forward(self, x, y=None):
        return self.net(x)

def build_loader(X,y,batch_size=64):
    return DataLoader(TensorDataset(X,y), batch_size=batch_size, shuffle=True)

def train_once(model, X, y, epochs=30, lr=1e-3, batch_size=64, device='cpu'):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss() if y.dtype==torch.long else nn.MSELoss()
    split = int(0.8 * X.shape[0])
    train_loader = build_loader(X[:split], y[:split], batch_size=batch_size)
    val_loader = build_loader(X[split:], y[split:], batch_size=batch_size)
    try:
        from ..utils.training_utils import train, set_seed
        from ..utils.metrics import MetricsFn
    except Exception:
        from utils.training_utils import train, set_seed
        from utils.metrics import MetricsFn
    set_seed(0)
    hist = train(model, train_loader, val_loader, opt, crit, device=device, epochs=epochs, out_dir=os.path.join(os.path.dirname(__file__), "..", "results", "sota"), metrics_fn=MetricsFn(task='classification' if y.dtype==torch.long else 'regression'), verbose=False)
    last_metrics = {k: v[-1] for k,v in hist.items() if k.startswith('val_')}
    return hist, last_metrics

def main(args):
    if synth_datasets is None:
        raise RuntimeError('synthetic datasets not available')
    X,y = synth_datasets.generate_blobs(n_samples=args.n_samples, n_features=2, n_classes=2, seed=args.seed)
    if y.dtype != torch.long:
        y = y.long()
    X = X.float()
    mlp = MLP(X.shape[1], 2)
    h_mlp, m_mlp = train_once(mlp, X, y, epochs=args.epochs, lr=args.lr, batch_size=args.batch_size, device=args.device or ('cuda' if torch.cuda.is_available() else 'cpu'))
    print("MLP metrics:", m_mlp)
    try:
        from ..models.world_model import AWMLWorldModel as UserAWML
        awml = UserAWML(obs_dim=X.shape[1])
        _, m_awml = train_once(awml, X, y, epochs=args.epochs, lr=args.lr, batch_size=args.batch_size, device=args.device or ('cuda' if torch.cuda.is_available() else 'cpu'))
        print("AWML metrics:", m_awml)
    except Exception as e:
        print("AWML model not available or failed to construct:", e)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_samples', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', default=None)
    args = parser.parse_args()
    main(args)
