import src.constants as cst
import numpy as np
import pandas as pd
import tqdm

from pprint import pprint

class CHFDataBuilder:
    def __init__(
        self,
        chf_data_dir,
        # dataset_type,
        feature_type,
        horizon=10,
        window=100,
        train_val_split=None,
        chosen_model=None,
        normalization_type=cst.NormalizationType.Z_SCORE,
    ):

        assert horizon in (1, 2, 3, 5, 10)

        # self.dataset_type = dataset_type
        self.feature_type = feature_type
        self.train_val_split = train_val_split
        self.chosen_model = chosen_model
        self.chf_data_dir = chf_data_dir
        self.normalization_type = normalization_type
        self.horizon = horizon
        self.window = window

        # KEY call, generates the dataset
        self.data, self.samples_x, self.samples_y = None, None, None
        self.__prepare_dataset()

    def __read_dataset(self):
        """ Reads the dataset from the CHF files. """

        NORMALIZATION = 'zscore' if self.normalization_type == cst.NormalizationType.Z_SCORE else 'minmax' if self.normalization_type == cst.NormalizationType.MINMAX else None
        DIR = self.chf_data_dir
        
        F_EXTENSION = '.pkl'

        F_NAME = DIR + \
                 'INE_sc_feature_u1tou6_{}_0.00002'.format(NORMALIZATION) + \
                 F_EXTENSION

        self.raw_df = pd.read_pickle(F_NAME)
        # self.raw_df = self.raw_df.iloc[:int(0.01*len(self.raw_df))]
        self.train_val_df = self.raw_df.iloc[:int(0.8*len(self.raw_df))]
        n_samples_train = int(np.floor(len(self.train_val_df) * self.train_val_split))
      
        self.train_df = self.train_val_df.iloc[:n_samples_train]
        self.val_df = self.train_val_df.iloc[n_samples_train:]
        self.test_df = self.raw_df.iloc[int(0.8*len(self.raw_df)):-10]
      
        print(f"raw df: {len(self.raw_df)}, train df: {len(self.train_df)}, val df: {len(self.val_df)}, test df: {len(self.test_df)} samples")
        

    def __prepareX(self):
        """ 
        basic: first 20 features, aka  5 levels of the LOB
        time insensitive: first 46 features, includes spread & mid-price, price differences, price & volume means, accumulated differences
        all: all 66 features, including basic, time insensitive, and price & volume derivations
        nonlob: all features excluding basic features, 46 in total.
        """
        if self.feature_type == cst.Features.basic:
            print('basic dataset chosen')
            self.samples_x_train = self.train_df.iloc[:, :20]
            self.samples_x_val = self.val_df.iloc[:, :20]
            self.samples_x_test = self.test_df.iloc[:, :20]
            print("Number of features: ", len(self.samples_x_train.columns))
        elif self.feature_type == cst.Features.insens:
            print('basic + time insensitive dataset chosen')
            self.samples_x_train = self.train_df.iloc[:, :46]
            self.samples_x_val = self.val_df.iloc[:, :46]
            self.samples_x_test = self.test_df.iloc[:, :46]
            print("Number of features: ", len(self.samples_x_train.columns))
        elif self.feature_type == cst.Features.all:
            print('basic + time insensitive + sensitive dataset chosen')
            self.samples_x_train = self.train_df.iloc[:, :66]
            self.samples_x_val = self.val_df.iloc[:, :66]
            self.samples_x_test = self.test_df.iloc[:, :66]
            print("Number of features: ", len(self.samples_x_train.columns))
        elif self.feature_type == cst.Features.nonlob:
            print('time insensitive + sensitive features chosen')
            self.samples_x_train = self.train_df.iloc[:, 20:66]
            self.samples_x_val = self.val_df.iloc[:, 20:66]
            self.samples_x_test = self.test_df.iloc[:, 20:66]
            print("Number of features: ", len(self.samples_x_train.columns))
        else:
            print("no feature type chosen")
            exit()

        print("Created X dataframes")

    def __prepareY(self):
        """ gets the labels """
        # the last five elements in self.data contain the labels
        # they are based on the possible horizon values [1, 2, 3, 5, 10]

        print("getting labels")
        if self.chosen_model == cst.Models.DEEPLOBATT:
            # self.samples_y = self.data.iloc[:, -5:]
            # self.samples_y.shape = (n_samples, 5)
            self.samples_y_train = self.train_df.iloc[:, -5:]
            self.samples_y_val = self.val_df.iloc[:, -5:]
            self.samples_y_test = self.test_df.iloc[:, -5:]
        else:
            # self.samples_y = self.data.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            # self.samples_y.shape = (n_samples,)
            self.samples_y_train = self.train_df.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            self.samples_y_val = self.val_df.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            self.samples_y_test = self.test_df.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
          
        self.samples_y_train -= 1
        self.samples_y_val -= 1
        self.samples_y_test -= 1

        print("Created Y dataframes")


    def __prepare_dataset(self):
        """ Crucial call! """

        self.__read_dataset()

        self.__prepareX()
        self.__prepareY()
        #  self.__snapshotting()

        #occurrences = collections.Counter(self.samples_y)
        #print("dataset type:", self.dataset_type, "- occurrences:", occurrences)
        #if not self.dataset_type == co.DatasetType.TEST:
            #self.__under_sampling()

        print("Prepared Train, Validation, and Test Datasets for CHF, normalization:", self.normalization_type)
        print()

    def get_data(self, first_half_split=1):
        return self.raw_df

    def get_samples_x_train(self, first_half_split=1):
        return self.samples_x_train
    def get_samples_x_val(self, first_half_split=1):
        return self.samples_x_val
    def get_samples_x_test(self, first_half_split=1):
        return self.samples_x_test

    def get_samples_y_train(self, first_half_split=1):
        return self.samples_y_train
    def get_samples_y_val(self, first_half_split=1):
        return self.samples_y_val
    def get_samples_y_test(self, first_half_split=1):
        return self.samples_y_test


