import pandas as pd
import numpy as np
import torch
import gc
from torch.utils.data import Dataset, DataLoader, TensorDataset
import math

class LOB_Dataset(Dataset):
    def __init__(self, data, targets, horizon, lookback, sample_gap=1):
        self.data = data
        self.targets = targets
        self.horizon = horizon
        self.lookback = lookback
        self.sample_gap = sample_gap
        print(f"data shape: {data.shape}, targets shape: {targets.shape}")
        assert(len(data) + horizon == len(targets))

    @property
    def x_shape(self):
        return self.data[:self.lookback, :].shape

    def __len__(self):
        return int(math.ceil((len(self.data) - self.lookback + 1) / self.sample_gap))

    def __getitem__(self, index):
        # use sample_gap to increase the ratio of effective samples
        idx = index * self.sample_gap
        sample = self.data[idx: idx + self.lookback, :]
        target = self.targets[idx + self.lookback + self.horizon - 1]
        sample = torch.tensor(sample, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)
        return sample, target


def compute_target(df, horizon, steps=1, rolling_norm=False, target_mu=None, target_std=None):
    if type(df) == pd.DataFrame:
        mid_price = np.array(df.loc[:, 'u2_Mid-Price_1'])
    elif type(df) == np.ndarray:
        mid_price = df[41,:].transpose() 
    ret = np.zeros_like(mid_price)
    ret[horizon:] = (mid_price[horizon:] / mid_price[:-horizon]) - 1
    if rolling_norm:
        ret, mu, std = normalize_zscore(ret, rolling=True, target_mu=target_mu, target_std=target_std)
    multistep_ret = torch.tensor(ret)
    if steps > 1:
        print(f"multistep_ret shape: {multistep_ret.shape}")
        multistep_ret = multistep_ret.unfold(dimension=0, size=steps, step=1)
        print(f"multistep_ret shape: {multistep_ret.shape}")
    if rolling_norm:
        return multistep_ret, ret, mu, std
    else:
        return multistep_ret, ret

def normalize_zscore(dataset, norm_windows=200, target_mu=None, target_std=None, rolling=True):
    
    if rolling:
        dataset = dataset[(dataset != 0.).all(axis=1)]
        feature_mu = dataset.rolling(norm_windows+1).mean()
        feature_std = dataset.rolling(norm_windows+1).std()
        dataset[norm_windows:] = (dataset[norm_windows:] - feature_mu[norm_windows:])/(feature_std[norm_windows:]+1e-8)
        dataset = dataset[norm_windows:]
    else:
        if target_mu is None:
            target_mu = dataset.mean()
        if target_std is None:
            target_std = dataset.std()
        dataset = (dataset - target_mu)/(target_std+1e-8)

    return dataset, target_mu, target_std

def load_data(dataset, num_features, horizon, lookback, batch_size, small=False, is_transformer=True, seed=1, rolling_norm=False, steps=1, sample_gap=1):
    if dataset == "FI":
        data_dir = 'data/FI-2010/BenchmarkDatasets'
        AUCTION = 'NoAuction'
        N = '1.'
        NORMALIZATION = 'Zscore'
        DATASET_TYPE = 'Training'
        DIR = data_dir + \
              "/{}".format(AUCTION) + \
              "/{}{}_{}".format(N, AUCTION, NORMALIZATION) + \
              "/{}_{}_{}".format(AUCTION, NORMALIZATION, DATASET_TYPE)

        DATASET_TYPE = 'Train'
        F_EXTENSION = '.txt'
        NORMALIZATION = 'ZScore'

        F_NAME = DIR + \
                 '/{}_Dst_{}_{}_CF_7'.format(DATASET_TYPE, AUCTION, NORMALIZATION) + \
                 F_EXTENSION

        out_df = np.loadtxt(F_NAME)

        n_samples_train = int(np.floor(out_df.shape[1] * 0.8))
        train_df = out_df[:, :n_samples_train]
        val_df = out_df[:, n_samples_train:]

        # Testing
        DATASET_TYPE = 'Testing'
        NORMALIZATION = 'Zscore'
        DIR = data_dir + \
              "/{}".format(AUCTION) + \
              "/{}{}_{}".format(N, AUCTION, NORMALIZATION) + \
              "/{}_{}_{}".format(AUCTION, NORMALIZATION, DATASET_TYPE)

        NORMALIZATION = 'ZScore'
        DATASET_TYPE = 'Test'
        F_EXTENSION = '.txt'
        F_NAMES = [
            DIR + \
            '/{}_Dst_{}_{}_CF_{}'.format(DATASET_TYPE, AUCTION, NORMALIZATION, i) + \
            F_EXTENSION
            for i in range(7, 10)
        ]
        test_df = np.hstack(
            [np.loadtxt(F_NAME) for F_NAME in F_NAMES]
        )

        train_y, train_ret = compute_target(train_df, horizon, steps=steps)
        train_X = np.column_stack((train_df[:num_features, :].transpose(), train_ret))
        val_y, val_ret = compute_target(val_df, horizon, steps=steps)
        val_X = np.column_stack((val_df[:num_features, :].transpose(), val_ret))
        test_y, test_ret = compute_target(test_df, horizon, steps=steps)
        test_X = np.column_stack((test_df[:num_features, :].transpose(), test_ret))


        train_X = train_X[horizon:-horizon-steps+1]
        val_X = val_X[horizon:-horizon-steps+1]
        test_X = test_X[horizon:-horizon-steps+1]

        train_y = train_y[horizon:]
        val_y = val_y[horizon:]
        test_y = test_y[horizon:]


    elif dataset == "CHF":
        print('reading labels')
        raw_ver2_labels = pd.read_pickle("")
        print('reading whole dataframe')
        raw_df = pd.read_pickle("")
        print('done')
        raw_df[raw_ver2_labels.columns] = raw_ver2_labels
        del raw_ver2_labels
        
        if small:
            raw_df = raw_df.iloc[:len(raw_df)//5]
            means = raw_df.mean(axis=0)
            stds = raw_df.std(axis=0)
            raw_df = (raw_df - means) / stds
            print("normalized data to account for 20%")
        print(f'Length of CHF dataframe used {len(raw_df)}')

        feature_keys = ['u1_BidPrice1', 'u1_AskPrice1', 'u1_BidVolume1', 'u1_AskVolume1',
       'u1_BidPrice2', 'u1_AskPrice2', 'u1_BidVolume2', 'u1_AskVolume2',
       'u1_BidPrice3', 'u1_AskPrice3', 'u1_BidVolume3', 'u1_AskVolume3',
       'u1_BidPrice4', 'u1_AskPrice4', 'u1_BidVolume4', 'u1_AskVolume4',
       'u1_BidPrice5', 'u1_AskPrice5', 'u1_BidVolume5', 'u1_AskVolume5']

        split = 0.8
        train_val_df = raw_df.iloc[:int(split * len(raw_df))]
        n_samples_train = int(np.floor(len(train_val_df) * 0.8))

        train_df = train_val_df.iloc[:n_samples_train]
        val_df = train_val_df.iloc[n_samples_train:]
        test_df = raw_df.iloc[int(split * len(raw_df)):-10]
        if rolling_norm:
            train_df, train_mu, train_std = normalize_zscore(train_df[feature_keys], norm_windows=200, rolling=True)
            val_df, _, _ = normalize_zscore(val_df[feature_keys], target_mu=train_mu, target_std=train_std, norm_windows=200, rolling=False)
            test_df, _, _ = normalize_zscore(test_df[feature_keys], target_mu=train_mu, target_std=train_std, norm_windows=200, rolling=False)
            train_y, train_ret, train_mu, train_std = compute_target(train_df, horizon, steps=steps, rolling_norm=rolling_norm)
            val_y, val_ret, _, _ = compute_target(val_df, horizon, steps=steps, rolling_norm=rolling_norm, target_mu=train_mu, target_std=train_std)
            test_y, test_ret, _, _ = compute_target(test_df, horizon, steps=steps, rolling_norm=rolling_norm, target_mu=train_mu, target_std=train_std)
            train_X = np.column_stack((train_df.iloc[:, :num_features], train_ret))
            val_X = np.column_stack((val_df.iloc[:, :num_features], val_ret))
            test_X = np.column_stack((test_df.iloc[:, :num_features], test_ret))
        else:
            train_y, train_ret = compute_target(train_df, horizon, steps=steps)
            train_X = np.column_stack((train_df.iloc[:, :num_features], train_ret))
            val_y, val_ret = compute_target(val_df, horizon, steps=steps)
            val_X = np.column_stack((val_df.iloc[:, :num_features], val_ret))
            test_y, test_ret = compute_target(test_df, horizon, steps=steps)
            test_X = np.column_stack((test_df.iloc[:, :num_features], test_ret))
        del raw_df

        print(f"train_df shape: {train_df.shape}, val_df shape: {val_df.shape}, test_df shape: {test_df.shape}")
        print(f"train_y shape: {train_y.shape}, val_y shape: {val_y.shape}, test_y shape: {test_y.shape}")
        
        print(f"train_X shape: {train_X.shape}")
        train_X = train_X[horizon:-horizon-steps+1]
        
        val_X = val_X[horizon:-horizon-steps+1]
        
        test_X = test_X[horizon:-horizon-steps+1]

        train_y = train_y[horizon:]
        val_y = val_y[horizon:]
        test_y = test_y[horizon:]

    elif dataset == "synthetic":
        data_dir = "data/synthetic/"
        train_df = pd.read_csv(data_dir + "/synthetic_time_series_train2.csv", index_col=0)
        val_df = pd.read_csv(data_dir + "/synthetic_time_series_val2.csv", index_col=0)
        test_df = pd.read_csv(data_dir + "/synthetic_time_series_test2.csv", index_col=0)
        train_y = train_df['price'].values
        val_y = val_df['price'].values
        test_y = test_df['price'].values
        train_X = train_df.iloc[:-horizon].values
        val_X = val_df.iloc[:-horizon].values
        test_X = test_df.iloc[:-horizon].values

        # normalize the data with regular zscore without rolling window with numpy compatible way
        train_X = (train_X - train_X.mean(axis=0)) / (train_X.std(axis=0) + 1e-8)
        val_X = (val_X - val_X.mean(axis=0)) / (val_X.std(axis=0) + 1e-8)
        test_X = (test_X - test_X.mean(axis=0)) / (test_X.std(axis=0) + 1e-8)
        train_y = (train_y - train_y.mean(axis=0)) / (train_y.std(axis=0) + 1e-8)
        val_y = (val_y - val_y.mean(axis=0)) / (val_y.std(axis=0) + 1e-8)
        test_y = (test_y - test_y.mean(axis=0)) / (test_y.std(axis=0) + 1e-8)

    generator = torch.Generator()
    generator.manual_seed(seed)

    if 'FI' in dataset or 'CHF' or 'synthetic' in dataset:
        train_dataset = LOB_Dataset(train_X, train_y, horizon, lookback, sample_gap=sample_gap)
        print(train_dataset.x_shape, batch_size)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=generator, drop_last=True, num_workers=4)
        val_dataset = LOB_Dataset(val_X, val_y, horizon, lookback, sample_gap=sample_gap)
        print(val_dataset.x_shape, batch_size)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, generator=generator, drop_last=True, num_workers=4)
        test_dataset = LOB_Dataset(test_X, test_y, horizon, lookback, sample_gap=sample_gap)
        print(test_dataset.x_shape, batch_size)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, generator=generator, drop_last=True, num_workers=4)

        del train_dataset
        del val_dataset
        del test_dataset
        gc.collect()

    return train_loader, val_loader, test_loader