# -*- coding: utf-8 -*-

u"""
(c) Copyright 2023 Telefónica. All Rights Reserved.
The copyright to the software program(s) is property of Telefónica.
The program(s) may be used and or copied only with the express written consent of Telefónica or in accordance with the
terms and conditions stipulated in the agreement/contract under which the program(s) have been supplied."""

import json
import os

from collections import namedtuple
from typing import Dict, List, Tuple

from common.constants import CHECKPOINT_PATH, COMMON, CURRENT_MODEL_NAME, ML_UTILITY_OUT_PATH, \
    MASK_INCLUDE_TARGET, MASK_PARAMS,  MODEL_PARAMS, MODEL_TYPE, OPTIMIZER_PARAMS, SEED, \
    ML_UTILITY_TEST_PARAMS, TRANSFORMER, TRAIN_PARAMS

DatasetMetadata = namedtuple(
    'DatasetMetadata',
    'dataset_name,'
    'categorical_features_idxs,'
    'categorical_lengths,'
    'categorical_weights,'
    'continuous_features_idxs,'
    'target_cols_idxs,'
    'problem_type,'
    'num_classes,'
    'class_weights,'
    'null_default_values'
)


def load_json(jsons: List[str]) -> List[Dict]:
    """Load json files

    :param jsons: List[str], list of json paths to load
    :return: List[Dict] with the info contained in the jsons
    """
    to_return = []
    for path in jsons:
        with open(path) as f:
            to_return.append(json.load(f))
    return to_return


def load_config_json(path: str) -> Tuple[Dict, Dict, Dict, Dict, Dict]:
    """Load and preprocessing config json

    :param path: str, path name
    :return: Tuple of dictionary: train_params, optimizer_params, model_params, mask_params
                                  supervised_with_target_exp_params, supervised_without_target_exp_params
    """
    # Load json
    json = load_json([path])[0]
    train_params = json[TRAIN_PARAMS]
    optimizer_params = json[OPTIMIZER_PARAMS]
    model_type = TRANSFORMER
    model_params = json[MODEL_PARAMS].get(model_type, None)
    if model_params is None:
        raise Exception('{} in config.json is not a valid value'.format(model_type))
    model_params.update({MODEL_TYPE: model_type, CURRENT_MODEL_NAME: json[CURRENT_MODEL_NAME]})
    model_params.update(json[MODEL_PARAMS][COMMON])
    if not os.path.exists(model_params[CHECKPOINT_PATH]):
        os.makedirs(model_params[CHECKPOINT_PATH])
    model_params[CHECKPOINT_PATH] = os.path.join(model_params[CHECKPOINT_PATH],
                                                 json[CURRENT_MODEL_NAME] + '.ptk')
    mask_params = json[MASK_PARAMS]
    if ((json[ML_UTILITY_TEST_PARAMS][ML_UTILITY_OUT_PATH] is not None) and
            (json[ML_UTILITY_TEST_PARAMS][ML_UTILITY_OUT_PATH] != '')):
        if not os.path.exists(json[ML_UTILITY_TEST_PARAMS][ML_UTILITY_OUT_PATH]):
            raise Exception(
                '{} path does not exist!'.format(json[ML_UTILITY_TEST_PARAMS][ML_UTILITY_OUT_PATH]))
    supervised_with_target_exp_params = json[ML_UTILITY_TEST_PARAMS]
    supervised_with_target_exp_params[ML_UTILITY_OUT_PATH] = (os.path.join(
        supervised_with_target_exp_params[ML_UTILITY_OUT_PATH], json[CURRENT_MODEL_NAME] + '.csv')
    if supervised_with_target_exp_params[ML_UTILITY_OUT_PATH] is not None else None)

    return (train_params, optimizer_params, model_params, mask_params,
            supervised_with_target_exp_params)


def get_seed(autoencoder_flag: bool, supervised_flag: bool, train_params: Dict, mask_params: Dict,
             svt_exp_params: Dict, svwot_exp_params: Dict):
    if autoencoder_flag:
        return train_params[SEED]
    elif supervised_flag:
        return svt_exp_params[SEED] if mask_params[MASK_INCLUDE_TARGET] else svwot_exp_params[SEED]
    else:
        return 0
