""""""

from __future__ import annotations

from typing import Tuple, Dict

import torch
from torch.utils.data import Dataset, DataLoader

from ..collection import stocks as stocks_collection


class StocksDataset(Dataset):
    """"""

    def __init__(self, X: torch.Tensor, Y: torch.Tensor, modality_first: bool = True):
        self.X = X
        self.Y = Y
        self.modality_first = modality_first

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.Y[idx]

        def _quantize(x, y):
            hi = torch.max(x)
            lo = torch.min(x)
            x = (x - lo) * 25 / (hi - lo)
            x = torch.round(x)
            x = x * (hi - lo) / 25 + lo
            return x, y

        x, y = _quantize(x, y)

        if not self.modality_first:
            return x, y
        if len(x.shape) == 2:
            x = x.permute(1, 0)
            x = list(x)
            x.append(y)
            return x
        else:
            x = x.permute(0, 2, 1)
            res = []
            for data, label in zip(x, y):
                data = list(data)
                data.append(label)
                res.append(data)
            return res


def create_dataloaders(
    dataset_name: str = "STOCKS-FandB",
    root=None,
    batch_size: int = 16,
    num_workers: int = 0,
    modality_first: bool = True,
    download: bool = True,
) -> Tuple[DataLoader, DataLoader, Dict[str, list]]:
    meta = stocks_collection.download_and_prepare(dataset_name=dataset_name, root=root, download=download)
    cache = torch.load(meta["splits"]["cache"], weights_only=False)

    X = cache["X"]
    Y = cache["Y"]
    window_size = cache["window_size"]
    val_split = cache["val_split"]
    test_split = cache["test_split"]

    train_ds = StocksDataset(X[:val_split], Y[:val_split], modality_first)
    val_ds = StocksDataset(X[val_split:test_split], Y[val_split:test_split], modality_first)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)

    test_loaders = {"timeseries": []}
    try:
        from data_loading.robustness.timeseries_robust import add_timeseries_noise  # ignore type
        for noise_level in range(9):
            X_robust = X[test_split:].clone().cpu().numpy()
            X_robust = torch.tensor(add_timeseries_noise(X_robust, noise_level=noise_level / 10), dtype=torch.float32)
            test_ds = StocksDataset(X_robust, Y[test_split:], modality_first)
            test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
            test_loaders["timeseries"].append(test_loader)
    except ModuleNotFoundError:
        test_ds = StocksDataset(X[test_split:], Y[test_split:], modality_first)
        test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
        test_loaders["timeseries"].append(test_loader)

    return train_loader, val_loader, test_loaders


__all__ = ["create_dataloaders", "StocksDataset"]
