# lightning_sweep.py
# --------------------------------------------------
# Hyperparameter Sweep (Lightning version) + StandardScaler
# --------------------------------------------------
import os, csv, random, argparse, pathlib, json, time, pickle
from functools import partial
from collections import Counter
from typing import Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import multiprocessing as mp
mp.set_start_method("spawn", force=True)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import StandardScaler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from tqdm.auto import tqdm


# ---------- 0) helpers --------------------------------------------------------
def remove_outlier(dataset, hand_crafted=False):
    if not hand_crafted:
        print("Removing outliers from dataset")
        bad_idx_trn, bad_idx_val, bad_idx_tst = 23082, 473, 7880
        trn_idx_to_select = list(range(bad_idx_trn)) + list(range(bad_idx_trn+1,len(dataset["train"])))
        val_idx_to_select = list(range(bad_idx_val)) + list(range(bad_idx_val+1,len(dataset["validation"])))
        tst_idx_to_select = list(range(bad_idx_tst)) + list(range(bad_idx_tst+1,len(dataset["test"])))
    else:
        print("Removing outliers from hand-crafted dataset")
        bad_idx_trn, bad_idx_val, bad_idx_tst = [3010, 9693, 16524, 22151], [449], [1158]
        trn_idx_to_select = list(sorted(set(range(len(dataset["train"]))) - set(bad_idx_trn)))
        val_idx_to_select = list(sorted(set(range(len(dataset["validation"]))) - set(bad_idx_val)))
        tst_idx_to_select = list(sorted(set(range(len(dataset["test"]))) - set(bad_idx_tst)))

    dataset["train"]      = dataset["train"].select(trn_idx_to_select)
    dataset["validation"] = dataset["validation"].select(val_idx_to_select)
    dataset["test"]       = dataset["test"].select(tst_idx_to_select)
    print(f"selected {len(dataset['train'])} train samples, {len(dataset['validation'])} validation samples, {len(dataset['test'])} test samples")
    return dataset


class ScenarioDataset(Dataset):
    """Wrap HF dataset to compute per-sample features for band scenarios."""
    def __init__(self, hf_ds, scenario="avg", label_key="label_idx", scaler: Optional[StandardScaler] = None):
        assert scenario in ("concat", "avg", "g", "r")
        self.ds, self.scenario, self.lkey = hf_ds, scenario, label_key
        self.scaler = scaler

    def __len__(self):  return len(self.ds)

    def __getitem__(self, idx):
        rec   = self.ds[idx]
        if "embeddings_g" in rec:
            emb_g = np.squeeze(np.array(rec["embeddings_g"], dtype=np.float32))
            emb_r = np.squeeze(np.array(rec["embeddings_r"], dtype=np.float32))
            if emb_g.ndim > 1:
                avg_g, avg_r = emb_g.mean(0), emb_r.mean(0)
            else:
                avg_g, avg_r = emb_g, emb_r
        else:
            emb_g = np.squeeze(np.array(rec["g_embedding"], dtype=np.float32))
            emb_r = np.squeeze(np.array(rec["r_embedding"], dtype=np.float32))
            avg_g, avg_r = emb_g, emb_r
        try:
            if   self.scenario == "concat": x_np = np.concatenate([avg_g, avg_r], 0)
            elif self.scenario == "avg":    x_np = 0.5 * (avg_g + avg_r)
            elif self.scenario == "g":      x_np = avg_g
            else:                           x_np = avg_r
        except:
            print(f"Error in __getitem__: {idx}")
            print(f"rec: {rec}")
            print(f"shapes: emb_g: {emb_g.shape}, emb_r: {emb_r.shape}, avg_g: {avg_g.shape}, avg_r: {avg_r.shape}")
            raise

        if self.scaler is not None:
            x_np = self.scaler.transform(x_np.reshape(1, -1)).astype(np.float32).ravel()

        return torch.from_numpy(x_np), torch.tensor(rec[self.lkey], dtype=torch.long)


class MLP(nn.Module):
    """Simple feed-forward classifier."""
    def __init__(self, in_dim, hidden_dims, num_classes, dropout=0.0):
        super().__init__()
        layers, last = [], in_dim
        for h in hidden_dims:
            layers += [nn.Linear(last, h), nn.ReLU(inplace=True)]
            if dropout: layers.append(nn.Dropout(dropout))
            last = h
        layers.append(nn.Linear(last, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):  return self.net(x)


# ---------- 1) Lightning modules ---------------------------------------------
class LitMLP(pl.LightningModule):
    def __init__(self, in_dim, hidden_dims, num_classes, lr, dropout=0.0, class_weights=None):
        super().__init__()
        self.save_hyperparameters()             # -> logs all hparams automatically
        self.model     = MLP(in_dim, hidden_dims, num_classes, dropout)
        if class_weights is not None:
            # keep on model so it moves with .to(device)
            self.register_buffer("class_weights", class_weights)
            self.criterion = nn.CrossEntropyLoss(weight=self.class_weights)
        else:
            self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    # — training / validation / test steps —
    def _shared_step(self, batch):
        x, y   = batch
        logits = self(x)
        loss   = self.criterion(logits, y)
        acc    = (logits.argmax(1) == y).float().mean()
        return loss, acc

    def training_step(self, batch, _):
        loss, acc = self._shared_step(batch)
        # log *epoch*-level metrics so the CSV has train_loss_epoch / train_acc_epoch
        self.log("train_loss_epoch", loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("train_acc_epoch",  acc,  on_step=False, on_epoch=True, sync_dist=True)
        return loss

    def validation_step(self, batch, _):
        loss, acc = self._shared_step(batch)
        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.log("val_acc",  acc,  prog_bar=True, sync_dist=True)

    def test_step(self, batch, _):
        loss, acc = self._shared_step(batch)
        self.log_dict({"test_loss": loss, "test_acc": acc}, prog_bar=False)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


class ScenarioDataModule(pl.LightningDataModule):
    def __init__(self, train_ds, val_ds, test_ds, batch_size, scenario,
                 num_workers=8, val_batch=128, scaler: Optional[StandardScaler] = None):
        super().__init__()
        self.train_ds, self.val_ds, self.test_ds = train_ds, val_ds, test_ds
        self.batch_size, self.val_batch = batch_size, val_batch
        self.scenario, self.num_workers = scenario, num_workers
        self.scaler = scaler

    def train_dataloader(self):
        return DataLoader(
            ScenarioDataset(self.train_ds, self.scenario, scaler=self.scaler),
            batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=True, persistent_workers=True
        )

    def val_dataloader(self):
        return DataLoader(
            ScenarioDataset(self.val_ds, self.scenario, scaler=self.scaler),
            batch_size=self.val_batch, shuffle=False,
            num_workers=self.num_workers//2 or 2, pin_memory=True, persistent_workers=True
        )

    def test_dataloader(self):
        return DataLoader(
            ScenarioDataset(self.test_ds, self.scenario, scaler=self.scaler),
            batch_size=self.val_batch, shuffle=False,
            num_workers=self.num_workers//2 or 2, pin_memory=True, persistent_workers=True
        )


def fit_standardizer(train_hf_ds, scenario: str, batch_size: int = 8192, max_samples: Optional[int] = None) -> StandardScaler:
    """
    Fit a StandardScaler on training features only (for the chosen scenario).
    Iterates in batches to avoid loading everything into memory.
    """
    scaler = StandardScaler()
    loader = DataLoader(
        ScenarioDataset(train_hf_ds, scenario),  # no scaler during fitting
        batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )

    n_seen = 0
    for xb, _ in tqdm(loader, desc="Fitting StandardScaler (partial_fit)"):
        x = xb.numpy()
        if max_samples is not None:
            remain = max_samples - n_seen
            if remain <= 0:
                break
            x = x[:remain]
        scaler.partial_fit(x)
        n_seen += len(x)

    return scaler


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--scenario", type=str, default="avg")
    parser.add_argument("--hidden_layers", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=1024)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--patience", type=int, default=3)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--out_dir", type=str, default="sweep_results")
    parser.add_argument("--input_embs", type=str, default="/xxx/csdr1_raw4_catflags_filtered_embs_hand_crafted_trn_val_tst_bandgr")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--hand_crafted", type=int, default=0)
    parser.add_argument("--standardize", type=int, default=1, help="Use StandardScaler (fit on train only)")
    parser.add_argument("--std_max_samples", type=int, default=0, help="Cap samples when fitting scaler (0=all)")
    args = parser.parse_args()

    # print all args' values
    print(f"scenario: {args.scenario}")
    print(f"hidden_layers: {args.hidden_layers}")
    print(f"batch_size: {args.batch_size}")
    print(f"lr: {args.lr}")
    print(f"dropout: {args.dropout}")
    print(f"patience: {args.patience}")
    print(f"epochs: {args.epochs}")
    print(f"seed: {args.seed}")
    print(f"input_embs: {args.input_embs}")
    print(f"out_dir: {args.out_dir}")
    print(f"standardize: {args.standardize}")

    # ---------- Globals & reproducibility ----------
    SEED = args.seed
    pl.seed_everything(SEED, workers=True)
    torch.backends.cudnn.deterministic, torch.backends.cudnn.benchmark = True, False

    DATASET_PATH = args.input_embs
    ds           = remove_outlier(load_from_disk(DATASET_PATH), args.hand_crafted)
    train_ds, val_ds, test_ds = ds["train"], ds["validation"], ds["test"]

    # label remap
    orig_labels    = sorted(set(train_ds["class_str"]), key=lambda s: int(s))
    label2idx      = {lab: i for i, lab in enumerate(orig_labels)}
    class_name_map = {
        "1":  "EW",
        "2":  "EA",
        "4":  "RRab",
        "5":  "RRc",
        "6":  "RRd",
        "8":  "RS CVn",
        "13": "LPV"
    }
    def add_label(example, mapping):
        try:
            result = {"label_idx": mapping[example["class_str"]]}
        except:
            print(f"example['class_str']: {example['class_str']}, {type(example['class_str'])};")
            print(f"mapping: {mapping}")
            raise
        return result
    text_labels = [ class_name_map[c] for c in orig_labels ]
    print(f"label2idx: {label2idx}")
    train_ds      = train_ds.map(partial(add_label, mapping=label2idx), num_proc=4)
    val_ds        = val_ds.map  (partial(add_label, mapping=label2idx), num_proc=2)
    test_ds       = test_ds.map (partial(add_label, mapping=label2idx), num_proc=2)
    num_classes   = len(orig_labels)

    print("Train class distribution →", Counter(train_ds["label_idx"]))

    # Inverse-frequency class weights (normalized to mean=1)
    cls_counts  = np.bincount(train_ds["label_idx"], minlength=num_classes)
    inv_freq    = 1.0 / cls_counts
    weights     = torch.tensor(inv_freq / inv_freq.mean(), dtype=torch.float32)

    # ---------- Sweep definition ----------
    SCENARIO, PATIENCE, EPOCHS = args.scenario, args.patience, args.epochs
    if args.hidden_layers == 3:
        HIDDEN_DIMS = [1024, 512, 256]
    elif args.hidden_layers == 2:
        HIDDEN_DIMS = [512, 256]
    else:
        raise ValueError(f"Invalid number of hidden layers: {args.hidden_layers}")

    base_name = pathlib.Path(args.input_embs).stem
    experiments = [
        {"name": f"{base_name}_bs{args.batch_size}_lr{args.lr}_do{args.dropout}_{args.scenario}_s{args.seed}",
         "batch_size": args.batch_size, "lr": args.lr, "dropout": args.dropout},
    ]

    out_dir = args.out_dir
    os.makedirs(out_dir, exist_ok=True)

    for exp in experiments:
        name = exp["name"]
        print(f"\n=== Experiment: {name} ===")

        first_x, _ = ScenarioDataset(train_ds, SCENARIO)[0]

        scaler = None
        if args.standardize:
            max_samples = args.std_max_samples if args.std_max_samples > 0 else None
            scaler = fit_standardizer(train_ds, SCENARIO, batch_size=8192, max_samples=max_samples)

        dm = ScenarioDataModule(
            train_ds, val_ds, test_ds,
            batch_size=exp["batch_size"], scenario=SCENARIO,
            num_workers=8, val_batch=128, scaler=scaler
        )

        lit_model = LitMLP(
            in_dim       = first_x.numel(),
            hidden_dims  = HIDDEN_DIMS,
            num_classes  = num_classes,
            lr           = exp["lr"],
            dropout      = exp["dropout"],
            class_weights= weights
        )

        logger   = CSVLogger(out_dir, name=name, flush_logs_every_n_steps=50)
        ckpt_cb  = ModelCheckpoint(monitor="val_loss", mode="min",
                                   filename="{epoch}-{val_loss:.4f}", save_weights_only=True)
        early_cb = EarlyStopping(monitor="val_loss", mode="min", patience=PATIENCE)

        trainer  = pl.Trainer(
            max_epochs=EPOCHS,
            callbacks=[ckpt_cb, early_cb],
            logger=logger,
            deterministic=True,            # full reproducibility
            log_every_n_steps=50,
            accelerator="auto", devices="auto",
        )

        if args.standardize and trainer.is_global_zero:
            os.makedirs(logger.log_dir, exist_ok=True)
            with open(os.path.join(logger.log_dir, "standard_scaler.pkl"), "wb") as f:
                pickle.dump(scaler, f)

        trainer.fit(lit_model, dm)

        best_ckpt = ckpt_cb.best_model_path
        single_gpu_tester = pl.Trainer(accelerator="gpu", devices=1, logger=False)
        single_gpu_tester.test(lit_model, datamodule=dm, ckpt_path=best_ckpt)

        # ---------- post-processing metrics/plots ----------
        log_csv = os.path.join(logger.log_dir, "metrics.csv")
        if trainer.is_global_zero:
            df = pd.read_csv(log_csv)

            loss_curves = df.pivot_table(index="epoch", values=["train_loss_epoch","val_loss"])
            acc_curves  = df.pivot_table(index="epoch", values=["train_acc_epoch","val_acc"])

            plt.figure(figsize=(10,4)); plt.plot(loss_curves); plt.title(f"Loss ({name})")
            plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(loss_curves.columns); plt.tight_layout()
            plt.savefig(os.path.join(logger.log_dir, "loss.png")); plt.close()

            plt.figure(figsize=(10,4)); plt.plot(acc_curves); plt.title(f"Accuracy ({name})")
            plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(acc_curves.columns); plt.tight_layout()
            plt.savefig(os.path.join(logger.log_dir, "acc.png")); plt.close()

            lit_model.eval()
            preds, labels = [], []
            with torch.no_grad():
                for x, y in dm.test_dataloader():
                    logits = lit_model(x.to(lit_model.device))
                    preds.extend(logits.argmax(1).cpu().numpy()); labels.extend(y.numpy())

            pd.DataFrame(classification_report(labels, preds, target_names=text_labels,
                                               output_dict=True)).T.to_csv(
                os.path.join(logger.log_dir, "metrics_report.csv"))

            cm  = confusion_matrix(labels, preds, labels=list(range(num_classes)))
            cmn = cm.astype(float) / cm.sum(1, keepdims=True)
            with open(os.path.join(logger.log_dir, "confusion_data.pkl"), "wb") as f:
                pickle.dump(cmn, f)
            plt.figure(figsize=(8,6))
            sns.heatmap(cmn, annot=True, fmt=".3f", cmap="viridis",
                        xticklabels=text_labels, yticklabels=text_labels)
            plt.xlabel("Pred"); plt.ylabel("True")
            plt.xticks(rotation=45, ha="right"); plt.yticks(rotation=0)
            plt.tight_layout()
            plt.savefig(os.path.join(logger.log_dir, "confusion.pdf"))
            plt.close()

    print("\nSweep complete. See results in", out_dir)
