# -*- coding: utf-8 -*-

u"""
(c) Copyright 2023 Telefónica. All Rights Reserved.
The copyright to the software program(s) is property of Telefónica.
The program(s) may be used and or copied only with the express written consent of Telefónica or in accordance with the
terms and conditions stipulated in the agreement/contract under which the program(s) have been supplied."""

from typing import Dict, List, Tuple

import numpy as np
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, mean_squared_error, r2_score
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
import torch

from common.constants import ACCURACY, CLASSIFICATION, MACRO_F1, MODEL, MSE, PARAMS, WEIGHTED_F1
from experiments.xgboost_best_params import xgboost_params


def data_to_supervised_model_format(x_train: torch.tensor,
                                    y_train: torch.tensor,
                                    x_test_list: List[torch.tensor],
                                    y_test_list: List[torch.tensor],
                                    cont_num: int,
                                    problem_type: str=CLASSIFICATION) -> Tuple:
    """Pre-process data for the RandomForest

    :param x_train: torch.tensor, train features
    :param y_train: torch.tensor, train target
    :param x_test_list: List[torch.tensor], test features
    :param y_test_list: List[torch.tensor], test target
    :param cont_num: int
    :param problem_type: str, CLASSIFICATION or REGRESSION
    :return: Tuple de np.array -> preprocessed train_x_tot, train_y_tot, test_x_tot, test_y_tot
    """
    # PyTorch --> Numpy
    train_x_cont = x_train.cpu().numpy()[:, :cont_num]
    train_x_cat = x_train.cpu().numpy()[:, cont_num:]
    train_y = y_train.cpu().numpy()

    test_x_cont = [arr.cpu().numpy()[:, :cont_num] for arr in x_test_list]
    test_x_cat = [arr.cpu().numpy()[:, cont_num:] for arr in x_test_list]
    test_y = [arr.cpu().numpy() for arr in y_test_list]

    # OneHotEncoding categorical features
    if x_train.shape[1] > cont_num:
        tmp_ohe = OneHotEncoder(sparse=False)
        tmp_ohe.fit(np.concatenate([train_x_cat] + test_x_cat, axis=0))
        tmp_train_x_cat = tmp_ohe.transform(train_x_cat)
        tmp_test_x_cat = [tmp_ohe.transform(arr) for arr in test_x_cat]
        train_x_tot = np.concatenate((train_x_cont, tmp_train_x_cat), axis=1)
        test_x_tot = [np.concatenate([arr_cont, arr_cat], axis=1)
                      for arr_cont, arr_cat in zip(test_x_cont, tmp_test_x_cat)]
    else:
        train_x_tot = train_x_cont
        test_x_tot = test_x_cont

    # Label encoder target
    if problem_type is CLASSIFICATION:
        le = LabelEncoder()
        le.fit(train_y)
        train_y_tot = le.transform(train_y)
        test_y_tot = [le.transform(arr) for arr in test_y]
    else:
        train_y_tot = train_y
        test_y_tot = test_y
    return train_x_tot, train_y_tot, test_x_tot, test_y_tot


def train_xgboost(x_train: np.array, y_train: np.array, dataset_name: str):
    """Train a xgboost

    :param x_train: np.array, train features
    :param y_train: np.array, train target
    :param dataset_name:
    :return: trained xgboost model
    """
    model_dict = xgboost_params.get(dataset_name, None)
    if model_dict is None:
        print(f'No tuned XGBoost for the {dataset_name} dataset')
        return None
    model = model_dict[MODEL](**model_dict[PARAMS])
    model.fit(x_train, y_train)
    return model


def train_rf(x_train: np.array, y_train: np.array, problem_type: str=CLASSIFICATION):
    """Train a Random Forest model

    :param x_train: np.array, train features
    :param y_train: np.array, train target
    :param problem_type: str, CLASSIFICATION or REGRESSION
    :return: a fitted Random Forest instance
    """
    if problem_type == CLASSIFICATION:
        rfc = RandomForestClassifier(n_estimators=200, n_jobs=-1)
    else:
        rfc = RandomForestRegressor(n_estimators=200, n_jobs=-1)
    rfc.fit(x_train, y_train)
    return rfc


def test_ml(ml_model, x_test: np.array, y_test: np.array, problem_type: str=CLASSIFICATION) -> Dict:
    """Test a ML model (Random Forest o XGBoost) instance

    :param ml_model: Model instance to test
    :param x_test: np.array, test features
    :param y_test: np.array, test targets
    :param problem_type: str, CLASSIFICATION or REGRESSION
    :return: Dict, key: Accuracy
    """
    y_pred = ml_model.predict(x_test)
    if problem_type == CLASSIFICATION:
        acc = accuracy_score(y_test, y_pred)
        macro_f1 = f1_score(y_test, y_pred, average='macro')
        weighted_f1 = f1_score(y_test, y_pred, average='weighted')
        print(confusion_matrix(y_test, y_pred))
        print('\tMicro F1 (Accuracy): {:.3f}'.format(100*acc))
        print('\tMacro F1: {:.3f}'.format(100 * macro_f1))
        print('\tWeighted F1: {:.3f}'.format(100 * weighted_f1))
        print()
        return {ACCURACY: acc, MACRO_F1: macro_f1, WEIGHTED_F1: weighted_f1}
    else:
        mse = mean_squared_error(y_test.reshape(-1), y_pred.reshape(-1))
        print('\tMSE: {:.5f}'.format(mse))
        return {MSE: mse}
