
import numpy as np
from collections import Counter
import gc

import src.constants as cst
from src.config import Configuration
from src.data_preprocessing.CHF.CHFDataBuilder import CHFDataBuilder
from src.data_preprocessing.FI.FIDataBuilder import FIDataBuilder
from src.data_preprocessing.mprf.MPRFDataBuilder import MPRFDataBuilder

# DATASETS
from src.data_preprocessing.FI.FIDataset import FIDataset
from src.data_preprocessing.CHF.CHFDataset import CHFDataset
from src.data_preprocessing.mprf.MPRFDataset import MPRFDataset


from src.data_preprocessing.DataModule import DataModule


def prepare_data_fi(config: Configuration):

    fi_databuilder = FIDataBuilder(
        cst.DATA_SOURCE + cst.DATASET_FI,
        # dataset_type=cst.DatasetType.TEST,
        feature_type = config.CHOSEN_FEATURES,
        horizon=config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
        window=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
        train_val_split=config.TRAIN_SPLIT_VAL,
        chosen_model=config.CHOSEN_MODEL,
        levels=config.LOB_LEVELS
    )

    train_set = FIDataset(
        x=fi_databuilder.get_samples_x_train(),
        y=fi_databuilder.get_samples_y_train(),
        chosen_model=config.CHOSEN_MODEL,
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )
    perc_cl = lambda a: np.array(list(a.values())) / sum(a.values())

    # print("TRAIN balance", Counter(fi_databuilder.get_samples_y_train()), perc_cl(Counter(fi_databuilder.get_samples_y_train())))

    val_set = FIDataset(
        x=fi_databuilder.get_samples_x_val(),
        y=fi_databuilder.get_samples_y_val(),
        chosen_model=config.CHOSEN_MODEL,
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )
    # print("VAL balance", Counter(fi_databuilder.get_samples_y_val()), perc_cl(Counter(fi_databuilder.get_samples_y_val())))

    test_set = FIDataset(
        x=fi_databuilder.get_samples_x_test(),
        y=fi_databuilder.get_samples_y_test(),
        chosen_model=config.CHOSEN_MODEL,
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )

    # print("TEST balance", Counter(fi_databuilder.get_samples_y_test()), perc_cl(Counter(fi_databuilder.get_samples_y_test())))
    print()

    fi_dm = DataModule(
        train_set, val_set, test_set,
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.BATCH_SIZE],
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.IS_SHUFFLE_TRAIN_SET]
    )
    del fi_databuilder
    del train_set
    del val_set
    del test_set
    gc.collect()
  
    return fi_dm


def prepare_data_chf(config: Configuration):
  
    chf_databuilder = CHFDataBuilder(
        cst.DATASET_CHF,
        feature_type=config.CHOSEN_FEATURES, 
        horizon=config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
        window=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
        train_val_split=config.TRAIN_SPLIT_VAL,
        chosen_model=config.CHOSEN_MODEL,
        levels=config.LOB_LEVELS,
        alpha=config.ALPHA
    )

    train_set = CHFDataset(
        x=chf_databuilder.get_samples_x_train(),
        y=chf_databuilder.get_samples_y_train(),
        chosen_model=config.CHOSEN_MODEL,
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )
    perc_cl = lambda a: np.array(list(a.values())) / sum(a.values())

    print("TRAIN balance", Counter(chf_databuilder.get_samples_y_train()), perc_cl(Counter(chf_databuilder.get_samples_y_train())))

    val_set = CHFDataset(
        x=chf_databuilder.get_samples_x_val(),
        y=chf_databuilder.get_samples_y_val(),
        chosen_model=config.CHOSEN_MODEL,
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )
    print("VAL balance", Counter(chf_databuilder.get_samples_y_val()), perc_cl(Counter(chf_databuilder.get_samples_y_val())))

    test_set = CHFDataset(
        x=chf_databuilder.get_samples_x_test(),
        y=chf_databuilder.get_samples_y_test(),
        chosen_model=config.CHOSEN_MODEL,
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )

    print("TEST balance", Counter(chf_databuilder.get_samples_y_test()), perc_cl(Counter(chf_databuilder.get_samples_y_test())))
    print()

    chf_dm = DataModule(
        train_set, val_set, test_set,
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.BATCH_SIZE],
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.IS_SHUFFLE_TRAIN_SET]
    )

    del chf_databuilder
    del train_set
    del val_set
    del test_set
    gc.collect()

    return chf_dm

def prepare_data_regression(config: Configuration):
    if config.CHOSEN_DATASET == cst.DatasetFamily.FI:
        data_dir = cst.DATASET_FI
    if config.CHOSEN_DATASET == cst.DatasetFamily.CHF:
        data_dir = cst.Dataset_CHF

    rf_databuilder = MPRFDataBuilder(
        config.CHOSEN_DATASET,
        data_dir,
        config.CHOSEN_FEATURES,
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
        train_val_split=config.TRAIN_SPLIT_VAL
    )

    train_set = MPRFDataset(
        x=rf_databuilder.samples_x_train,
        y=rf_databuilder.samples_y_train,
        chosen_model=config.CHOSEN_MODEL,
        horizon=config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )

    val_set = MPRFDataset(
        x=rf_databuilder.samples_x_val,
        y=rf_databuilder.samples_y_val,
        chosen_model=config.CHOSEN_MODEL,
        horizon=config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )

    test_set = MPRFDataset(
        x=rf_databuilder.samples_x_test,
        y=rf_databuilder.samples_y_test,
        chosen_model=config.CHOSEN_MODEL,
        horizon=config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
        num_snapshots=config.HYPER_PARAMETERS[cst.LearningHyperParameter.NUM_SNAPSHOTS],
    )

    rf_dm = DataModule(
        train_set, val_set, test_set,
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.BATCH_SIZE],
        config.HYPER_PARAMETERS[cst.LearningHyperParameter.IS_SHUFFLE_TRAIN_SET]
    )

    del rf_databuilder
    del train_set
    del val_set
    del test_set
    gc.collect()

    return rf_dm


    

def pick_dataset(config: Configuration):

    if config.CHOSEN_DATASET == cst.DatasetFamily.FI:
        return prepare_data_fi(config)

    elif config.CHOSEN_DATASET == cst.DatasetFamily.CHF:
        return prepare_data_chf(config)

        
