"""
synthetic_experiments.py

Train a simple AWML-like model (or a fallback model) on synthetic tasks:
- classification blobs / spirals
- regression climate yield
- simple PDE field (if available)

This script prefers package-relative imports so it works when the project is imported
as a package (e.g., Code.experiments.synthetic_experiments).
"""

# At top of experiment file
import importlib
pkg = "Code"  # package directory name in this repo; adjust if you rename the folder
datasets = importlib.import_module(f"{pkg}.mydatasets")
climate_agriculture = getattr(datasets, "climate_agriculture")
healthcare_sparse   = getattr(datasets, "healthcare_sparse")


import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import sys

# Prefer package-relative dataset imports when available
try:
    # when imported as Code.experiments.synthetic_experiments, the following works
    from ..datasets import synthetic_classification as synth_datasets
    from ..datasets import climate_agriculture as climate_datasets
except Exception:
    try:
        # fallback to top-level import (useful during interactive runs)
        from datasets import synthetic_classification as synth_datasets
        from datasets import climate_agriculture as climate_datasets
    except Exception as e:
        print("Warning: failed to import uploaded datasets package:", e)
        synth_datasets = None
        climate_datasets = None

# Try importing AWML model from user's code (package-relative)
AWMLModel = None
try:
    from ..models.world_model import AWMLWorldModel as UserAWML
    AWMLModel = UserAWML
except Exception:
    try:
        # fallback absolute import
        from models.world_model import AWMLWorldModel as UserAWML2
        AWMLModel = UserAWML2
    except Exception:
        AWMLModel = None

# Fallback simple model
class FallbackModel(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.GELU(),
            nn.Linear(128, 64),
            nn.GELU(),
            nn.Linear(64, out_dim)
        )
    def forward(self, x, y=None):
        return self.net(x)

def build_dataloader_from_tensors(x, y=None, batch_size=32, shuffle=True):
    if y is None:
        ds = TensorDataset(x)
    else:
        ds = TensorDataset(x, y)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

def main(args):
    device = args.device if getattr(args, "device", None) else ("cuda" if torch.cuda.is_available() else "cpu")
    out_dir = os.path.join(os.path.dirname(__file__), "..", "results", "synthetic")
    out_dir = os.path.abspath(out_dir)
    os.makedirs(out_dir, exist_ok=True)

    # choose dataset
    if args.dataset == "blobs":
        if synth_datasets is None:
            raise RuntimeError("synthetic datasets not available")
        X,y = synth_datasets.generate_blobs(n_samples=args.n_samples, n_features=2, n_classes=2, seed=args.seed)
        task = "classification"
        out_dim = 2
    elif args.dataset == "spirals":
        X,y = synth_datasets.generate_spirals(n_samples=args.n_samples, noise=0.1, seed=args.seed)
        task = "classification"
        out_dim = 2
    elif args.dataset == "climate":
        X,y = climate_datasets.generate_climate_yield(n_samples=args.n_samples, seed=args.seed)
        task = "regression"
        out_dim = 1
    else:
        raise ValueError("Unknown dataset")

    # create dataloaders
    X = X.float()
    if task == "classification":
        y = y.long()
        model_out_dim = out_dim
    else:
        y = y.view(-1,1).float()
        model_out_dim = 1

    split = int(0.8 * X.shape[0])
    train_X, val_X = X[:split], X[split:]
    train_y, val_y = y[:split], y[split:]
    train_loader = build_dataloader_from_tensors(train_X, train_y, batch_size=args.batch_size)
    val_loader = build_dataloader_from_tensors(val_X, val_y, batch_size=args.batch_size, shuffle=False)

    # instantiate model
    ModelClass = AWMLModel if AWMLModel is not None else FallbackModel
    try:
        model = ModelClass(obs_dim=X.shape[1]) if AWMLModel is not None else ModelClass(X.shape[1], model_out_dim)
    except Exception:
        model = ModelClass(X.shape[1], model_out_dim)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    if task == "classification":
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.MSELoss()

    # training util (package-relative)
    try:
        from ..utils.training_utils import train, set_seed
        from ..utils.metrics import MetricsFn
        from ..utils.visualization import plot_learning_curve
    except Exception:
        # fallback to absolute imports
        from utils.training_utils import train, set_seed
        from utils.metrics import MetricsFn
        from utils.visualization import plot_learning_curve

    set_seed(args.seed)
    history = train(model, train_loader, val_loader, optimizer, criterion, device=device, epochs=args.epochs, out_dir=out_dir, metrics_fn=MetricsFn(task=task), verbose=True)

    plot_learning_curve(history, os.path.join(out_dir, f"learning_curve_{args.dataset}.png"), title=f"Synthetic {args.dataset}")

    # save history
    import json
    with open(os.path.join(out_dir, f"history_{args.dataset}.json"), "w") as fh:
        json.dump(history, fh)

    print("Done. Results in:", out_dir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', choices=['blobs','spirals','climate'], default='blobs')
    parser.add_argument('--n_samples', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=40)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default=None)
    args = parser.parse_args()
    main(args)
