"""
Grid-Search variant of train-pmap.py
"""

import datetime
import os

import numpy as np
import ray
import ray.train
import ray.tune
import torch
import upath
from dotenv import load_dotenv
from torch.utils.data import DataLoader

from generative_prediction_sets import utils
from generative_prediction_sets.data_loaders import (
  HiddenStateDataset,
  load_splits,
)
from generative_prediction_sets.predictors import MLP

load_dotenv()
print(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])


def get_timestamp():
  return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


LOSS_FNS = {"mse": torch.nn.MSELoss, "bce": torch.nn.BCELoss}


def train_model(train_ds, test_ds, config):
  utils.seed_everything(config["seed"])
  dataloader = DataLoader(
    train_ds.with_format("torch"), batch_size=config["batch_size"], shuffle=True
  )
  model = MLP(
    input_dims=len(train_ds[0]["hidden_states"]),
    hidden_dims=config["hidden_dims"],
    num_layers=config["num_layers"],
  )
  optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
  # criterion = torch.nn.MSELoss()
  LossFn = LOSS_FNS[config["loss"]]
  criterion = LossFn()
  for epoch in range(config["num_epochs"]):
    for batch in dataloader:
      optimizer.zero_grad()
      y_pred = model(batch["hidden_states"]).squeeze()
      loss = criterion(y_pred, batch["p_map"])
      loss.backward()
      optimizer.step()

  model = model.eval()
  with torch.inference_mode():
    test_results = []
    for batch in DataLoader(test_ds.with_format("torch"), batch_size=32):
      y_pred = model(batch["hidden_states"]).squeeze()
      test_results.append(y_pred.detach().numpy())

  import sklearn.metrics

  metrics = {
    "rmse": sklearn.metrics.mean_squared_error,
    "mae": sklearn.metrics.mean_absolute_error,
    "r2": sklearn.metrics.r2_score,
    "explained_variance": sklearn.metrics.explained_variance_score,
  }
  y_pred = np.concatenate(test_results)
  y_test = test_ds["p_map"]
  metrics = {k: v(y_test, y_pred) for k, v in metrics.items()}

  return model, metrics


def load_dataset(experiment_dir, split_indices=None):
  from generative_prediction_sets.data_loaders import get_dataset_split

  dataset = HiddenStateDataset(experiment_dir=experiment_dir).load()
  dataset = dataset.select_columns(["hidden_states", "p_map", "dataset_idx"])
  if split_indices is not None:
    dataset = get_dataset_split(dataset, split_indices)
  return dataset


def train_with_cfg(config):
  from loguru import logger

  dataset = load_dataset(config["dataset_path"], config["split_indices"])

  train_test = dataset.train_test_split(test_size=config["test_size"])
  train_ds = train_test["train"]
  test_ds = train_test["test"]

  model, metrics = train_model(train_ds, test_ds, config)
  return dataset, model, metrics


def trainer(config):
  _, _, metrics = train_with_cfg(config)
  ray.train.report(metrics)


def main(output_dir: upath.UPath, experiment_dir: upath.UPath, splits=None):
  from loguru import logger

  from generative_prediction_sets import utils

  if not splits:
    splits_dir = experiment_dir.parent.parent / "splits"
    cfg_path = experiment_dir / "config.yaml"
    if not cfg_path.exists():
      logger.error("Config file not found: {}", cfg_path.as_posix())
      raise FileNotFoundError(f"Config file not found: {cfg_path.as_posix()}")
    cfg = utils.load_yaml(cfg_path)
    task_name = cfg["task"]
    logger.info("Trying to find splits dir: {}", splits_dir.as_posix())
    splits_path = splits_dir / f"{task_name}.npz"
    if not splits_path.exists():
      logger.error("Splits dir not found: {}", splits_path.as_posix())
      raise FileNotFoundError(f"Splits dir not found: {splits_dir.as_posix()}")
  else:
    splits_path = upath.UPath(splits)

  try:
    logger.info("Loading splits from: {}", splits_path.as_posix())
    split_indices = load_splits(splits_path)
  except Exception as e:
    logger.error(f"Failed to load splits: {e}")
    raise e

  logger.info("Caching dataset: {}", experiment_dir.resolve().as_posix())
  dataset = HiddenStateDataset(
    experiment_dir=experiment_dir.resolve().as_posix()
  ).load()
  configs = {
    "num_epochs": ray.tune.grid_search([15, 20, 25]),
    "batch_size": 32,  # ray.tune.grid_search([32, 64, 128]),
    "lr": ray.tune.grid_search([1e-3]),
    "hidden_dims": ray.tune.grid_search([256, 512, 1024]),
    "num_layers": ray.tune.grid_search([1, 2, 4]),
    "dataset_path": experiment_dir.resolve().as_posix(),
    "split_indices": split_indices.train,
    "loss": ray.tune.grid_search(["mse"]),
    "seed": 0,
    "test_size": 0.2,
  }

  experiment_output_dir = output_dir / experiment_dir.name
  ray_output_dir = experiment_output_dir / "grid-search"
  tuner = ray.tune.Tuner(
    trainer,
    param_space=configs,
    run_config=ray.train.RunConfig(
      name=experiment_dir.name,
      storage_path=ray_output_dir.as_posix(),
      # storage_filesystem=output_dir.fs,
    ),
  )
  from loguru import logger

  logger.info(f"Starting experiment {experiment_dir.name}")
  results = tuner.fit()
  logger.info(f"Saved outputs to: {experiment_output_dir}")

  # ---------------------------------------------------------------------------- #
  #                         Train model with best config                         #
  # ---------------------------------------------------------------------------- #
  best_result = results.get_best_result(metric="mae", mode="min")
  best_cfg = best_result.config
  if best_cfg is None:
    logger.error("No best config found")
    return

  print(f"Best config: {best_cfg}")
  logger.info("Training with best config")

  train_ds = load_dataset(best_cfg["dataset_path"], split_indices.train)
  test_ds = load_dataset(best_cfg["dataset_path"], split_indices.test)

  assert len(train_ds) > 0, "Training dataset is empty"
  assert len(test_ds) > 0, "Test dataset is empty"
  model, metrics = train_model(train_ds, test_ds, best_cfg)

  print("Final Trained Model Metrics:")
  print(metrics)

  logger.info(f"Saving model to {(experiment_output_dir / 'model.ckpt').as_posix()}")
  model.save_checkpoint((experiment_output_dir / "model.ckpt").as_posix())

  with (experiment_output_dir / "metrics.json").open("w") as f:
    import json

    json.dump(metrics, f)

  with (experiment_output_dir / "best_config.yaml").open("w") as f:
    import yaml

    yaml.dump(best_cfg, f)


if __name__ == "__main__":
  import argparse

  parser = argparse.ArgumentParser()
  parser.add_argument("--output-dir", type=str, default="./train-runs")
  parser.add_argument("--dataset-path", type=str, required=True)
  parser.add_argument("--splits", type=str, default=None)
  args, unknown_args = parser.parse_known_args()
  if unknown_args:
    print(f"Ignoring unknown args: {unknown_args}")
    print("\n")

  main(
    output_dir=upath.UPath(args.output_dir).resolve(),
    experiment_dir=upath.UPath(args.dataset_path).resolve(),
    splits=args.splits,
  )
