import collections
from collections import Counter

import src.constants as cst
import numpy as np
import tqdm

from pprint import pprint


class FIDataBuilder:
    def __init__(
        self,
        fi_data_dir,
        feature_type,
        horizon=10,
        window=100,
        train_val_split=None,
        chosen_model=None,
        auction=False,
        normalization_type=cst.NormalizationType.Z_SCORE,
        levels=-1
    ):

        assert horizon in (1, 2, 3, 5, 10)

        self.feature_type = feature_type
        self.train_val_split = train_val_split
        self.chosen_model = chosen_model
        self.fi_data_dir = fi_data_dir
        self.auction = auction
        self.normalization_type = normalization_type
        self.horizon = horizon
        self.window = window
        self.levels = levels

        # KEY call, generates the dataset
        self.data, self.samples_x, self.samples_y = None, None, None
        self.__prepare_dataset()

    def __read_dataset(self):
        """ Reads the dataset from the FI files. """

        AUCTION = 'Auction' if self.auction else 'NoAuction'
        N = '1.' if self.normalization_type == cst.NormalizationType.Z_SCORE else '2.' if self.normalization_type == cst.NormalizationType.MINMAX else '3.'
        NORMALIZATION = 'Zscore' if self.normalization_type == cst.NormalizationType.Z_SCORE else 'MinMax' if self.normalization_type == cst.NormalizationType.MINMAX else 'DecPre'
        TRAIN_DIR = self.fi_data_dir + \
                 "{}".format(AUCTION) + \
                 "/{}{}_{}".format(N, AUCTION, NORMALIZATION) + \
                 "/{}_{}_Training".format(AUCTION, NORMALIZATION)
        
        TEST_DIR = self.fi_data_dir + \
                "{}".format(AUCTION) + \
                "/{}{}_{}".format(N, AUCTION, NORMALIZATION) + \
                "/{}_{}_Testing".format(AUCTION, NORMALIZATION)

        F_EXTENSION = '.txt'
        NORMALIZATION = 'ZScore' if self.normalization_type == cst.NormalizationType.Z_SCORE else 'MinMax' if self.normalization_type == cst.NormalizationType.MINMAX else 'DecPre'
        # if it is training time, we open the 7-days training file
        # if it is testing time, we open the 3 test files
      
        # train
        train_F_NAME = TRAIN_DIR + \
                 '/Train_Dst_{}_{}_CF_7'.format(AUCTION, NORMALIZATION) + \
                 F_EXTENSION
      
        test_F_NAMES = [
            TEST_DIR + \
            '/Test_Dst_{}_{}_CF_{}'.format(AUCTION, NORMALIZATION, i) + \
            F_EXTENSION
            for i in range(7, 10)
        ]

        train_out_df = np.loadtxt(train_F_NAME)

        n_samples_train = int(np.floor(train_out_df.shape[1] * self.train_val_split))
        self.train_df = train_out_df[:, :n_samples_train]
        self.val_df = train_out_df[:, n_samples_train:]

        
        # test
        test_out_df = np.hstack(
            [np.loadtxt(test_F_NAME) for test_F_NAME in test_F_NAMES]
        )

        self.test_df = test_out_df


    def __prepareX(self):
        """ 
        basic: first 40 features, aka  10 levels of the LOB
        insens: first 86 features, includes spread & mid-price, price differences, price & volume means, accumulated differences
        all: all 144 features, including basic, time insensitive, and price & volume derivations
        nonlob: all features excluding basic features, 104 in total.
        """
        if self.feature_type == cst.Features.basic:
            print('basic dataset chosen')
            if self.levels == -1:
                print('All LOB levels chosen')
                num_features = 40
            elif self.levels >=1 and self.levels <= 10:
                print(f"{self.levels} LOB levels chosen")
                num_features = self.levels*4
            else:
                print(f"Invalid number of lob levels chosen, must be 1<=x<=10")
            self.samples_x_train = self.train_df[:num_features, :].transpose()
            self.samples_x_val = self.val_df[:num_features, :].transpose()
            self.samples_x_test = self.test_df[:num_features, :].transpose()
        elif self.feature_type == cst.Features.insens:
            print('basic + time insensitive dataset chosen')
            self.samples_x_train = self.train_df[:86, :].transpose()
            self.samples_x_val = self.val_df[:86, :].transpose()
            self.samples_x_test = self.test_df[:86, :].transpose()
        elif self.feature_type == cst.Features.all:
            print('basic + time insensitive + sensitive dataset chosen')
            self.samples_x_train = self.train_df[:144, :].transpose()
            self.samples_x_val = self.val_df[:144, :].transpose()
            self.samples_x_test = self.test_df[:144, :].transpose()
        elif self.feature_type == cst.Features.nonlob:
            print('time insensitive + sensitive dataset chosen')
            self.samples_x_train = self.train_df[40:144, :].transpose()
            self.samples_x_val = self.val_df[40:144, :].transpose()
            self.samples_x_test = self.test_df[40:144, :].transpose()
        else:
            print("no feature type chosen")
            exit()

        print("Size of train dataframe: ", self.samples_x_train.shape)
        print("Created X dataframes")

    def __prepareY(self):
        """ gets the labels """
        # the last five elements in self.data contain the labels
        # they are based on the possible horizon values [1, 2, 3, 5, 10]

        print("getting labels")

        if self.chosen_model == cst.Models.DEEPLOBATT:
            self.samples_y_train = self.train_df[-5:,].transpose()
            self.samples_y_val = self.val_df[-5:,].transpose()
            self.samples_y_test = self.test_df[-5:,].transpose()
        else:
            self.samples_y_train = self.train_df[cst.HORIZONS_MAPPINGS[self.horizon], :].transpose()
            self.samples_y_val = self.val_df[cst.HORIZONS_MAPPINGS[self.horizon], :].transpose()
            self.samples_y_test = self.test_df[cst.HORIZONS_MAPPINGS[self.horizon], :].transpose()
          
        self.samples_y_train -= 1
        self.samples_y_val -= 1
        self.samples_y_test -= 1

        print("Created Y dataframes")


    def __prepare_dataset(self):
        """ Crucial call! """

        self.__read_dataset()

        self.__prepareX()
        self.__prepareY()
        #  self.__snapshotting()

        #occurrences = collections.Counter(self.samples_y)
        #print("dataset type:", self.dataset_type, "- occurrences:", occurrences)
        #if not self.dataset_type == co.DatasetType.TEST:
            #self.__under_sampling()

        print("Prepared Train Validation and Test Datasets for FI, normalization:", self.normalization_type)
        print()

    # def get_data(self, first_half_split=1):
    #     return self.data

    def get_samples_x_train(self, first_half_split=1):
        return self.samples_x_train
    def get_samples_x_val(self, first_half_split=1):
        return self.samples_x_val
    def get_samples_x_test(self, first_half_split=1):
        return self.samples_x_test

    def get_samples_y_train(self, first_half_split=1):
        return self.samples_y_train
    def get_samples_y_val(self, first_half_split=1):
        return self.samples_y_val
    def get_samples_y_test(self, first_half_split=1):
        return self.samples_y_test
