import argparse
from dataclasses import dataclass, asdict, field
import hydra
from omegaconf import DictConfig
import random
import torch
import os
import os.path as osp
from data.sampler import multi_output_gp_prior_sampler, multi_task_gp_prior_sampler
from data.function_preprocessing import make_range_tensor
from utils.paths import get_split_dataset_path, get_log_filepath
from utils.data import set_all_seeds
from utils.dataclasses import SamplerConfig
from utils.data import group_ele_count, save_dataset
from utils.config import SPLITS
from typing import Optional, List
from tqdm import tqdm
from utils.log import get_logger, log_fn


SAMPLER_DICT = {
    "multi_task_gp_prior_sampler": multi_task_gp_prior_sampler,
    "multi_output_gp_prior_sampler": multi_output_gp_prior_sampler,
}
SAMPLER_LIST = list(SAMPLER_DICT.keys())
SAMPLER_WEIGHTS = [1 for _ in SAMPLER_LIST]  # [1, 1]
KERNEL_LIST = ["rbf", "matern32", "matern52"]
KERNEL_WEIGHTS = [1 for _ in KERNEL_LIST]  # [1, 1, 1]
LENGTH_SCALE = [0.1, 2.0]
STD_RANGE = [0.1, 1.0]
MIN_RANK = 1
MAX_RANK = None
P_ISO = 0.5


@dataclass
class StaticDataConfig:
    x_dim: int
    y_dim: int
    split: str = "train"
    grid: bool = False
    num_datapoints: int = 300
    num_datasets: int = 100000
    filename: Optional[str] = None

    def __post_init__(self):
        assert self.split in SPLITS, f"split `{self.split}` is not supported."
        if self.filename is None:
            self.filename = f"gp_x{self.x_dim}_y{self.y_dim}"

    def to_dict(self):
        return asdict(self)


@hydra.main(version_base=None, config_path="configs", config_name="generate_data.yaml")
def main(config: DictConfig):
    seed = config.experiment.seed
    expid = config.experiment.expid
    device = config.experiment.device
    resume = config.experiment.resume

    torch.set_default_dtype(torch.float32)
    torch.set_default_device(device)

    sampler_config = SamplerConfig(**config.sampler)
    static_data_config = StaticDataConfig(**config.generate)

    # ==== Setup logging ====
    group_name = f"gp_data_generation"
    log_filename = get_log_filepath(group_name=group_name, expid=expid)
    logger = get_logger(file_name=log_filename, mode="w")
    log = log_fn(logger)

    log(f"log_filename:\t{log_filename}")

    # ==== Setup path to save datasets ====
    path = get_split_dataset_path(split=static_data_config.split)
    if config.experiment.subfolder != None:
        path = osp.join(path, config.experiment.subfolder)
    path = osp.join(
        path, f"x_dim_{static_data_config.x_dim}", f"y_dim_{static_data_config.y_dim}"
    )

    generate_data(
        path=path,
        seed=seed,
        device=device,
        resume=resume,
        sampler_config=sampler_config,
        static_data_config=static_data_config,
        log=log,
    )


def generate_data(
    path: str,
    seed: int,
    device: str,
    resume: bool,
    sampler_config: SamplerConfig,
    static_data_config: StaticDataConfig,
    log: callable = print,
):
    log(f"seed:\t{seed}")
    log(f"StaticDataConfig:\n{static_data_config.to_dict()}")
    log(f"SamplerConfig:\n{sampler_config.to_dict()}")
    set_all_seeds(seed)

    for sampler in sampler_config.sampler_list:
        if sampler not in globals():
            raise ValueError(f"sampler `{sampler}` is not supported.")

    datapath = osp.join(path, f"{static_data_config.filename}.hdf5")
    if osp.exists(datapath):
        if not resume:
            raise FileExistsError(
                path, f"File {static_data_config.filename}.hdf5 already exists."
            )
        else:
            epoch = group_ele_count(
                osp.join(path, f"{static_data_config.filename}.hdf5")
            )
    else:
        epoch = 0
        os.makedirs(path, exist_ok=True)

    log(f"datapath:\t{datapath}")
    log(f"epoch:\t{epoch}")

    x_range_t = make_range_tensor(
        sampler_config.x_range,
        num_dim=static_data_config.x_dim,
    )
    for i in tqdm(
        range(epoch, epoch + static_data_config.num_datasets),
        desc=f"Generating {static_data_config.num_datasets} datasets...",
        miniters=500,
    ):
        sampler = random.choices(
            population=sampler_config.sampler_list,
            weights=sampler_config.sampler_weights,
            k=1,
        )[0]
        sampler_func = SAMPLER_DICT[sampler]

        with torch.no_grad():
            x, y = sampler_func(
                x_range=x_range_t,
                x_dim=static_data_config.x_dim,
                num_datapoints=static_data_config.num_datapoints,
                num_tasks=static_data_config.y_dim,
                data_kernel_type_list=sampler_config.data_kernel_type_list,
                sample_kernel_weights=sampler_config.sample_kernel_weights,
                lengthscale_range=sampler_config.lengthscale_range,
                std_range=sampler_config.std_range,
                min_rank=sampler_config.min_rank,
                max_rank=sampler_config.max_rank,
                p_iso=sampler_config.p_iso,
                grid=static_data_config.grid,
                standardize=sampler_config.standardize,
                jitter=sampler_config.jitter,
                max_tries=sampler_config.max_tries,
                device=device,
            )

        save_dataset(
            datapath,
            grp_name=f"dataset_{i}",
            inputs=x,
            targets=y,
        )

        if i % 20 == 0:
            torch.cuda.empty_cache()


if __name__ == "__main__":
    main()