import logging
import os
from functools import partial
from multiprocessing import cpu_count
from pathlib import Path
from typing import Iterator, Optional, cast

import numpy as np
import torch
import yaml
from datasets import (
    Dataset,
    concatenate_datasets,
    load_from_disk,
)
from datasets import (
    load_dataset as hf_load_dataset,
)
from datasets.features.features import Sequence
from torch.utils.data import IterableDataset

logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)


def has_enough_observations(
    entry: dict, min_length: int = 0, max_missing_prop: float = 1.0
) -> bool:
    if (
        len(entry["target"]) >= min_length
        and np.isnan(entry["target"]).mean() <= max_missing_prop
    ):
        return True
    return False


def preprocess_and_save_dataset(
    config_path: Path,
):
    with open(config_path) as fp:
        config = yaml.safe_load(fp)

    datasets = preprocess_datasets(
        config["dataset_names_or_paths"],
        config["prediction_length"],
        config["context_length"],
        config["min_past"],
    )

    if config.get("num_samples"):
        datasets = [
            concatenate_datasets_stratified(
                [d for d in datasets if len(d) > 0], config["num_samples"]
            )
        ]

    concatenate_datasets(datasets).save_to_disk(
        config["output_dir"], num_proc=max(cpu_count() - 1, 1)
    )


def get_training_dataset(
    dataset_names_or_paths: list[str],
    probabilities: Optional[list[float]],
    prediction_length: int,
    context_length: int,
    min_past: int,
    preprocess: bool = True,
):
    if preprocess:
        train_datasets = preprocess_datasets(
            dataset_names_or_paths, prediction_length, context_length, min_past
        )
    else:
        train_datasets = [load_dataset(name) for name in dataset_names_or_paths]

    if probabilities is not None:
        assert len(probabilities) == len(train_datasets)
    else:
        probabilities = [1.0 / len(dataset_names_or_paths)] * len(
            dataset_names_or_paths
        )

    dataset = ChromaDataset(
        subsets=train_datasets,
        probabilities=probabilities,
        context_length=context_length,
        prediction_length=prediction_length,
        min_past=min_past,
    )

    return dataset


def concatenate_datasets_stratified(
    datasets: list[Dataset], num_samples: int
) -> Dataset:
    samples_per_dataset = num_samples // len(datasets)
    output_datasets = []

    lengths = [len(dataset) for dataset in datasets]

    for i, dataset in enumerate(datasets):
        if samples_per_dataset > len(dataset):
            # Sample with replacement when we need more samples
            indices = np.random.randint(0, len(dataset), size=samples_per_dataset)
            output_datasets.append(dataset.select(indices))
        else:
            # Sample without replacement when we need fewer samples
            output_datasets.append(dataset.shuffle().select(range(samples_per_dataset)))

    if num_samples % len(datasets) != 0:
        # Handle the remaining samples
        remaining_samples = num_samples % len(datasets)
        maximum_length_dataset = datasets[np.argmax(lengths)]

        indices = np.random.randint(0, len(dataset), size=remaining_samples)
        output_datasets.append(maximum_length_dataset.shuffle().select(indices))

    return concatenate_datasets(output_datasets)


def preprocess_datasets(
    dataset_names_or_paths, prediction_length, context_length, min_past
):
    filter_fn = partial(
        has_enough_observations,
        min_length=min_past + prediction_length,
        max_missing_prop=0.9,
    )
    train_datasets = []

    for name in dataset_names_or_paths:
        dataset = load_dataset(name)
        sequence_columns = get_sequence_features(dataset)
        try:
            filtered_dataset = (
                dataset.select_columns(sequence_columns[:1])
                .rename_columns({sequence_columns[0]: "target"})
                .filter(filter_fn, num_proc=max(cpu_count() - 1, 1))
            )
            broken_up_dataset = (
                break_up_long_series(
                    filtered_dataset,
                    max_length=(context_length + prediction_length) * 5,
                )
                .rename_columns({"target": "prev_target"})
                .map(
                    lambda batch: {
                        "target": [
                            np.array(x).astype(np.float64) for x in batch["prev_target"]
                        ]
                    },
                    batched=True,
                    num_proc=max(cpu_count() - 1, 1),
                    batch_size=1000,
                    remove_columns="prev_target",
                )
            )
            train_datasets.append(broken_up_dataset)
        except IndexError:
            logger.warning(f"Dataset {name} does not have any sequence columns. ")
            raise

    return train_datasets


def load_dataset(name_or_path) -> Dataset:
    logger.info(f"Loading dataset: {name_or_path}")

    local_path = Path(name_or_path)
    if not local_path.exists():
        raise FileNotFoundError(f"Dataset {name_or_path} does not exist.")
    dataset = cast(Dataset, load_from_disk(local_path))
    logger.info(f"Loaded dataset from local path: {local_path}")

    logger.info(repr(dataset))
    return dataset


class ChromaDataset(IterableDataset):
    """See, ChronosDataset."""

    def __init__(
        self,
        subsets: list[Dataset],
        probabilities: list[float],
        context_length: int = 512,
        prediction_length: int = 64,
        min_past: Optional[int] = None,
    ) -> None:
        super().__init__()

        assert len(probabilities) == len(subsets)

        self.subsets = subsets
        self.probabilities = probabilities
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.min_past = min_past or prediction_length

    def __iter__(self) -> Iterator:
        while True:
            # Choose which subset to sample from
            selected_subset_ix = np.random.choice(
                len(self.subsets), p=self.probabilities
            )
            selected_subset = self.subsets[selected_subset_ix]

            if len(selected_subset) == 0:
                continue

            item_idx = np.random.randint(len(selected_subset))
            full_series = selected_subset[item_idx]["target"]

            slice_idx = np.random.randint(
                self.min_past, len(full_series) - self.prediction_length + 1
            )
            if slice_idx >= self.context_length:
                context = full_series[slice_idx - self.context_length : slice_idx]
            else:
                pad_size = self.context_length - slice_idx
                context = np.concatenate(
                    [np.full(pad_size, fill_value=np.nan), full_series[:slice_idx]]
                )
            target = full_series[slice_idx : slice_idx + self.prediction_length]

            if (~np.isnan(context)).sum() == 0:
                continue

            yield {
                "context": torch.tensor(context, dtype=torch.float32),
                "target": torch.tensor(target, dtype=torch.float32),
            }


def get_sequence_features(dataset: Dataset):
    sequence_features = []
    for k, v in dataset.features.items():
        if k in ["id", "timestamp"]:
            continue
        if isinstance(v, Sequence) and (
            v.feature.dtype.startswith("float") or v.feature.dtype.startswith("int")
        ):
            sequence_features.append(k)
    return sequence_features


def break_up_long_series(
    dataset: Dataset, max_length: int, target_name: str = "target"
) -> Dataset:
    stride_length = max_length // 2

    def break_up(batch):
        target = batch[target_name][0]

        if len(target) <= max_length:
            return batch

        outputs = []
        i = 0
        while i + max_length <= len(target):
            outputs.append(target[i : i + max_length])
            i += stride_length
        return {target_name: outputs}

    return dataset.map(
        break_up, batched=True, batch_size=1, num_proc=max(cpu_count() - 1, 1)
    )


def download_chronos_subset(subset_name: str, data_dir: str = "data") -> None:
    target_dir = os.path.join(data_dir, subset_name)
    os.makedirs(target_dir, exist_ok=True)

    ds = hf_load_dataset(
        "autogluon/chronos_datasets",
        subset_name,
        split="train",
        keep_in_memory=False,
    )

    ds.save_to_disk(target_dir)
