import src.constants as cst
import numpy as np
import pandas as pd
import tqdm
from mprf.data_preprocessing.rolling_norms import welford

from pprint import pprint

class BTCDataBuilder:
    def __init__(
        self,
        feature_type,
        horizon=10,
        window=100,
        train_val_split=None,
        chosen_model=None,
        normalization_type=cst.NormalizationType.Z_SCORE,
        levels=-1,
        rw=20
    ):

        assert horizon in (1, 2, 3, 5, 10)

        self.feature_type = feature_type
        self.train_val_split = train_val_split
        self.chosen_model = chosen_model
        self.normalization_type = normalization_type
        self.horizon = horizon
        self.window = window
        self.levels = levels
        self.rw = rw

        # KEY call, generates the dataset
        self.data, self.samples_x, self.samples_y = None, None, None
        self.__prepare_dataset()

    def __read_dataset(self):
        """ Reads the dataset from the CHF files. """
        if self.rw != -1:
            F_TRAIN = '.data/btcusdt_train_rw{}_a0.0001.pkl'.format(self.rw)
            F_VAL = './data/btcusdt_val_rw{}_a0.0001.pkl'.format(self.rw)
            F_TEST = './data/btcusdt_test_rw{}_a0.0001.pkl'.format(self.rw)
        else:
            F_TRAIN = './data/btcusdt_train_chunknorm_a0.0001.pkl'.format(self.rw)
            F_VAL = './data/btcusdt_val_chunknorm_a0.0001.pkl'.format(self.rw)
            F_TEST = './data/btcusdt_test_chunknorm_a0.0001.pkl'.format(self.rw)

        print(f'Reading dataset {F_TRAIN}')
        self.train_df = pd.read_pickle(F_TRAIN)
        self.train_x, self.train_y = self.train_df.iloc[:,:-5], self.train_df.iloc[:,-5:]

        self.val_df = pd.read_pickle(F_VAL)
        self.val_x, self.val_y = self.val_df.iloc[:,:-5], self.val_df.iloc[:,-5:]

        self.test_df = pd.read_pickle(F_TEST)
        self.test_x, self.test_y = self.test_df.iloc[:,:-5], self.test_df.iloc[:,-5:]
      
        print(f"train df: {len(self.train_x)}, val df: {len(self.val_x)}, test df: {len(self.test_x)} samples")
        

    def __prepareX(self):
        """ 
        basic: first 20 features, aka  5 levels of the LOB
        time insensitive: first 46 features, includes spread & mid-price, price differences, price & volume means, accumulated differences
        all: all 66 features, including basic, time insensitive, and price & volume derivations
        nonlob: all features excluding basic features, 46 in total.
        """
        if self.feature_type == cst.Features.basic:
            print('basic dataset chosen')
            if self.levels == -1:
                print('all LOB levels chosen')
                num_features = 40
            elif self.levels >= 1 and self.levels <=5:
                print(f"{self.levels} LOB levels chosen")
                num_features = self.levels*4
            else:
                print(f"Invalid number of lob levels chosen, must be 1<=x<=5")
            self.samples_x_train = self.train_x.iloc[:, :num_features]
            self.samples_x_val = self.val_x.iloc[:, :num_features]
            self.samples_x_test = self.test_x.iloc[:, :num_features]
        else:
            print("no feature type chosen")
            exit()

        print("Number of features: ", len(self.samples_x_train.columns))
        print("Created X dataframes")

    def __prepareY(self):
        """ gets the labels """
        # the last five elements in self.data contain the labels
        # they are based on the possible horizon values [1, 2, 3, 5, 10]

        print("getting labels")
        if self.chosen_model == cst.Models.DEEPLOBATT:
            # self.samples_y = self.data.iloc[:, -5:]
            # self.samples_y.shape = (n_samples, 5)
            self.samples_y_train = self.train_y
            self.samples_y_val = self.val_y
            self.samples_y_test = self.test_y
        else:
            self.samples_y_train = self.train_y.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            self.samples_y_val = self.val_y.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
            self.samples_y_test = self.test_y.iloc[:, cst.HORIZONS_MAPPINGS[self.horizon]]
          
        self.samples_y_train -= 1
        self.samples_y_val -= 1
        self.samples_y_test -= 1

        print("Created Y dataframes")


    def __prepare_dataset(self):
        """ Crucial call! """

        self.__read_dataset()

        self.__prepareX()
        self.__prepareY()
        #  self.__snapshotting()

        #occurrences = collections.Counter(self.samples_y)
        #print("dataset type:", self.dataset_type, "- occurrences:", occurrences)
        #if not self.dataset_type == co.DatasetType.TEST:
            #self.__under_sampling()

        print("Prepared Train, Validation, and Test Datasets for BTC")
        print()

    def get_data(self, first_half_split=1):
        return self.raw_df

    def get_samples_x_train(self, first_half_split=1):
        return self.samples_x_train
    def get_samples_x_val(self, first_half_split=1):
        return self.samples_x_val
    def get_samples_x_test(self, first_half_split=1):
        return self.samples_x_test

    def get_samples_y_train(self, first_half_split=1):
        return self.samples_y_train
    def get_samples_y_val(self, first_half_split=1):
        return self.samples_y_val
    def get_samples_y_test(self, first_half_split=1):
        return self.samples_y_test


