# -*- coding: utf-8 -*-

u"""
(c) Copyright 2023 Telefónica. All Rights Reserved.
The copyright to the software program(s) is property of Telefónica.
The program(s) may be used and or copied only with the express written consent of Telefónica or in accordance with the
terms and conditions stipulated in the agreement/contract under which the program(s) have been supplied."""

import json
import numpy as np
import os
import pandas as pd
import torch

from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, QuantileTransformer, StandardScaler
from typing import Dict, Optional, Tuple

from common.constants import CALIFORNIA_HOUSING, CAT, CHURN, CLASSIFICATION, CPU, CONF_DIR, DEFAULT, FREQ, \
    FREQUENCY, INDEX, NAME, NCAT, ONEHOT, PROBLEM_TYPE, QUAL, QUAN, REGRESSION, STRUCT, TARGET, TRAIN_SIZE_GEN, \
    TYPE, WEIGHTS
from common.utils import DatasetMetadata


def check_metadata(m: Dict) -> bool:
    """
    Basic checks of the metadata
    :param m: Dictionary with the metadata
    :return: True if everything seems alright, False otherwise
    """
    # Check column features
    ms = m[STRUCT]
    n_keys = len(ms)
    assert len(set(ms.keys())) == n_keys, "Duplicated names"
    assert len(set([v[INDEX] for v in ms.values()])) == n_keys, "Duplicated indexes"

    # Check problem type
    target_type = [v[TYPE] for v in ms.values() if v.get(TARGET, False)][0]
    if m[PROBLEM_TYPE] == CLASSIFICATION:
        assert target_type == QUAL, "Wrong target type for a {} problem".format(CLASSIFICATION)
    else:
        assert target_type == QUAN, "Wrong target type for a {} problem".format(REGRESSION)

    return True


def read_metadata(path: str) -> Dict:
    """
    Read JSON file that contains the metadata
    :param path:
    :return:
    """
    with open(path) as json_file:
        data = json.load(json_file)
    return data


def print_dataset_metadata(m: DatasetMetadata) -> None:
    """
    Pretty print for object `DatasetMetadata`
    :param m: DatasetMetadata object
    :return: None
    """
    print(type(m).__name__)
    for k, v in m._asdict().items():
        print("  {} = {}".format(k, v))


def build_metadata_object(metadata: Dict, device: str, include_target: bool = False) -> DatasetMetadata:
    """
    Build expected object with the metadata that are strictly necessary to train
    :param metadata: Dictionary with the whole information of the processed dataset
    :param device: Device (`CPU` for CPU, `CUDA` for GPU)
    :param include_target: Whether include target column or not
    :return: Object `DatasetMetadata` with the necessary metadata as attributes
    """
    categorical_features_idxs = list()
    categorical_lengths = list()
    categorical_weights = list()
    continuous_features_idxs = list()
    target_cols_idxs = list()
    num_classes = -1
    class_weights = None
    null_default_values = torch.tensor([])

    # Build new dictionary sorted by index to make sure the index lists are correct
    m_struct_sorted = dict(sorted(metadata[STRUCT].items(), key=lambda item: item[1][INDEX]))
    for v in m_struct_sorted.values():
        if v.get(TARGET, False):
            target_cols_idxs.append(v[INDEX])
            if v[TYPE] == QUAL:
                num_classes = v[NCAT]
                class_weights = torch.tensor(list(v[WEIGHTS].values())).to(device)
            if not include_target:
                continue
        null_default_values = torch.cat((null_default_values, torch.tensor([v[DEFAULT]])), 0)
        if v[TYPE] == QUAL:
            categorical_features_idxs.append(v[INDEX])
            categorical_lengths.append(v[NCAT])
            categorical_weights.append(torch.tensor(list(v[WEIGHTS].values())).to(device))
        else:
            continuous_features_idxs.append(v[INDEX])

    return DatasetMetadata(
        dataset_name=metadata[NAME],
        problem_type=metadata[PROBLEM_TYPE],
        categorical_features_idxs=categorical_features_idxs,
        categorical_lengths=categorical_lengths,
        categorical_weights=categorical_weights,
        continuous_features_idxs=continuous_features_idxs,
        target_cols_idxs=target_cols_idxs,
        num_classes=num_classes,
        class_weights=class_weights,
        null_default_values=null_default_values.to(device, dtype=torch.float)
    )


def process_data(df: pd.DataFrame, metadata: Dict, seed: Optional[int] = None, device: str = CPU,
                 include_target: bool = False, cat_weights_method: str = "Uniform") -> Optional[
                 Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, DatasetMetadata]]:
    """
    Process input data (Pandas DataFrame) as desired for AI Lab
    :param df: Pandas DataFrame with the data to be processed
    :param metadata: Dictionary with the metadata set by the user
    :param seed: Seed for random number generator
    :param device: Device (`CPU` for CPU, `CUDA` for GPU)
    :param include_target: Whether include target column or not
    :param cat_weights_method: str, the method to use for computing the categorical weights
    :return: Tuple with a Pandas DataFrame that contains the processed data and its associated metadata
    """

    if not check_metadata(metadata):
        return None
    m = metadata[STRUCT].copy()

    col_target = [c for c, v in m.items() if v.get(TARGET, False)]

    #############################################
    # Process qualitative columns
    #############################################
    enc = OrdinalEncoder()
    df_qual = None
    for c, v in m.items():
        if v[TYPE] == QUAL:
            if v.get(ONEHOT):
                idx_list = list(range(v[INDEX], v[INDEX] + v.get(ONEHOT)))
                ohe_arr = df[idx_list].values
                int_arr = np.expand_dims(np.array([i for i in range(ohe_arr.shape[1])]), axis=0)
                arr = (ohe_arr * int_arr).sum(axis=1)
            else:
                arr = enc.fit_transform(np.array(df[v[INDEX]]).reshape(-1, 1)).flatten()

            pd_arr = pd.DataFrame(data=arr, columns=[c])
            df_qual = pd_arr if df_qual is None else pd.concat([df_qual, pd_arr], axis=1)

            # Update metadata
            cat = np.unique(arr)
            m[c][NCAT] = cat.size
            m[c][CAT] = list(cat)
            m[c][FREQ] = pd.Series(arr).value_counts().to_dict()
            m[c][DEFAULT] = cat.size
            weights = np.array([1. for _ in range(cat.size)])
            if cat_weights_method == FREQUENCY:
                weights = compute_class_weight(class_weight='balanced', classes=cat, y=arr)
            m[c][WEIGHTS] = dict([(k, w) for k, w in zip(cat, weights)])

    qual_cols = []
    if df_qual is not None:
        qual_cols = list(df_qual.columns)

    #############################################
    # Process quantitative columns
    #############################################
    col_quan, idx_quan = map(list, zip(*[(c, v[INDEX]) for c, v in m.items()
                                         if ((v[TYPE] == QUAN) and (c != col_target[0]))]))
    pt = QuantileTransformer(n_quantiles=20, random_state=0, output_distribution='normal')
    df_quan = pd.DataFrame(data=pt.fit_transform(df[idx_quan].to_numpy()).astype('float32'), columns=col_quan)

    # Complete metadata and reorder columns
    col_not_target = [c for c in list(df_quan.columns) + qual_cols if c != col_target[0]]
    col_final = col_target + col_not_target if m[col_target[0]][TYPE] == QUAN else col_not_target + col_target
    if m[col_target[0]][TYPE] == QUAN:
        ss = StandardScaler()
        df_target = pd.DataFrame(data=ss.fit_transform(df[[m[col_target[0]][INDEX]]].to_numpy()).astype('float32'), columns=col_target)
        df_final = pd.concat([df_target, df_quan, df_qual], axis=1).reindex(columns=col_final)
    else:
        df_final = pd.concat([df_quan, df_qual], axis=1).reindex(columns=col_final)

    # Update metadata
    quan_default = float('inf')
    for i, c in enumerate(col_final):
        m[c][INDEX] = i
        if m[c][TYPE] == QUAN:
            def_val = df_final[c].min()
            if def_val < quan_default:
                quan_default = def_val
        if ONEHOT in m[c]:
            del m[c][ONEHOT]
    for c in col_final:
        if m[c][TYPE] == QUAN:
            m[c][DEFAULT] = 1.5 * quan_default

    #############################################
    # Split data as needed
    #############################################
    y_dtype = torch.long if metadata[PROBLEM_TYPE] == CLASSIFICATION else torch.float
    df_train, df_test = train_test_split(df_final, test_size=1 - metadata[TRAIN_SIZE_GEN], random_state=seed)
    x_train_torch = torch.tensor(df_train.drop(col_target, axis=1).values, dtype=torch.float)
    y_train_torch = torch.tensor(df_train[col_target].values, dtype=y_dtype)
    x_test_torch = torch.tensor(df_test.drop(col_target, axis=1).values, dtype=torch.float)
    y_test_torch = torch.tensor(df_test[col_target].values, dtype=y_dtype)

    return (x_train_torch, y_train_torch, x_test_torch, y_test_torch,
            build_metadata_object(metadata={**metadata, **m}, device=device, include_target=include_target))


def load_california_housing(data_dir: str, seed: Optional[int] = None, device: str = CPU,
               include_target: bool = False, cat_weights_method: str = "Uniform") -> Optional[
               Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, Dict]]:
    """
    Load all necessary pieces of data, put them into a single Pandas DataFrame and process
    :param data_dir: Directory where data are located
    :param seed: Seed for random number generator
    :param device: Device (`CPU` for CPU, `CUDA` for GPU)
    :param include_target: Whether include target column or not
    :return: Tuple with a Pandas DataFrame that contains the processed data and its associated metadata
    """
    print('Loading CALIFORNIA HOUSING dataset...')
    data_info = read_metadata(os.path.join(CONF_DIR, 'california.json'))

    data = pd.read_csv(os.path.join(data_dir, "CALIFORNIA_HOUSING", "data.csv"), index_col=0)
    df = pd.DataFrame(data=data.values)

    return process_data(df=df, metadata=data_info, seed=seed, device=device, include_target=include_target,
                        cat_weights_method=cat_weights_method)


def load_churn(data_dir: str, seed: Optional[int] = None, device: str = CPU,
               include_target: bool = False, cat_weights_method: str = "Uniform") -> Optional[
               Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, Dict]]:
    """
    Load all necessary pieces of data, put them into a single Pandas DataFrame and process
    :param data_dir: Directory where data are located
    :param seed: Seed for random number generator
    :param device: Device (`CPU` for CPU, `CUDA` for GPU)
    :param include_target: Whether include target column or not
    :return: Tuple with a Pandas DataFrame that contains the processed data and its associated metadata
    """
    print('Loading CHURN dataset...')
    data_info = read_metadata(os.path.join(CONF_DIR, 'churn.json'))

    data = pd.read_csv(os.path.join(data_dir, "CHURN", "processed.csv"), index_col=0)
    df = pd.DataFrame(data=data.values)

    return process_data(df=df, metadata=data_info, seed=seed, device=device, include_target=include_target,
                        cat_weights_method=cat_weights_method)


dataset_load_methods = {
    CALIFORNIA_HOUSING: load_california_housing,
    CHURN: load_churn,
}
