import yaml

import os
from pathlib import Path
import wandb
import torch
import logging

logger = logging.getLogger(__name__)

def load_and_harmonize_hparams(args):
  """
  Loads hyperparameters from a YAML file and harmonizes them with the provided arguments.

  Args:
      args (Namespace): The command-line arguments.

  Returns:
      dict: The harmonized hyperparameters.
  """
  with open(get_project_root() / args.hparams) as f:
      hparams = yaml.safe_load(f)

  # "harmonize" hparams and args
  hparams['optimizer'] = {
      'lr': args.lr,
      'wd': args.wd
  }
  hparams['legendre_polys'] = args.locationencoder_args["legendre_polys"]
  hparams['harmonics_calculation'] = args.locationencoder_args["harmonics_calculation"]
  # copy presence_only_loss attribute if it exists in args
  if hasattr(args, "presence_only_loss"):    
    hparams["presence_only_loss"] = args.presence_only_loss
  hparams['patience'] = args.patience
  hparams['regression'] = args.regression
  hparams["max_epochs"] = args.max_epochs
  hparams["min_radius"] = args.min_radius

  hparams = set_default_if_unset(hparams, "max_radius", 360)

  args.locationencoder_args["hparams"] = hparams

  return args

def get_project_root():
    """Returns the root directory of the project."""
    return Path(os.path.dirname(os.path.abspath(__file__))) / "../"

def get_output_root():
    """Returns the output root directory of the project."""
    return Path("./output")

def parse_resultsdir(args):
    if args.expname is None:
        rsdir = (get_output_root() / 
            args.results_dir / 
            args.locationencoder_args["time_embedding_type"] / 
            args.locationencoder_args["combined_encoding_args"]["name"] / 
            # f'layers_{args.locationencoder_args["num_layers"]}' /
            # f'hidden_neurons_{args.locationencoder_args["dim_hidden"]}' /
            (str(args.max_epochs) + "_epochs") /
            f'space_degree_{args.locationencoder_args["legendre_polys"]}' /
            f'time_degree_{args.locationencoder_args["time_encoding_args"]["input_dim"]}' /
            f'combination_{args.locationencoder_args["combination_type"]}' /
            args.dataset / 
            f'task_{args.datamodule_args["mode"]}' /
            # f'train_fraction_{args.datamodule_args["train_fraction"]}' /
            # f'subset_fraction_{args.datamodule_args["subset_fraction"]}' /
            args.locationencoder_args["positional_embedding_type"]
            # args.expname
        )
    else:
        rsdir = (get_output_root() / 
            args.results_dir / 
            args.locationencoder_args["time_embedding_type"] / 
            args.locationencoder_args["combined_encoding_args"]["name"] / 
            # f'layers_{args.locationencoder_args["num_layers"]}' /
            # f'hidden_neurons_{args.locationencoder_args["dim_hidden"]}' /
            (str(args.max_epochs) + "_epochs") /
            f'space_degree_{args.locationencoder_args["legendre_polys"]}' /
            f'time_degree_{args.locationencoder_args["time_encoding_args"]["input_dim"]}' /
            f'combination_{args.locationencoder_args["combination_type"]}' /
            args.dataset / 
            f'task_{args.datamodule_args["mode"]}' /
            # f'train_fraction_{args.datamodule_args["train_fraction"]}' /
            # f'subset_fraction_{args.datamodule_args["subset_fraction"]}' /
            args.locationencoder_args["positional_embedding_type"] /
            args.expname
        )

    os.makedirs(rsdir, exist_ok=True)
    return rsdir


def find_best_checkpoint(directory, pattern, verbose=False):
    """searches a directory for checkpoints following a pattern (e.g., sphericalharmonics-siren) and returns
    the one with lowest val_loss.
    checkpoint format example: sphericalharmonics-siren-val_lossval_loss=6.69.ckpt
    """
    checkpoints = [c for c in os.listdir(directory) if c.endswith("ckpt")]
    checkpoints = [c for c in checkpoints if pattern in c]

    if len(checkpoints) == 0:
        if verbose:
            print("no suitable checkpoint found. returning None")
        return None
    else:
        if verbose:
            print(f"resuming from checkpoints in results-dir. Found candidates {' '.join(checkpoints)}")
        val_loss = [float(c.split("val_loss=")[-1].replace(".ckpt", "")) for c in checkpoints]

        # this line sorts checkpoints according to their validation loss and takes first (lowest val loss)
        resume_checkpoint = [c for _, c in sorted(zip(val_loss, checkpoints))][0]
        if verbose:
            print(f"taking: {resume_checkpoint}")

        return os.path.join(directory, resume_checkpoint)


def set_default_if_unset(hparams, key, value):
    if not key in hparams.keys():
        hparams[key] = value
    return hparams


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def save_model_parameters(trainer, locationencoder, wandb_logger):
    """
    Saves the model parameters and logs them as a wandb artifact.

    Args:
        trainer (pl.Trainer): The PyTorch Lightning trainer.
        locationencoder (torch.nn.Module): The trained model.
    """
    locationencoder.float()
    model_path = f"{trainer.logger.save_dir}/model_weights_only.pt"
    torch.save(locationencoder.state_dict(), model_path)

    artifact = wandb.Artifact(name="model-params", type="model")
    artifact.add_file(model_path)
    wandb_logger.experiment.log_artifact(artifact)

    logger.info("Model converted to float32 and saved successfully.")
