"""
realworld_experiments.py

Runs simple real-world style experiments:
- climate -> regression (uses climate_agriculture.generate_climate_yield)
- healthcare sparse simulation (if available)
"""
# At top of experiment file
import importlib
pkg = "Code"  # package directory name in this repo; adjust if you rename the folder
datasets = importlib.import_module(f"{pkg}.datasets")
climate_agriculture = getattr(datasets, "climate_agriculture")
healthcare_sparse   = getattr(datasets, "healthcare_sparse")

import os, sys, argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# Prefer package-relative imports
try:
    from Code.mydatasets import climate_agriculture, healthcare_sparse
except Exception:
    try:
        from Code.mydatasets import climate_agriculture, healthcare_sparse
    except Exception as e:
        print("Warning: failed to import some datasets:", e)
        climate_agriculture = None
        healthcare_sparse = None

# fallback simple model (regression)
class SimpleRegressor(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1))
    def forward(self, x, y=None):
        return self.net(x)

def build_loader(x,y,batch_size=64):
    return DataLoader(TensorDataset(x,y), batch_size=batch_size, shuffle=True)

def main(args):
    device = args.device if getattr(args, "device", None) else ('cuda' if torch.cuda.is_available() else 'cpu')
    out_dir = os.path.join(os.path.dirname(__file__), "..", "results", "realworld")
    out_dir = os.path.abspath(out_dir)
    os.makedirs(out_dir, exist_ok=True)

    if args.task == 'climate':
        if climate_agriculture is None:
            raise RuntimeError('climate dataset module not available')
        X,y = climate_agriculture.generate_climate_yield(n_samples=args.n_samples, seed=args.seed)
        X = X.float(); y = y.view(-1,1).float()
        model = SimpleRegressor(X.shape[1])
        criterion = nn.MSELoss()
    elif args.task == 'healthcare':
        if healthcare_sparse is None:
            raise RuntimeError('healthcare dataset module not available')
        X,y = healthcare_sparse.generate_healthcare_sparse(n_samples=args.n_samples, seed=args.seed)
        X = X.float(); y = y.view(-1,1).float()
        model = SimpleRegressor(X.shape[1])
        criterion = nn.MSELoss()
    else:
        raise ValueError('Unknown task')

    try:
        from ..utils.training_utils import train, set_seed
        from ..utils.metrics import MetricsFn
        from ..utils.visualization import plot_learning_curve
    except Exception:
        from utils.training_utils import train, set_seed
        from utils.metrics import MetricsFn
        from utils.visualization import plot_learning_curve

    set_seed(args.seed)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    split = int(0.8 * X.shape[0])
    train_loader = build_loader(X[:split], y[:split], batch_size=args.batch_size)
    val_loader = build_loader(X[split:], y[split:], batch_size=args.batch_size)

    history = train(model, train_loader, val_loader, optimizer, criterion, device=device, epochs=args.epochs, out_dir=out_dir, metrics_fn=MetricsFn(task='regression'))
    plot_learning_curve(history, os.path.join(out_dir, f"learning_curve_{args.task}.png"), title=f"Realworld {args.task}")
    with open(os.path.join(out_dir, 'history.json'), 'w') as fh:
        import json
        json.dump(history, fh)
    print('Saved results to', out_dir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', choices=['climate','healthcare'], default='climate')
    parser.add_argument('--n_samples', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default=None)
    args = parser.parse_args()
    main(args)