"""
Train an admissibility estimator for gps.
"""

import dataclasses
import typing as tp

import numpy as np
import simple_parsing.helpers
import sklearn.metrics
import torch
import upath
from loguru import logger
from simple_parsing import ArgumentParser
from simple_parsing.wrappers.field_wrapper import DashVariant
from torch.utils.data.dataloader import DataLoader
from upath import UPath

from generative_prediction_sets import utils
from generative_prediction_sets.data_loaders import (
  HiddenStateDataset,
  get_dataset_split,
  load_splits,
)
from generative_prediction_sets.predictors import MLP

LOSS_FNS = {"mse": torch.nn.MSELoss, "bce": torch.nn.BCELoss}


@simple_parsing.helpers.serialization.encode.register
def encode_upath(upath: UPath) -> str:
  return upath.resolve().as_posix()


simple_parsing.helpers.serialization.register_decoding_fn(upath.UPath, upath.UPath)


@dataclasses.dataclass
class TrainerConfig(simple_parsing.helpers.Serializable):
  dataset_path: upath.UPath = dataclasses.field(metadata={"required": True})
  seed: int = 42
  splits: tp.Optional[upath.UPath] = None
  batch_size: int = 32
  hidden_dims: int = 256
  num_layers: int = 4
  lr: float = 0.001
  loss: str = dataclasses.field(
    default="mse", metadata={"choices": list(LOSS_FNS.keys())}
  )
  num_epochs: int = 15
  output_dir: upath.UPath | None = None


def trainer(
  seed,
  dataset_path: upath.UPath,
  batch_size,
  hidden_dims,
  num_layers,
  lr,
  loss,
  num_epochs,
  output_dir=None,
  split_indices=None,
):
  if output_dir is not None:
    if upath.UPath(output_dir).exists():
      logger.error(
        "Output directory already exists. To avoid accidental overwrites, you have to delete it manually.\n"
        "Use the following command:"
        f"gsutil -m rm -rf {upath.UPath(output_dir).as_posix()}"
      )
      raise ValueError(f"Output directory {output_dir} already exists")
  utils.seed_everything(seed)
  dataset = HiddenStateDataset(experiment_dir=dataset_path.as_posix()).load()
  dataset = dataset.select_columns(
    ["hidden_states", "p_map", "dataset_idx", "sample_idx"]
  )
  logger.info("Loaded dataset with {} samples", len(dataset))
  if split_indices is not None:
    logger.info("Using provided splits")
    dataset = get_dataset_split(dataset, split_indices)
    logger.info("Dataset size after applying split: {}", len(dataset))

  train_test = dataset.train_test_split(test_size=0.2)
  train_ds = train_test["train"]
  test_ds = train_test["test"]
  dataloader = DataLoader(
    train_ds.with_format("torch"), batch_size=batch_size, shuffle=True
  )

  model = MLP(
    input_dims=len(train_ds[0]["hidden_states"]),
    hidden_dims=hidden_dims,
    num_layers=num_layers,
  )
  # log model config, optimizer config
  logger.info(f"Model: {model.config}")
  logger.info(f"Optimizer: Adam(lr={lr})")

  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  LossFn = LOSS_FNS[loss]
  criterion = LossFn()

  for epoch in range(num_epochs):
    losses = []
    for batch in dataloader:
      optimizer.zero_grad()
      y_pred = model(batch["hidden_states"]).squeeze()
      loss = criterion(y_pred, batch["p_map"])
      loss.backward()
      optimizer.step()
      losses.append(loss.item())
    logger.info(f"Epoch {epoch} | Loss: {np.mean(losses)}")

  model = model.eval()
  with torch.inference_mode():
    test_results = []
    for batch in DataLoader(test_ds.with_format("torch"), batch_size=32):
      y_pred = model(batch["hidden_states"]).squeeze()
      test_results.append(y_pred.detach().numpy())

  metrics = {
    "rmse": sklearn.metrics.mean_squared_error,
    "mae": sklearn.metrics.mean_absolute_error,
    "r2": sklearn.metrics.r2_score,
    "explained_variance": sklearn.metrics.explained_variance_score,
  }

  y_pred = np.concatenate(test_results)
  y_test = test_ds["p_map"]
  metrics = {k: v(y_test, y_pred) for k, v in metrics.items()}
  import json

  logger.info(json.dumps(metrics, indent=2))

  if output_dir is not None:
    checkpoint_path = (upath.UPath(output_dir) / "model.ckpt").resolve().as_posix()
    model.save_checkpoint(checkpoint_path)

    with (upath.UPath(output_dir) / "metrics.json").open("w") as f:
      json.dump(metrics, f)

    logger.info(f"Saved checkpoint to {output_dir}")


if __name__ == "__main__":
  parser = ArgumentParser(
    description="Trainer for Hidden State Dataset",
    add_option_string_dash_variants=DashVariant.DASH,
  )
  parser.add_arguments(TrainerConfig, dest="config")

  args = parser.parse_args()
  config: TrainerConfig = args.config

  split_indices = None
  if config.splits is not None:
    try:
      logger.info("Loading splits from: {}", config.splits)
      split_indices = load_splits(config.splits).train
    except Exception as e:
      logger.error(f"Failed to load splits: {e}")
      raise e
  else:
    logger.warning("No splits provided. Using full dataset.")

  trainer(
    seed=config.seed,
    dataset_path=config.dataset_path,
    batch_size=config.batch_size,
    hidden_dims=config.hidden_dims,
    num_layers=config.num_layers,
    lr=config.lr,
    loss=config.loss,
    num_epochs=config.num_epochs,
    output_dir=config.output_dir,
    split_indices=split_indices,
  )
  if config.output_dir:
    with (config.output_dir / "config.yaml").open("w") as f:
      config.dump_yaml(f)
      logger.info(f"Saved `config.yaml` to {config.output_dir}")
