#  Copyright (c) 2024, Salesforce, Inc.
#  SPDX-License-Identifier: Apache-2
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import itertools
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Any, Generator
from typing import Any, Callable, Generator, Optional


import datasets
import numpy as np
import pandas as pd
from datasets import Features, Sequence, Value

from uni2ts.common.env import env
from uni2ts.data.dataset import TimeSeriesDataset

from ._base import LOTSADatasetBuilder

ERA5_VARIABLES = [
    "2m_temperature",
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
] + [
    f"{var}_{level}"
    for var, level in itertools.product(
        [
            "geopotential",
            "relative_humidity",
            "specific_humidity",
            "temperature",
            "u_component_of_wind",
            "v_component_of_wind",
        ],
        [50, 250, 500, 600, 700, 850, 925],
    )
]

class Synthetic10DatasetBuilder(LOTSADatasetBuilder):
    
    dataset_list = [f"data{i}" for i in range(10)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

    def build_dataset(
        self,
        file: Path,
        dataset_type: str,
        offset: Optional[int] = None,
        date_offset: Optional[pd.Timestamp] = None,
        freq: str = "D",
    ):
        assert offset is None or date_offset is None, (
            "One or neither offset and date_offset must be specified, but not both. "
            f"Got offset: {offset}, date_offset: {date_offset}"
        )

        df = pd.read_csv(file, index_col=0, parse_dates=True)

        if dataset_type == "long":
            _from_dataframe = _from_long_dataframe
        elif dataset_type == "wide":
            _from_dataframe = _from_wide_dataframe
        elif dataset_type == "wide_multivariate":
            _from_dataframe = _from_wide_dataframe_multivariate
        else:
            raise ValueError(
                f"Unrecognized dataset_type, {dataset_type}."
                " Valid options are 'long', 'wide', and 'wide_multivariate'."
            )

        example_gen_func, features = _from_dataframe(
            df, freq=freq, offset=offset, date_offset=date_offset
        )
        hf_dataset = datasets.Dataset.from_generator(
            example_gen_func, features=features
        )
        hf_dataset.info.dataset_name = self.dataset
        hf_dataset.save_to_disk(self.storage_path / self.dataset)

class Synthetic20DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(20)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True


class Synthetic30DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(30)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic40DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(40)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic50DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(50)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic60DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(60)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic70DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(70)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic80DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(80)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic90DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(90)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic100DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}" for i in range(100)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

# order 5 below

class Synthetic100_Lag5DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}_o5" for i in range(100)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Synthetic50_Lag5DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}_o5" for i in range(50)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class SyntheticTest_Lag5DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = ["test_data0_o5"]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class StrongDatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"strong_data{i}" for i in range(50)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

# Sesasonal

class Seasonal100_Lag5DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}_o5_season" for i in range(100)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class SeasonalTest_Lag5DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = ["test_data0_o5_season"]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class Seasonal50_Lag5DatasetBuilder(Synthetic10DatasetBuilder):
    dataset_list = [f"data{i}_o5_season" for i in range(50)]
    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

class SeasonalityMixedDatasetBuilder(Synthetic10DatasetBuilder):
    # dataset_list = [f"data{i}_o5_" for i in range(10)]
    dataset_list = []
    for q in [4, 5]:
        for d in [1, 2, 3, 4, 5]:
            for k in range(10):
                dataset_list.append(f"data{k}_o{q}_season_{d}")

    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

# mixed lag
class SyntheticMixedDatasetBuilder(Synthetic10DatasetBuilder):
    # dataset_list = [f"data{i}_o5_" for i in range(10)]
    dataset_list = []
    for i in range(2, 5):
        for j in range(10):
            for k in range(1, 5):
                dataset_list.append(f"data{j}_o{i}_{k}")

    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True

# mixed lag
class SyntheticMixed2DatasetBuilder(Synthetic10DatasetBuilder):
    # dataset_list = [f"data{i}_o5_" for i in range(10)]
    dataset_list = []
    for q in [4, 5]:
        for d in [4, 5]:
            for k in range(10):
                dataset_list.append(f"data{k}_o{q}_{d}")

    dataset_type_map = defaultdict(lambda: TimeSeriesDataset)
    dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset))
    uniform = True