import src.constants as cst
import numpy as np
import pandas as pd
import tqdm

from pprint import pprint

class CHFDataBuilder:
    def __init__(
        self,
        chf_data_dir,
        feature_type,
        horizon=10,
        window=100,
        train_val_split=None,
        chosen_model=None,
        normalization_type=cst.NormalizationType.Z_SCORE,
        levels=-1,
        alpha=0.00002
    ):

        assert horizon in (1, 2, 3, 5, 10)

        self.feature_type = feature_type
        self.train_val_split = train_val_split
        self.chosen_model = chosen_model
        self.chf_data_dir = chf_data_dir
        self.normalization_type = normalization_type
        self.horizon = horizon
        self.window = window
        self.levels = levels
        self.alpha = alpha

        # KEY call, generates the dataset
        self.data, self.samples_x, self.samples_y = None, None, None
        self.__prepare_dataset()

    def __read_dataset(self):
        """ Reads the dataset from the CHF files. """

        NORMALIZATION = 'zscore' if self.normalization_type == cst.NormalizationType.Z_SCORE else 'minmax' if self.normalization_type == cst.NormalizationType.MINMAX else None
        DIR = self.chf_data_dir
        
        F_EXTENSION = '.pkl'

        F_NAME = DIR + \
                 '' + \
                 F_EXTENSION

        print(f'Reading dataset {F_NAME}')
        self.raw_df = pd.read_pickle(F_NAME)
        # self.raw_df = self.raw_df.iloc[:int(0.01*len(self.raw_df))] # make it about the same size as the FI-2010 dataset
        self.train_val_df = self.raw_df.iloc[:int(0.8*len(self.raw_df))]
        n_samples_train = int(np.floor(len(self.train_val_df) * self.train_val_split))
      
        self.train_df = self.train_val_df.iloc[:n_samples_train]
        self.val_df = self.train_val_df.iloc[n_samples_train:]
        self.test_df = self.raw_df.iloc[int(0.8*len(self.raw_df)):-10]
      
        print(f"raw df: {len(self.raw_df)}, train df: {len(self.train_df)}, val df: {len(self.val_df)}, test df: {len(self.test_df)} samples")
        

    def __prepareX(self):
        """ 
        basic: first 20 features, aka  5 levels of the LOB
        time insensitive: first 46 features, includes spread & mid-price, price differences, price & volume means, accumulated differences
        all: all 66 features, including basic, time insensitive, and price & volume derivations
        nonlob: all features excluding basic features, 46 in total.
        """
        if self.feature_type == cst.Features.basic:
            print('basic dataset chosen')
            if self.levels == -1:
                print('all LOB levels chosen')
                num_features = 20
            elif self.levels >= 1 and self.levels <=5:
                print(f"{self.levels} LOB levels chosen")
                num_features = self.levels*4
            else:
                print(f"Invalid number of lob levels chosen, must be 1<=x<=5")
            self.samples_x_train = self.train_df.iloc[:, :num_features]
            self.samples_x_val = self.val_df.iloc[:, :num_features]
            self.samples_x_test = self.test_df.iloc[:, :num_features]
        elif self.feature_type == cst.Features.insens:
            print('basic + time insensitive dataset chosen')
            num_features = 46
            self.samples_x_train = self.train_df.iloc[:, :num_features]
            self.samples_x_val = self.val_df.iloc[:, :num_features]
            self.samples_x_test = self.test_df.iloc[:, :num_features]
        elif self.feature_type == cst.Features.all:
            print('basic + time insensitive + sensitive dataset chosen')
            num_features = 66
            self.samples_x_train = self.train_df.iloc[:, :num_features]
            self.samples_x_val = self.val_df.iloc[:, :num_features]
            self.samples_x_test = self.test_df.iloc[:, :num_features]
        elif self.feature_type == cst.Features.nonlob:
            print('time insensitive + sensitive features chosen')
            self.samples_x_train = self.train_df.iloc[:, 20:66]
            self.samples_x_val = self.val_df.iloc[:, 20:66]
            self.samples_x_test = self.test_df.iloc[:, 20:66]
        else:
            print("no feature type chosen")
            exit()

        print("Number of features: ", len(self.samples_x_train.columns))
        print(self.samples_x_train.columns)
        print("Created X dataframes")

    def __prepareY(self):
        """ gets the labels """
        # the last five elements in self.data contain the labels
        # they are based on the possible horizon values [1, 2, 3, 5, 10]

        print("getting labels")
        if self.chosen_model == cst.Models.DEEPLOBATT:
            # self.samples_y = self.data.iloc[:, -5:]
            # self.samples_y.shape = (n_samples, 5)
            self.samples_y_train = self.train_df.iloc[:, -5:]
            self.samples_y_val = self.val_df.iloc[:, -5:]
            self.samples_y_test = self.test_df.iloc[:, -5:]
        else:
            # self.samples_y = self.data.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            # self.samples_y.shape = (n_samples,)
            self.samples_y_train = self.train_df.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            self.samples_y_val = self.val_df.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            self.samples_y_test = self.test_df.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
          
        self.samples_y_train -= 1
        self.samples_y_val -= 1
        self.samples_y_test -= 1

        print("Created Y dataframes")


    def __prepare_dataset(self):
        """ Crucial call! """

        self.__read_dataset()

        self.__prepareX()
        self.__prepareY()
        #  self.__snapshotting()

        #occurrences = collections.Counter(self.samples_y)
        #print("dataset type:", self.dataset_type, "- occurrences:", occurrences)
        #if not self.dataset_type == co.DatasetType.TEST:
            #self.__under_sampling()

        print("Prepared Train, Validation, and Test Datasets for CHF, normalization:", self.normalization_type)
        print()

    def get_data(self, first_half_split=1):
        return self.raw_df

    def get_samples_x_train(self, first_half_split=1):
        return self.samples_x_train
    def get_samples_x_val(self, first_half_split=1):
        return self.samples_x_val
    def get_samples_x_test(self, first_half_split=1):
        return self.samples_x_test

    def get_samples_y_train(self, first_half_split=1):
        return self.samples_y_train
    def get_samples_y_val(self, first_half_split=1):
        return self.samples_y_val
    def get_samples_y_test(self, first_half_split=1):
        return self.samples_y_test


