import argparse
from dataclasses import (
    asdict,
    dataclass,
)
import functools
import random
from textwrap import dedent, indent
import json
from pathlib import Path

# from toolz import curry
from typing import Dict, List, Optional, Sequence, Tuple, Union

import toml
import voluptuous
from voluptuous import (
    Any,
    ExactSequence,
    MultipleInvalid,
    Object,
    Required,
    Schema,
)
from transformers import CLIPTokenizer

from . import train_util
from .train_util import (
    DreamBoothSubset,
    FineTuningSubset,
    ControlNetSubset,
    DreamBoothDataset,
    FineTuningDataset,
    ControlNetDataset,
    DatasetGroup,
)
from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


def add_config_arguments(parser: argparse.ArgumentParser):
    parser.add_argument(
        "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
    )


# TODO: inherit Params class in Subset, Dataset


@dataclass
class BaseSubsetParams:
    image_dir: Optional[str] = None
    num_repeats: int = 1
    shuffle_caption: bool = False
    caption_separator: str = (",",)
    keep_tokens: int = 0
    keep_tokens_separator: str = (None,)
    secondary_separator: Optional[str] = None
    enable_wildcard: bool = False
    color_aug: bool = False
    flip_aug: bool = False
    face_crop_aug_range: Optional[Tuple[float, float]] = None
    random_crop: bool = False
    caption_prefix: Optional[str] = None
    caption_suffix: Optional[str] = None
    caption_dropout_rate: float = 0.0
    caption_dropout_every_n_epochs: int = 0
    caption_tag_dropout_rate: float = 0.0
    token_warmup_min: int = 1
    token_warmup_step: float = 0
    custom_attributes: Optional[Dict[str, Any]] = None


@dataclass
class DreamBoothSubsetParams(BaseSubsetParams):
    is_reg: bool = False
    class_tokens: Optional[str] = None
    caption_extension: str = ".caption"
    cache_info: bool = False
    alpha_mask: bool = False


@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
    metadata_file: Optional[str] = None
    alpha_mask: bool = False


@dataclass
class ControlNetSubsetParams(BaseSubsetParams):
    conditioning_data_dir: str = None
    caption_extension: str = ".caption"
    cache_info: bool = False


@dataclass
class BaseDatasetParams:
    resolution: Optional[Tuple[int, int]] = None
    network_multiplier: float = 1.0
    debug_dataset: bool = False


@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
    batch_size: int = 1
    enable_bucket: bool = False
    min_bucket_reso: int = 256
    max_bucket_reso: int = 1024
    bucket_reso_steps: int = 64
    bucket_no_upscale: bool = False
    prior_loss_weight: float = 1.0


@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
    batch_size: int = 1
    enable_bucket: bool = False
    min_bucket_reso: int = 256
    max_bucket_reso: int = 1024
    bucket_reso_steps: int = 64
    bucket_no_upscale: bool = False


@dataclass
class ControlNetDatasetParams(BaseDatasetParams):
    batch_size: int = 1
    enable_bucket: bool = False
    min_bucket_reso: int = 256
    max_bucket_reso: int = 1024
    bucket_reso_steps: int = 64
    bucket_no_upscale: bool = False


@dataclass
class SubsetBlueprint:
    params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]


@dataclass
class DatasetBlueprint:
    is_dreambooth: bool
    is_controlnet: bool
    params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
    subsets: Sequence[SubsetBlueprint]


@dataclass
class DatasetGroupBlueprint:
    datasets: Sequence[DatasetBlueprint]


@dataclass
class Blueprint:
    dataset_group: DatasetGroupBlueprint


class ConfigSanitizer:
    # @curry
    @staticmethod
    def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
        Schema(ExactSequence([klass, klass]))(value)
        return tuple(value)

    # @curry
    @staticmethod
    def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
        Schema(Any(klass, ExactSequence([klass, klass])))(value)
        try:
            Schema(klass)(value)
            return (value, value)
        except:
            return ConfigSanitizer.__validate_and_convert_twodim(klass, value)

    # subset schema
    SUBSET_ASCENDABLE_SCHEMA = {
        "color_aug": bool,
        "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
        "flip_aug": bool,
        "num_repeats": int,
        "random_crop": bool,
        "shuffle_caption": bool,
        "keep_tokens": int,
        "keep_tokens_separator": str,
        "secondary_separator": str,
        "caption_separator": str,
        "enable_wildcard": bool,
        "token_warmup_min": int,
        "token_warmup_step": Any(float, int),
        "caption_prefix": str,
        "caption_suffix": str,
        "custom_attributes": dict,
    }
    # DO means DropOut
    DO_SUBSET_ASCENDABLE_SCHEMA = {
        "caption_dropout_every_n_epochs": int,
        "caption_dropout_rate": Any(float, int),
        "caption_tag_dropout_rate": Any(float, int),
    }
    # DB means DreamBooth
    DB_SUBSET_ASCENDABLE_SCHEMA = {
        "caption_extension": str,
        "class_tokens": str,
        "cache_info": bool,
    }
    DB_SUBSET_DISTINCT_SCHEMA = {
        Required("image_dir"): str,
        "is_reg": bool,
        "alpha_mask": bool,
    }
    # FT means FineTuning
    FT_SUBSET_DISTINCT_SCHEMA = {
        Required("metadata_file"): str,
        "image_dir": str,
        "alpha_mask": bool,
    }
    CN_SUBSET_ASCENDABLE_SCHEMA = {
        "caption_extension": str,
        "cache_info": bool,
    }
    CN_SUBSET_DISTINCT_SCHEMA = {
        Required("image_dir"): str,
        Required("conditioning_data_dir"): str,
    }

    # datasets schema
    DATASET_ASCENDABLE_SCHEMA = {
        "batch_size": int,
        "bucket_no_upscale": bool,
        "bucket_reso_steps": int,
        "enable_bucket": bool,
        "max_bucket_reso": int,
        "min_bucket_reso": int,
        "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
        "network_multiplier": float,
    }

    # options handled by argparse but not handled by user config
    ARGPARSE_SPECIFIC_SCHEMA = {
        "debug_dataset": bool,
        "max_token_length": Any(None, int),
        "prior_loss_weight": Any(float, int),
    }
    # for handling default None value of argparse
    ARGPARSE_NULLABLE_OPTNAMES = [
        "face_crop_aug_range",
        "resolution",
    ]
    # prepare map because option name may differ among argparse and user config
    ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
        "train_batch_size": "batch_size",
        "dataset_repeats": "num_repeats",
    }

    def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
        assert support_dreambooth or support_finetuning or support_controlnet, (
            "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
            + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
        )

        self.db_subset_schema = self.__merge_dict(
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.DB_SUBSET_DISTINCT_SCHEMA,
            self.DB_SUBSET_ASCENDABLE_SCHEMA,
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
        )

        self.ft_subset_schema = self.__merge_dict(
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.FT_SUBSET_DISTINCT_SCHEMA,
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
        )

        self.cn_subset_schema = self.__merge_dict(
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.CN_SUBSET_DISTINCT_SCHEMA,
            self.CN_SUBSET_ASCENDABLE_SCHEMA,
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
        )

        self.db_dataset_schema = self.__merge_dict(
            self.DATASET_ASCENDABLE_SCHEMA,
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.DB_SUBSET_ASCENDABLE_SCHEMA,
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
            {"subsets": [self.db_subset_schema]},
        )

        self.ft_dataset_schema = self.__merge_dict(
            self.DATASET_ASCENDABLE_SCHEMA,
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
            {"subsets": [self.ft_subset_schema]},
        )

        self.cn_dataset_schema = self.__merge_dict(
            self.DATASET_ASCENDABLE_SCHEMA,
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.CN_SUBSET_ASCENDABLE_SCHEMA,
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
            {"subsets": [self.cn_subset_schema]},
        )

        if support_dreambooth and support_finetuning:

            def validate_flex_dataset(dataset_config: dict):
                subsets_config = dataset_config.get("subsets", [])

                if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
                    return Schema(self.cn_dataset_schema)(dataset_config)
                # check dataset meets FT style
                # NOTE: all FT subsets should have "metadata_file"
                elif all(["metadata_file" in subset for subset in subsets_config]):
                    return Schema(self.ft_dataset_schema)(dataset_config)
                # check dataset meets DB style
                # NOTE: all DB subsets should have no "metadata_file"
                elif all(["metadata_file" not in subset for subset in subsets_config]):
                    return Schema(self.db_dataset_schema)(dataset_config)
                else:
                    raise voluptuous.Invalid(
                        "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
                    )

            self.dataset_schema = validate_flex_dataset
        elif support_dreambooth:
            if support_controlnet:
                self.dataset_schema = self.cn_dataset_schema
            else:
                self.dataset_schema = self.db_dataset_schema
        elif support_finetuning:
            self.dataset_schema = self.ft_dataset_schema
        elif support_controlnet:
            self.dataset_schema = self.cn_dataset_schema

        self.general_schema = self.__merge_dict(
            self.DATASET_ASCENDABLE_SCHEMA,
            self.SUBSET_ASCENDABLE_SCHEMA,
            self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
            self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
            self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
        )

        self.user_config_validator = Schema(
            {
                "general": self.general_schema,
                "datasets": [self.dataset_schema],
            }
        )

        self.argparse_schema = self.__merge_dict(
            self.general_schema,
            self.ARGPARSE_SPECIFIC_SCHEMA,
            {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
            {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
        )

        self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)

    def sanitize_user_config(self, user_config: dict) -> dict:
        try:
            return self.user_config_validator(user_config)
        except MultipleInvalid:
            # TODO: エラー発生時のメッセージをわかりやすくする
            logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
            raise

    # NOTE: In nature, argument parser result is not needed to be sanitize
    #   However this will help us to detect program bug
    def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
        try:
            return self.argparse_config_validator(argparse_namespace)
        except MultipleInvalid:
            # XXX: this should be a bug
            logger.error(
                "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
            )
            raise

    # NOTE: value would be overwritten by latter dict if there is already the same key
    @staticmethod
    def __merge_dict(*dict_list: dict) -> dict:
        merged = {}
        for schema in dict_list:
            # merged |= schema
            for k, v in schema.items():
                merged[k] = v
        return merged


class BlueprintGenerator:
    BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}

    def __init__(self, sanitizer: ConfigSanitizer):
        self.sanitizer = sanitizer

    # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
    def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
        sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
        sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)

        # convert argparse namespace to dict like config
        # NOTE: it is ok to have extra entries in dict
        optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
        argparse_config = {
            optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
        }

        general_config = sanitized_user_config.get("general", {})

        dataset_blueprints = []
        for dataset_config in sanitized_user_config.get("datasets", []):
            # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
            subsets = dataset_config.get("subsets", [])
            is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
            is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
            if is_controlnet:
                subset_params_klass = ControlNetSubsetParams
                dataset_params_klass = ControlNetDatasetParams
            elif is_dreambooth:
                subset_params_klass = DreamBoothSubsetParams
                dataset_params_klass = DreamBoothDatasetParams
            else:
                subset_params_klass = FineTuningSubsetParams
                dataset_params_klass = FineTuningDatasetParams

            subset_blueprints = []
            for subset_config in subsets:
                params = self.generate_params_by_fallbacks(
                    subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
                )
                subset_blueprints.append(SubsetBlueprint(params))

            params = self.generate_params_by_fallbacks(
                dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
            )
            dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))

        dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)

        return Blueprint(dataset_group_blueprint)

    @staticmethod
    def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
        name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
        search_value = BlueprintGenerator.search_value
        default_params = asdict(param_klass())
        param_names = default_params.keys()

        params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}

        return param_klass(**params)

    @staticmethod
    def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
        for cand in fallbacks:
            value = cand.get(key)
            if value is not None:
                return value

        return default_value


def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
    datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

    for dataset_blueprint in dataset_group_blueprint.datasets:
        if dataset_blueprint.is_controlnet:
            subset_klass = ControlNetSubset
            dataset_klass = ControlNetDataset
        elif dataset_blueprint.is_dreambooth:
            subset_klass = DreamBoothSubset
            dataset_klass = DreamBoothDataset
        else:
            subset_klass = FineTuningSubset
            dataset_klass = FineTuningDataset

        subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
        dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
        datasets.append(dataset)

    # print info
    info = ""
    for i, dataset in enumerate(datasets):
        is_dreambooth = isinstance(dataset, DreamBoothDataset)
        is_controlnet = isinstance(dataset, ControlNetDataset)
        info += dedent(
            f"""\
      [Dataset {i}]
        batch_size: {dataset.batch_size}
        resolution: {(dataset.width, dataset.height)}
        enable_bucket: {dataset.enable_bucket}
        network_multiplier: {dataset.network_multiplier}
    """
        )

        if dataset.enable_bucket:
            info += indent(
                dedent(
                    f"""\
        min_bucket_reso: {dataset.min_bucket_reso}
        max_bucket_reso: {dataset.max_bucket_reso}
        bucket_reso_steps: {dataset.bucket_reso_steps}
        bucket_no_upscale: {dataset.bucket_no_upscale}
      \n"""
                ),
                "  ",
            )
        else:
            info += "\n"

        for j, subset in enumerate(dataset.subsets):
            info += indent(
                dedent(
                    f"""\
        [Subset {j} of Dataset {i}]
          image_dir: "{subset.image_dir}"
          image_count: {subset.img_count}
          num_repeats: {subset.num_repeats}
          shuffle_caption: {subset.shuffle_caption}
          keep_tokens: {subset.keep_tokens}
          keep_tokens_separator: {subset.keep_tokens_separator}
          caption_separator: {subset.caption_separator}
          secondary_separator: {subset.secondary_separator}
          enable_wildcard: {subset.enable_wildcard}
          caption_dropout_rate: {subset.caption_dropout_rate}
          caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
          caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
          caption_prefix: {subset.caption_prefix}
          caption_suffix: {subset.caption_suffix}
          color_aug: {subset.color_aug}
          flip_aug: {subset.flip_aug}
          face_crop_aug_range: {subset.face_crop_aug_range}
          random_crop: {subset.random_crop}
          token_warmup_min: {subset.token_warmup_min}
          token_warmup_step: {subset.token_warmup_step}
          alpha_mask: {subset.alpha_mask}
          custom_attributes: {subset.custom_attributes}
      """
                ),
                "  ",
            )

            if is_dreambooth:
                info += indent(
                    dedent(
                        f"""\
          is_reg: {subset.is_reg}
          class_tokens: {subset.class_tokens}
          caption_extension: {subset.caption_extension}
        \n"""
                    ),
                    "    ",
                )
            elif not is_controlnet:
                info += indent(
                    dedent(
                        f"""\
          metadata_file: {subset.metadata_file}
        \n"""
                    ),
                    "    ",
                )

    logger.info(f"{info}")

    # make buckets first because it determines the length of dataset
    # and set the same seed for all datasets
    seed = random.randint(0, 2**31)  # actual seed is seed + epoch_no
    for i, dataset in enumerate(datasets):
        logger.info(f"[Dataset {i}]")
        dataset.make_buckets()
        dataset.set_seed(seed)

    return DatasetGroup(datasets)


def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
    def extract_dreambooth_params(name: str) -> Tuple[int, str]:
        tokens = name.split("_")
        try:
            n_repeats = int(tokens[0])
        except ValueError as e:
            logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
            return 0, ""
        caption_by_folder = "_".join(tokens[1:])
        return n_repeats, caption_by_folder

    def generate(base_dir: Optional[str], is_reg: bool):
        if base_dir is None:
            return []

        base_dir: Path = Path(base_dir)
        if not base_dir.is_dir():
            return []

        subsets_config = []
        for subdir in base_dir.iterdir():
            if not subdir.is_dir():
                continue

            num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
            if num_repeats < 1:
                continue

            subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
            subsets_config.append(subset_config)

        return subsets_config

    subsets_config = []
    subsets_config += generate(train_data_dir, False)
    subsets_config += generate(reg_data_dir, True)

    return subsets_config


def generate_controlnet_subsets_config_by_subdirs(
    train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
):
    def generate(base_dir: Optional[str]):
        if base_dir is None:
            return []

        base_dir: Path = Path(base_dir)
        if not base_dir.is_dir():
            return []

        subsets_config = []
        subset_config = {
            "image_dir": train_data_dir,
            "conditioning_data_dir": conditioning_data_dir,
            "caption_extension": caption_extension,
            "num_repeats": 1,
        }
        subsets_config.append(subset_config)

        return subsets_config

    subsets_config = []
    subsets_config += generate(train_data_dir)

    return subsets_config


def load_user_config(file: str) -> dict:
    file: Path = Path(file)
    if not file.is_file():
        raise ValueError(f"file not found / ファイルが見つかりません: {file}")

    if file.name.lower().endswith(".json"):
        try:
            with open(file, "r") as f:
                config = json.load(f)
        except Exception:
            logger.error(
                f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
            )
            raise
    elif file.name.lower().endswith(".toml"):
        try:
            config = toml.load(file)
        except Exception:
            logger.error(
                f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
            )
            raise
    else:
        raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")

    return config


# for config test
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--support_dreambooth", action="store_true")
    parser.add_argument("--support_finetuning", action="store_true")
    parser.add_argument("--support_controlnet", action="store_true")
    parser.add_argument("--support_dropout", action="store_true")
    parser.add_argument("dataset_config")
    config_args, remain = parser.parse_known_args()

    parser = argparse.ArgumentParser()
    train_util.add_dataset_arguments(
        parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
    )
    train_util.add_training_arguments(parser, config_args.support_dreambooth)
    argparse_namespace = parser.parse_args(remain)
    train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)

    logger.info("[argparse_namespace]")
    logger.info(f"{vars(argparse_namespace)}")

    user_config = load_user_config(config_args.dataset_config)

    logger.info("")
    logger.info("[user_config]")
    logger.info(f"{user_config}")

    sanitizer = ConfigSanitizer(
        config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
    )
    sanitized_user_config = sanitizer.sanitize_user_config(user_config)

    logger.info("")
    logger.info("[sanitized_user_config]")
    logger.info(f"{sanitized_user_config}")

    blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)

    logger.info("")
    logger.info("[blueprint]")
    logger.info(f"{blueprint}")
