# train_regression.py
import os
import time
import argparse
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from sklearn.datasets import make_regression, make_friedman1
from sklearn.preprocessing import StandardScaler


# =====================
# Args
# =====================
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp-name", type=str, required=True)
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["linear", "friedman", "housing"],
        default="linear",
    )
    parser.add_argument("--single-timescale", action="store_true")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--lr-body", type=float, default=1e-4)
    parser.add_argument("--lr-head", type=float, default=1e-3)
    parser.add_argument("--lambda-head", type=float, default=1e-4)
    parser.add_argument("--logdir", type=str, default="runs/regression")
    return parser.parse_args()


# =====================
# Model
# =====================
class MLPRegressor(nn.Module):
    def __init__(self, in_dim, hidden_dim=256):
        super().__init__()
        self.body = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
        )
        self.head = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        return self.head(self.body(x)).squeeze(-1)


# =====================
# Utils
# =====================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def grad_norm(params):
    total = 0.0
    for p in params:
        if p.grad is not None:
            total += p.grad.data.norm(2).item() ** 2
    return total ** 0.5


# =====================
# Data
# =====================
def get_dataset(name, n_samples=10000):
    if name == "linear":
        X, y = make_regression(
            n_samples=n_samples,
            n_features=20,
            noise=0.1,
            random_state=0,
        )
    elif name == "friedman":
        X, y = make_friedman1(
            n_samples=n_samples,
            noise=1.0,
            random_state=0,
        )
    else:  # housing
        X, y = make_regression(
            n_samples=n_samples,
            n_features=13,
            noise=5.0,
            random_state=0,
        )

    scaler_x = StandardScaler()
    scaler_y = StandardScaler()

    X = scaler_x.fit_transform(X)
    y = scaler_y.fit_transform(y.reshape(-1, 1)).squeeze()

    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)

    n_train = int(0.8 * len(X))
    train_ds = TensorDataset(X[:n_train], y[:n_train])
    test_ds = TensorDataset(X[n_train:], y[n_train:])

    return train_ds, test_ds, X.shape[1]


# =====================
# Main
# =====================
def main():
    args = parse_args()
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    run_name = f"{args.dataset}__{args.exp_name}__seed{args.seed}__{int(time.time())}"
    writer = SummaryWriter(os.path.join(args.logdir, run_name))

    # Data
    trainset, testset, in_dim = get_dataset(args.dataset)
    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(testset, batch_size=512, shuffle=False)

    model = MLPRegressor(in_dim).to(device)

    # Optimizer
    if args.single_timescale:
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr_body)
    else:
        optimizer = optim.RMSprop(
            [
                {"params": model.body.parameters(), "lr": args.lr_body},
                {"params": model.head.parameters(), "lr": args.lr_head},
            ]
        )

    # Training
    global_step = 0
    start_time = time.time()

    for epoch in range(args.epochs):
        model.train()
        epoch_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            pred = model(x)
            mse = F.mse_loss(pred, y)

            if args.single_timescale:
                loss = mse
            else:
                l2_head = sum((p ** 2).sum() for p in model.head.parameters())
                loss = mse + args.lambda_head * l2_head

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            writer.add_scalar("train/loss", loss.item(), global_step)
            writer.add_scalar("train/mse", mse.item(), global_step)
            writer.add_scalar("grads/body_norm", grad_norm(model.body.parameters()), global_step)
            writer.add_scalar("grads/head_norm", grad_norm(model.head.parameters()), global_step)

            epoch_loss += loss.item()
            global_step += 1

        # Eval
        model.eval()
        mse_total = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x)
                mse_total += F.mse_loss(pred, y, reduction="sum").item()

        mse_avg = mse_total / len(testset)
        avg_loss = epoch_loss / len(train_loader)

        print(
            f"[{args.dataset.upper()}] Epoch {epoch:03d} | "
            f"TrainLoss {avg_loss:.4f} | TestMSE {mse_avg:.4f}"
        )

        writer.add_scalar("eval/mse", mse_avg, epoch)
        writer.add_scalar("charts/epoch_loss", avg_loss, epoch)
        writer.add_scalar("charts/elapsed_time", time.time() - start_time, epoch)

    writer.close()


if __name__ == "__main__":
    main()
