# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------

import os
import logging
import torch

from omegaconf import OmegaConf

from ..models.clip import CustomizedCLIP, CustomizedTokenizer
from ..models.prior_model import PriorDiffusionModel
from ..models.decoder_model import Text2ImProgressiveModel
from ..models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel


SAMPLING_CONF = {
    "default": {
        "prior_sm": "25",
        "prior_n_samples": 1,
        "prior_cf_scale": 4.0,
        "decoder_sm": "50",
        "decoder_cf_scale": 8.0,
        "sr_sm": "7",
    },
    "fast": {
        "prior_sm": "25",
        "prior_n_samples": 1,
        "prior_cf_scale": 4.0,
        "decoder_sm": "25",
        "decoder_cf_scale": 8.0,
        "sr_sm": "7",
    },
}

CKPT_PATH = {
    "prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
    "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
    "sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
}


class BaseSampler:
    _PRIOR_CLASS = PriorDiffusionModel
    _DECODER_CLASS = Text2ImProgressiveModel
    _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel

    def __init__(
        self,
        root_dir: str,
        sampling_type: str = "fast",
    ):
        self._root_dir = root_dir

        sampling_type = SAMPLING_CONF[sampling_type]
        self._prior_sm = sampling_type["prior_sm"]
        self._prior_n_samples = sampling_type["prior_n_samples"]
        self._prior_cf_scale = sampling_type["prior_cf_scale"]

        assert self._prior_n_samples == 1

        self._decoder_sm = sampling_type["decoder_sm"]
        self._decoder_cf_scale = sampling_type["decoder_cf_scale"]

        self._sr_sm = sampling_type["sr_sm"]

    def __repr__(self):
        line = ""
        line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
        line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
        line += f"SR(64->256), sampling method: {self._sr_sm}"

        return line

    def load_clip(self, clip_path: str):
        clip = CustomizedCLIP.load_from_checkpoint(
            os.path.join(self._root_dir, clip_path)
        )
        clip = torch.jit.script(clip)
        if torch.cuda.is_available():
            clip.cuda()
        clip.eval()

        self._clip = clip
        self._tokenizer = CustomizedTokenizer()

    def load_prior(
        self,
        ckpt_path: str,
        clip_stat_path: str,
    ):
        logging.info(f"Loading prior: {ckpt_path}")

        config = OmegaConf.load("configs/karlo/prior_1B_vit_l.yaml")
        clip_mean, clip_std = torch.load(
            os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
        )

        prior = self._PRIOR_CLASS.load_from_checkpoint(
            config,
            self._tokenizer,
            clip_mean,
            clip_std,
            os.path.join(self._root_dir, ckpt_path),
            strict=True,
        )
        if torch.cuda.is_available():
            prior.cuda()
        prior.eval()
        logging.info("done.")

        self._prior = prior

    def load_decoder(self, ckpt_path: str):
        logging.info(f"Loading decoder: {ckpt_path}")

        config = OmegaConf.load("configs/karlo/decoder_900M_vit_l.yaml")
        decoder = self._DECODER_CLASS.load_from_checkpoint(
            config,
            self._tokenizer,
            os.path.join(self._root_dir, ckpt_path),
            strict=True,
        )
        if torch.cuda.is_available():
            decoder.cuda()
        decoder.eval()
        logging.info("done.")

        self._decoder = decoder

    def load_sr_64_256(self, ckpt_path: str):
        logging.info(f"Loading SR(64->256): {ckpt_path}")

        config = OmegaConf.load("configs/karlo/improved_sr_64_256_1.4B.yaml")
        sr = self._SR256_CLASS.load_from_checkpoint(
            config, os.path.join(self._root_dir, ckpt_path), strict=True
        )
        if torch.cuda.is_available():
            sr.cuda()
        sr.eval()
        logging.info("done.")

        self._sr_64_256 = sr
