import os
import json
import re

from pathlib import Path
from tqdm.autonotebook import tqdm

import torch

from shrp.models.def_net import CNN, CNN3
from shrp.models.def_AE import AE

def load_checkpoint(model_dir, epoch=25):
    """Loads a model from a checkpoint.
    
    Args:
        model_dir (str): Path to the model directory.
        epoch (int): Epoch to load.

    Returns:
        model (torch.nn.Module): The model.
        params (dict): Model hyperparameters.
    """

    with open(os.path.join(model_dir, 'params.json')) as ifh:
        params = json.load(ifh)
    path = os.path.join(model_dir, f'checkpoint_{epoch:06d}/checkpoints')
    
    if params['model::type'] == 'CNN':
        model = CNN(
            channels_in = params['model::channels_in'],
            nlin = params['model::nlin'],
            dropout = params['model::dropout']
        )
    elif params['model::type'] == 'CNN3':
        model = CNN3(
            channels_in = params['model::channels_in'],
            nlin = params['model::nlin'],
            dropout = params['model::dropout']
        )
    else:
        raise ValueError('Unknown model')
        
    model.load_state_dict(torch.load(path))
    return model, params

def load_zoo(zoo_path, epoch=25, filter_models=None, verbose=False):
    """
    Load multiple models from a zoo directory.

    Args:
        zoo_path (str): The path to the zoo directory containing the model checkpoints.
        epoch (int, optional): The epoch number to load the models from. Defaults to 25.
        filter_models (function, optional): A function that takes a model and its hyperparameters (torch.nn.Module, dict) and returns a boolean, when False the model is filtered out. Defaults to None.
        verbose (bool, optional): Whether to display a progress bar. Defaults to False.

    Returns:
        list: A list of tuples (torch.nn.Module, dict) containing the model and its hyperparameters.

    """
    model_paths = [os.path.join(zoo_path, model_path) for model_path in os.listdir(zoo_path) if 'NN_tune_trainable' in model_path]
    models = [load_checkpoint(model_path, epoch) for model_path in tqdm(model_paths, unit='model', disable=not verbose)]

    if filter is not None:
        models = [model for model in models if filter_models(*model)]

    return models

def load_hyperrep(ae_path, latent_dim):
    """
    Loads a trained hyper-representation autoencoder model with the specified latent dimension.

    This function searches for a trained autoencoder model in the given directory and loads the model
    with the specified latent dimension. It returns the loaded model if found; otherwise, it raises a
    ValueError.

    Args:
        ae_path (str): The path to the directory containing the trained autoencoder models.
        latent_dim (int): The latent dimension of the autoencoder model to load.

    Returns:
        torch.nn.Module: The loaded autoencoder model with the specified latent dimension.

    Raises:
        ValueError: If a model with the specified latent dimension is not found in the given directory.
    """
    model_paths = [os.path.join(ae_path, model_path) for model_path in os.listdir(ae_path) 
                                                     if 'AE_trainable' in model_path]
    for model_path in model_paths:
        with open(os.path.join(model_path, 'params.json')) as ifh:
            params_dict = json.load(ifh)
        
        if params_dict['ae:lat_dim'] != latent_dim:
            continue
        
        ae = AE(params_dict)
        ae.load_state_dict(torch.load(Path(model_path) / 'checkpoint_000150' / 'state.pt')['model'])
        ae.eval()
        ae.config = params_dict
        
        return ae
    
    raise ValueError('Hyper-representation model not found')


def load_hyperrep_edx(ae_path, epoch=150):
    """
    Given the path to a given auto-encoder, loads the checkpoint corresponding to the given epoch.

    Args:
        ae_path (str): The path to the directory of the hyperrepresentation model.
        epoch (int, optional): The epoch number to load the models from. Defaults to 150.

    Returns:
        torch.nn.Module: The loaded autoencoder model with the specified latent dimension.
    """
    with open(os.path.join(ae_path, 'params.json')) as ifh:
        params_dict = json.load(ifh)
    
    ae = AE(params_dict)
    ae.load_state_dict(torch.load(Path(ae_path) / 'checkpoint_{:06d}'.format(epoch) / 'state.pt')['model'])
    ae.eval()
    ae.config = params_dict
    
    return ae