from common_imports import *
from common_use_functions import contents_of_folder, path_join, load_json, join_string, str_first_part_split_from_r
from constant import np_file_extension, model_config_file_keyword, model_state_file_keyword

"""
All the following functions are the customized version for network training preparation (mainly for the classification task performed 
with CNN and Inception), we programmed other simpler version for the general training preparation 
(e.g. get optimizer, loss function) called "pytorch_training_preparation".

Note: This module is generally paired with the module "experim_neural_network" (the customized functions for pytorch training experiment).
"""

"""
Function for the dataset loading
"""
def load_dataset(dataset_path, valid_set_name, display=True):
    """
    This function load all the .npy files exsited in the folder and verify the existence of a validation set
    """
    # Existence of the validation set
    is_valid_set = False

    # Load data sets
    files = contents_of_folder(dataset_path)
    loaded_datasets_dict = {}
    for file in files:
        if np_file_extension in file:
            current_file_name = file.split('.')[0]
            loaded_datasets_dict[current_file_name] = np.load(path_join(dataset_path, file))
        if valid_set_name in file:
            is_valid_set = True
    if display:
        print(loaded_datasets_dict)
        for key in loaded_datasets_dict:
            print(key,':',loaded_datasets_dict[key].shape)
        print('Existence of the validation set :', str(is_valid_set))
    return loaded_datasets_dict, is_valid_set

"""
Pytorch Data Loader Generation
"""
def create_dataloader(X, y, batch_size, shuffle=False, type_conversion=True):
    """
    This is a simple version to create dataloaders, the type conversion is only
    valid for classification tasks.
    """
    # The provided X and y should be numpy arrays
    # Inputs and labels
    torch_inputs = None

    torch_labels = None
    if type_conversion:
        torch_inputs = torch.from_numpy(X).float()
        torch_labels = torch.from_numpy(y).long()
    else:
        torch_inputs = torch.from_numpy(X)
        torch_labels = torch.from_numpy(y)
    # TensorDataset
    torch_dataset = TensorDataset(torch_inputs, torch_labels)
    # Generate the dataloader
    torch_loader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=shuffle) 
    return torch_loader

def create_dataloader_different_tasks(X, y, batch_size, shuffle=False, type_conversion='classification'):  
    """  
    Create a dataloader to from numpy arrays, this function could serve more tasks.

    X: The input features  
    y: The labels  batch_size: The desired batch size for the loader  
    shuffle: Boolean that determines if the data should be shuffled every epoch  
    type_conversion: The general required type conversion for different tasks,            
                    you can provide values as "classification" and "regression",           
                    otherwise, there is no type conversion.  
    """  
    # The provided X and y should be numpy arrays  
    # Inputs and labels  
    torch_inputs = None  
    torch_labels = None  
    if type_conversion == 'classification':    
        torch_inputs = torch.from_numpy(X).float()    
        torch_labels = torch.from_numpy(y).long()  
    elif type_conversion == 'regression':   
        torch_inputs = torch.from_numpy(X).float()    
        torch_labels = torch.from_numpy(y).float() 
    else:   
        torch_inputs = torch.from_numpy(X)
        torch_labels = torch.from_numpy(y)  
    # TensorDataset  
    torch_dataset = TensorDataset(torch_inputs, torch_labels) 
     # Generate the dataloader  
    torch_loader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=shuffle)  
    
    return torch_loader

"""
Functions to load the pytorch models
"""
def load_model_by_config(config_folder_path, data_set_infos, needed_configs, model_class):
    """
    Load pytorch model with the provided configuration and weights in the folder

    This function requires the file for the state (weights) should contains the key word 'model'
    and the file for the specific configuration should contains the key word 'configuration'.

    The model class should be intialized with the parameters of the following order : data set informations then models configurations

    data_set_infos : general informations about the dataset (a dictionary that contains 
    the depth of input image, the image size, and the number of outputs (number of classes), please
    give the parameters as the order required by the model class)

    needed_configs : the configuration to use (please also provide this as the order required in the model class)
    """
    loaded_model = None
    # Load model informations
    files = contents_of_folder(config_folder_path)
    model_config_dict = {}
    for file in files:
        current_file_path = path_join(config_folder_path, file)
        if model_state_file_keyword in file:
            model_config_dict['state_path'] = current_file_path
        elif model_config_file_keyword in file:
            model_config_dict['config_path'] = current_file_path
    with open(model_config_dict['config_path'], 'r') as fp:
        config_data = load_json(fp)
        all_model_init_params = [*list(data_set_infos.values())]
        for needed_config in needed_configs:
            all_model_init_params.append(config_data[needed_config])
        loaded_model = model_class(*all_model_init_params)
        loaded_model.load_state_dict(torch.load(model_config_dict['state_path']))
        fp.close()
    return loaded_model

def load_models(all_models_folder_path, model_type, data_set_infos, needed_configs, model_class, direct_name=False):
    """
    Load the models existed in the provided folder (This version is used for the case when we store only the state dictionary of models)

    direct_name : A parameter indicates if the name of folder containing the model is the model code or its direct name (for example, 0 or CNN_0)
    """
    loaded_models = {}
    model_folders = contents_of_folder(all_models_folder_path)
    for model_code_name in model_folders:
        model_folder_path = path_join(all_models_folder_path, model_code_name)
        current_loaded_model = load_model_by_config(model_folder_path, data_set_infos, needed_configs, model_class)
        if direct_name:
            loaded_models[model_code_name] = current_loaded_model
        else:
            loaded_models[join_string([model_type, model_code_name])] = current_loaded_model
    return loaded_models

def load_models_pt(all_models_folder_path):
    """
    Load the models existed in the provided folder (This version is used for the direct loading of model from .pt files)

    For this version. in all_models_folder_path, it should contains:
    - all_models_folder_path
        - Unique_id_model.pt
        ...
        - Unique_id_model.pt

    For example:
    - all_models_folder_path
        - CNN_0.pt
        ...
        - Inception_4.pt    
    """
    loaded_models = {}
    model_files = contents_of_folder(all_models_folder_path)
    for model_file_name in model_files:
        model_file_path = path_join(all_models_folder_path, model_file_name)
        current_loaded_model = torch.load(model_file_path)
        model_code_name = str_first_part_split_from_r(model_file_name)
        loaded_models[model_code_name] = current_loaded_model
    return loaded_models
