import math
import os
import pandas as pd
import pathlib
import torch
import torchcde

from . import common


_here = pathlib.Path(__file__).resolve().parent


# Data from https://lobsterdata.com/info/DataStructure.php


def stocks_data(batch_size):
    raw_data = _here / 'data' / 'GOOGL'
    processed_data = _here / 'processed_data' / 'stocks'

    samples_per_day = 40000
    length = 100
    step = 35

    try:
        x = torch.load(processed_data / 'x.pt')
        y = torch.load(processed_data / 'y.pt')
    except FileNotFoundError:
        if not os.path.exists(raw_data):
            raise FileNotFoundError("Stocks data not available. This must be obtained from https://lobsterdata.com/.")
        x = []
        end = 'message_10.csv'
        for filename in sorted(os.listdir(raw_data)):
            ti = []
            xi = []
            if filename.endswith(end):
                message_file = pd.read_csv(raw_data / filename)
                orderbook_file = pd.read_csv(raw_data / (filename[:-len(end)] + 'orderbook_10.csv'))
                prev_timestamp = 0.
                for message_line, orderbook_line in zip(message_file.itertuples(), orderbook_file.itertuples()):
                    timestamp = message_line._1
                    ask = orderbook_line._1 / 1e6
                    bid = orderbook_line._3 / 1e6
                    if timestamp == prev_timestamp:
                        continue
                    prev_timestamp = timestamp

                    midpoint = (ask + bid) * 0.5
                    spread = ask - bid
                    logspread = math.log(spread)
                    ti.append(timestamp)
                    xi.append((midpoint, logspread))
                ti = torch.tensor(ti, dtype=torch.float64)
                xi = torch.tensor(xi, dtype=torch.float64)
                # We resample on to a uniform grid.
                # Frankly, this is just as a convenience, so as to be able to write slightly simpler code (as our other
                # problems are regularly sampled). It's perfectly possible to do NSDEs directly on irregularly sampled
                # data as well. (Although it does require knowing a bit of CDE theory.)
                xir = torchcde.LinearInterpolation(xi, ti).evaluate(torch.linspace(ti[0].item(), ti[-1].item(),
                                                                                   samples_per_day))
                xir = xir.to(torch.float32)
                xi = xir.unfold(dimension=0, size=length, step=step)
                x.append(xi.transpose(1, 2))
        x = torch.cat(x)
        y = torch.empty(x.size(0), 0, dtype=x.dtype)

        os.makedirs(processed_data, exist_ok=True)
        torch.save(x, processed_data / 'x.pt')
        torch.save(y, processed_data / 'y.pt')

    dataset = torch.utils.data.TensorDataset(x, y)
    dataloader = common.dataloader(dataset, batch_size)

    t = torch.linspace(0, length - 1, length)
    input_channels = 2  # midpoint and spread
    label_channels = 0  # unconditional GAN, not conditional
    return t, dataloader, input_channels, label_channels
