import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pytorch_lightning as pl
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm

# Constants
LEARNING_RATE = 0.0001
BATCH_SIZE = 64 * 2
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True


class LOB_Dataset(Dataset):
    def __init__(self, data, targets, horizon, lookback):
        self.data = data
        self.targets = targets
        self.horizon = horizon
        self.lookback = lookback

    def __len__(self):
        return len(self.targets) - self.lookback - self.horizon - 1

    def __getitem__(self, index):
        sample = self.data[index: index + self.lookback, :]
        target = self.targets[index + self.lookback]
        sample = torch.tensor(sample, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)
        return sample, target


def compute_target(df, horizon):
    mid_price = df.loc[:, 'u2_Mid-Price_1'].array
    ret = np.ones_like(mid_price)
    ret[horizon:] = (mid_price[horizon:] / mid_price[:-horizon]) - 1
    return ret


def load_data(dataset, data_dir, num_features, horizon, lookback, batch_size):
    if dataset == "FI":
        AUCTION = 'NoAuction'
        N = '1.'
        NORMALIZATION = 'Zscore'
        DATASET_TYPE = 'Training'
        DIR = data_dir + \
              "/{}".format(AUCTION) + \
              "/{}{}_{}".format(N, AUCTION, NORMALIZATION) + \
              "/{}_{}_{}".format(AUCTION, NORMALIZATION, DATASET_TYPE)

        DATASET_TYPE = 'Train'
        F_EXTENSION = '.txt'

        F_NAME = DIR + \
                 '/{}_Dst_{}_{}_CF_7'.format(DATASET_TYPE, AUCTION, NORMALIZATION) + \
                 F_EXTENSION

        out_df = np.loadtxt(F_NAME)

        n_samples_train = int(np.floor(out_df.shape[1] * 0.8))
        train_df = out_df[:, :n_samples_train]
        val_df = out_df[:, n_samples_train:]

        # Testing
        DATASET_TYPE = 'Testing'
        DIR = data_dir + \
              "/{}".format(AUCTION) + \
              "/{}{}_{}".format(N, AUCTION, NORMALIZATION) + \
              "/{}_{}_{}".format(AUCTION, NORMALIZATION, DATASET_TYPE)

        NORMALIZATION = 'ZScore'
        DATASET_TYPE = 'Test'
        F_EXTENSION = '.txt'
        F_NAMES = [
            DIR + \
            '/{}_Dst_{}_{}_CF_{}'.format(DATASET_TYPE, AUCTION, NORMALIZATION, i) + \
            F_EXTENSION
            for i in range(7, 10)
        ]
        test_df = np.hstack(
            [np.loadtxt(F_NAME) for F_NAME in F_NAMES]
        )

        test_df = test_df[:, train_df.shape[1]:]

        train_X = train_df[:num_features, :].transpose()
        val_X = val_df[:num_features, :].transpose()
        test_X = test_df[:num_features, :].transpose()

        # Label
        mid_price = train_df[41, :].transpose()
        ret = np.ones_like(mid_price)
        ret[:-horizon] = (mid_price[horizon:] / mid_price[:-horizon]) - 1
        train_y = ret

        mid_price = val_df[41, :].transpose()
        ret = np.ones_like(mid_price)
        ret[:-horizon] = (mid_price[horizon:] / mid_price[:-horizon]) - 1
        val_y = ret

        mid_price = test_df[41, :].transpose()
        ret = np.ones_like(mid_price)
        ret[:-horizon] = (mid_price[horizon:] / mid_price[:-horizon]) - 1
        test_y = ret

    elif dataset == "CHF":
        raw_df = pd.read_pickle(data_dir)
        split = 0.8
        train_val_df = raw_df.iloc[:int(split * len(raw_df))]
        n_samples_train = int(np.floor(len(train_val_df) * 0.8))

        train_df = train_val_df.iloc[:n_samples_train]
        val_df = train_val_df.iloc[n_samples_train:]
        test_df = raw_df.iloc[int(split * len(raw_df)):-10]

        train_y = compute_target(train_df, horizon)
        val_y = compute_target(val_df, horizon)
        test_y = compute_target(test_df, horizon)

    generator = torch.Generator()
    generator.manual_seed(42)

    train_dataset = LOB_Dataset(train_X, train_y, horizon, lookback)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=generator)

    val_dataset = LOB_Dataset(val_X, val_y, horizon, lookback)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, generator=generator)

    test_dataset = LOB_Dataset(test_X, test_y, horizon, lookback)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, generator=generator)

    return train_loader, val_loader, test_loader


class CNNLSTM(pl.LightningModule):
    def __init__(self, num_features, num_classes, batch_size, seq_len, hidden_size, num_layers, hidden_mlp, p_dropout):
        super().__init__()

        self.num_features = num_features
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.num_layers = num_layers  # 1
        self.hidden_size = hidden_size  # 32
        self.hidden_mlp = hidden_mlp  # 32/64

        self.seq_len = seq_len  # number of snapshots (100)

        # Convolution 1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(5, 42), padding=(0, 2))
        self.bn1 = nn.BatchNorm2d(16)
        self.prelu1 = nn.PReLU()

        # Convolution 2
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=(5,))
        self.bn2 = nn.BatchNorm1d(16)
        self.prelu2 = nn.PReLU()

        # Convolution 3
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=(5,))
        self.bn3 = nn.BatchNorm1d(32)
        self.prelu3 = nn.PReLU()

        # Convolution 4 
        self.conv4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=(5,))
        self.bn4 = nn.BatchNorm1d(32)
        self.prelu4 = nn.PReLU()

        self.lstm_input = self.get_lstm_input_size(num_features, seq_len)
        
        # lstm layers
        self.lstm = nn.LSTM(
            input_size=self.lstm_input,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        
        # fully connected
        self.fc1 = nn.Linear(hidden_size, hidden_mlp)  # fully connected
        self.dropout = nn.Dropout(p=p_dropout)  # not specified
        self.prelu = nn.PReLU()
        
        self.fc2 = nn.Linear(hidden_mlp, self.num_classes)  # out layer

    def get_lstm_input_size(self, num_features, seq_len):
        with torch.no_grad():
            sample_in = torch.ones(self.batch_size, 1, seq_len, num_features) # batch_size, 1, seq_len, num_features
            sample_out = self.convolution_forward(sample_in)

        return sample_out.shape[-1]

    def forward(self, x):
        # Adding the channel dimension
        x = x[:, None, :]  # x.shape = [batch_size, 1, 100, 40]

        # print('x.shape:', x.shape)
        
        out = self.convolution_forward(x)
        # print('After convolution_forward:', out.shape)

        # lstm
        _, (hn, _) = self.lstm(out)
        # print('After lstm:', hn.shape)

        # flatten
        hn = hn.view(-1, self.hidden_size)
        # print('After flatten:', hn.shape)

        out = self.fc1(hn)
        # print('After fc1:', out.shape)

        out = self.dropout(out)
        out = self.prelu(out)

        out = self.fc2(out)
        # print('After fc2:', out.shape)

        return out

    def convolution_forward(self, x):
        # print('Starting convolution_forward')

        # print('x.shape:', x.shape)

        # Convolution 1
        out = self.conv1(x)
        # print('After convolution1:', out.shape)

        out = self.bn1(out)
        # print('After bn1:', out.shape)

        out = self.prelu1(out)
        out = out.reshape(out.shape[0], out.shape[1], -1)
        # print('After prelu1:', out.shape)

        # Convolution 2
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.prelu2(out)
        # print('After convolution2, bn2, prelu2:', out.shape)

        # Convolution 3
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.prelu3(out)
        # print('After convolution3, bn3, prelu3:', out.shape)

        # Convolution 4
        out = self.conv4(out)
        out = self.bn4(out)
        out = self.prelu4(out)
        # print('After convolution4, bn4, prelu4:', out.shape)

        # print('Ending convolution_forward')

        return out


def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2


def test(model, test_loader):
    all_predictions = []
    all_targets = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            all_predictions.append(outputs)
            all_targets.append(targets)
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    mse, corr, r2 = evaluate_predictions(all_targets, all_predictions)
    print(f"Test - MSE: {mse}, Correlation: {corr}, R2: {r2}")


def train(train_loader, val_loader, num_features, lr, decay, num_epochs, scheduler_flag=None, scheduler_step_flag=None, step_size=None, gamma=None, loss_type='MSE', loss_calculation_mp_flag=True, file_name_prefix=None):
    
    model = CNNLSTM(
            num_features=num_features,
            num_classes=1,
            batch_size=1,
            seq_len=train_loader.dataset[0][0].shape[0],
            hidden_size=32,
            num_layers=1,
            hidden_mlp=128,
            p_dropout=0,
        )
    
    train_losses, valid_losses = [], []
    train_mse, train_r2, train_corr = [], [], []
    valid_mse, valid_r2, valid_corr = [], [], []

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
    if loss_type == 'MSE':
        criterion = torch.nn.MSELoss()
    elif loss_type == 'L1':
        criterion = torch.nn.L1Loss()

    if scheduler_flag:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    epochs = num_epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        total_loss = 0
        for idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")):
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if scheduler_step_flag:
            scheduler.step()
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        model.eval()
        valid_loss = 0
        all_targets, all_outputs = [], []
        for inputs, targets in val_loader:
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            all_targets.extend(targets.cpu().numpy())
            all_outputs.extend(outputs.detach().cpu().numpy())
        avg_valid_loss = valid_loss / len(val_loader)
        valid_losses.append(avg_valid_loss)
        train_mse_val = mean_squared_error(all_targets, all_outputs)
        train_r2_val = r2_score(all_targets, all_outputs)
        train_corr_val, _ = pearsonr(np.array(all_targets).flatten(), np.array(all_outputs).flatten())
        valid_mse.append(train_mse_val)
        valid_r2.append(train_r2_val)
        valid_corr.append(train_corr_val)
        all_targets, all_outputs = [], []
        for inputs, targets in train_loader:
            targets = targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            all_targets.extend(targets.cpu().numpy())
            all_outputs.extend(outputs.detach().cpu().numpy())
        train_mse_t = mean_squared_error(all_targets, all_outputs)
        train_r2_t = r2_score(all_targets, all_outputs)
        train_corr_t, _ = pearsonr(np.array(all_targets).flatten(), np.array(all_outputs).flatten())
        train_mse.append(train_mse_t)
        train_r2.append(train_r2_t)
        train_corr.append(train_corr_t)
        print(f"Epoch {epoch + 1}, Training Loss: {avg_train_loss}, Validation Loss: {avg_valid_loss}")
        print(f"Training - MSE: {train_mse_t}, R2: {train_r2_t}, Corr: {train_corr_t}")
        print(f"Validation - MSE: {train_mse_val}, R2: {train_r2_val}, Corr: {train_corr_val}")
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.title('Training and Validation Loss per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    if file_name_prefix:
        plt.savefig(f"{file_name_prefix}_loss.png")
    plt.show()
    train_df = pd.DataFrame({'Epoch': range(1, epochs + 1), 'MSE': train_mse, 'R2': train_r2, 'Corr': train_corr, 'Type': 'Training'})
    valid_df = pd.DataFrame({'Epoch': range(1, epochs + 1), 'MSE': valid_mse, 'R2': valid_r2, 'Corr': valid_corr, 'Type': 'Validation'})
    metrics_df = pd.concat([train_df, valid_df])
    return model


def run_experiment(train_loader, val_loader, test_loader, num_features, PRED_LEN, LEARNING_RATE, WEIGHT_DECAY, BATCH_SIZE, EPOCHS, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, STEP_SIZE=None, GAMMA=None):
    seed = 1
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    file_name_prefix = f'pred{PRED_LEN}_lr{LEARNING_RATE}_wd{WEIGHT_DECAY}_bs{BATCH_SIZE}_ep{EPOCHS}_loss{LOSS_TYPE}_mp{LOSS_CALCULATION_MP_FLAG}'
    if SCHEDULER_FLAG:
        file_name_prefix += f'_sch{SCHEDULER_FLAG}_ss{STEP_SIZE}_g{GAMMA}'
    model = train(train_loader, val_loader,num_features=num_features, lr=LEARNING_RATE, decay=WEIGHT_DECAY, num_epochs=EPOCHS, scheduler_flag=SCHEDULER_FLAG, scheduler_step_flag=SCHEDULER_STEP_FLAG, step_size=STEP_SIZE, gamma=GAMMA, loss_type=LOSS_TYPE, loss_calculation_mp_flag=LOSS_CALCULATION_MP_FLAG, file_name_prefix=file_name_prefix)
    test(model, test_loader)
    file_name = f'model_weights_{file_name_prefix}.pth'
    torch.save(model.state_dict(), file_name)


def main():
    parser = argparse.ArgumentParser(description="Run LOB prediction experiment")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset to use (FI or CHF)")
    parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the data")
    parser.add_argument("--num_features", type=int, default=40, help="Number of features to use")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size")
    parser.add_argument("--lookback", type=int, default=100, help="Lookback period")
    parser.add_argument("--horizon", type=int, default=1, help="Prediction horizon")
    parser.add_argument("--learning_rate", type=float, default=LEARNING_RATE, help="Learning rate")
    parser.add_argument("--epochs", type=int, default=EPOCHS, help="Number of epochs")
    args = parser.parse_args()

    train_loader, val_loader, test_loader = load_data(args.dataset, args.data_dir, args.num_features, args.horizon, args.lookback, args.batch_size)
    run_experiment(train_loader, val_loader, test_loader, args.num_features, args.horizon, args.learning_rate, WEIGHT_DECAY, args.batch_size, args.epochs, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG, SCHEDULER_STEP_FLAG, STEP_SIZE, GAMMA)


if __name__ == "__main__":
    main()