import datetime
import logging
import os
import os.path
import os.path as osp
import time
from collections import OrderedDict

import PIL
import torch
from accelerate.logging import get_logger
from accelerate.state import PartialState
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.transforms import ToTensor
from torchvision.utils import make_grid

NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'


# ----------- file/logger util ----------
def get_time_str():
    return time.strftime('%Y%m%d_%H%M%S', time.localtime())


def mkdir_and_rename(path):
    """mkdirs. If path exists, rename it with timestamp and create a new one.

    Args:
        path (str): Folder path.
    """
    if osp.exists(path):
        new_name = path + '_archived_' + get_time_str()
        print(f'Path already exists. Rename it to {new_name}', flush=True)
        os.rename(path, new_name)
    os.makedirs(path, exist_ok=True)


def make_exp_dirs(opt):
    """Make dirs for experiments."""
    path_opt = opt['path'].copy()
    if opt['is_train']:
        mkdir_and_rename(path_opt.pop('experiments_root'))
    else:
        mkdir_and_rename(path_opt.pop('results_root'))
    for key, path in path_opt.items():
        if ('strict_load' in key) or ('pretrain_network' in key) or (
                'resume' in key) or ('param_key' in key) or ('lora_path' in key):
            continue
        else:
            os.makedirs(path, exist_ok=True)


def copy_opt_file(opt_file, experiments_root):
    # copy the yml file to the experiment root
    import sys
    import time
    from shutil import copyfile
    cmd = ' '.join(sys.argv)
    filename = osp.join(experiments_root, osp.basename(opt_file))
    copyfile(opt_file, filename)

    with open(filename, 'r+') as f:
        lines = f.readlines()
        lines.insert(
            0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
        f.seek(0)
        f.writelines(lines)


def set_path_logger(accelerator, root_path, config_path, opt, is_train=True):
    opt['is_train'] = is_train

    if is_train:
        experiments_root = osp.join(root_path, 'experiments', opt['name'])
        opt['path']['experiments_root'] = experiments_root
        opt['path']['models'] = osp.join(experiments_root, 'models')
        opt['path']['log'] = experiments_root
        opt['path']['visualization'] = osp.join(experiments_root,
                                                'visualization')
    else:
        results_root = osp.join(root_path, 'results', opt['name'])
        opt['path']['results_root'] = results_root
        opt['path']['log'] = results_root
        opt['path']['visualization'] = osp.join(results_root, 'visualization')

    # Handle the output folder creation
    if accelerator.is_main_process:
        make_exp_dirs(opt)

    accelerator.wait_for_everyone()

    if is_train:
        copy_opt_file(config_path, opt['path']['experiments_root'])
        log_file = osp.join(opt['path']['log'],
                            f"train_{opt['name']}_{get_time_str()}.log")
        set_logger(log_file)
    else:
        copy_opt_file(config_path, opt['path']['results_root'])
        log_file = osp.join(opt['path']['log'],
                            f"test_{opt['name']}_{get_time_str()}.log")
        set_logger(log_file)


def set_logger(log_file=None):
    # Make one log on every process with the configuration for debugging.
    format_str = '%(asctime)s %(levelname)s: %(message)s'
    log_level = logging.INFO
    handlers = []

    file_handler = logging.FileHandler(log_file, 'w')
    file_handler.setFormatter(logging.Formatter(format_str))
    file_handler.setLevel(log_level)
    handlers.append(file_handler)

    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(logging.Formatter(format_str))
    handlers.append(stream_handler)

    logging.basicConfig(handlers=handlers, level=log_level)


def dict2str(opt, indent_level=1):
    """dict to string for printing options.

    Args:
        opt (dict): Option dict.
        indent_level (int): Indent level. Default: 1.

    Return:
        (str): Option string for printing.
    """
    msg = '\n'
    for k, v in opt.items():
        if isinstance(v, dict):
            msg += ' ' * (indent_level * 2) + k + ':['
            msg += dict2str(v, indent_level + 1)
            msg += ' ' * (indent_level * 2) + ']\n'
        else:
            msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
    return msg


class MessageLogger():
    """Message logger for printing.

    Args:
        opt (dict): Config. It contains the following keys:
            name (str): Exp name.
            logger (dict): Contains 'print_freq' (str) for logger interval.
            train (dict): Contains 'total_iter' (int) for total iters.
            use_tb_logger (bool): Use tensorboard logger.
        start_iter (int): Start iter. Default: 1.
        tb_logger (obj:`tb_logger`): Tensorboard logger. Default： None.
    """
    def __init__(self, opt, start_iter=1):
        self.exp_name = opt['name']
        self.interval = opt['logger']['print_freq']
        self.start_iter = start_iter
        self.max_iters = opt['train']['total_iter']
        self.start_time = time.time()
        self.logger = get_logger('mixofshow', log_level='INFO')

    def reset_start_time(self):
        self.start_time = time.time()

    def __call__(self, log_vars):
        """Format logging message.

        Args:
            log_vars (dict): It contains the following keys:
                epoch (int): Epoch number.
                iter (int): Current iter.
                lrs (list): List for learning rates.

                time (float): Iter time.
                data_time (float): Data time for each iter.
        """
        # epoch, iter, learning rates
        current_iter = log_vars.pop('iter')
        lrs = log_vars.pop('lrs')

        message = (
            f'[{self.exp_name[:5]}..][Iter:{current_iter:8,d}, lr:('
        )
        for v in lrs:
            message += f'{v:.3e},'
        message += ')] '

        # time and estimated time
        total_time = time.time() - self.start_time
        time_sec_avg = total_time / (current_iter - self.start_iter + 1)
        eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
        eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
        message += f'[eta: {eta_str}] '

        # other items, especially losses
        for k, v in log_vars.items():
            message += f'{k}: {v:.4e} '

        self.logger.info(message)


def reduce_loss_dict(accelerator, loss_dict):
    """reduce loss dict.

    In distributed training, it averages the losses among different GPUs .

    Args:
        loss_dict (OrderedDict): Loss dict.
    """
    with torch.no_grad():
        keys = []
        losses = []
        for name, value in loss_dict.items():
            keys.append(name)
            losses.append(value)
        losses = torch.stack(losses, 0)
        losses = accelerator.reduce(losses)

        world_size = PartialState().num_processes
        losses /= world_size

        loss_dict = {key: loss for key, loss in zip(keys, losses)}

        log_dict = OrderedDict()
        for name, value in loss_dict.items():
            log_dict[name] = value.mean().item()

        return log_dict


def pil_imwrite(img, file_path, auto_mkdir=True):
    """Write image to file.
    Args:
        img (ndarray): Image array to be written.
        file_path (str): Image file path.
        params (None or list): Same as opencv's :func:`imwrite` interface.
        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
            whether to create it automatically.
    Returns:
        bool: Successful or not.
    """
    assert isinstance(
        img, PIL.Image.Image), 'model should return a list of PIL images'
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(file_path))
        os.makedirs(dir_name, exist_ok=True)
    img.save(file_path)


def draw_prompt(text, height, width, font_size=45):
    img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    font = ImageFont.truetype(
        osp.join(osp.dirname(osp.abspath(__file__)), 'arial.ttf'), font_size)

    guess_count = 0

    while font.font.getsize(text[:guess_count])[0][
            0] + 0.1 * width < width - 0.1 * width and guess_count < len(
                text):  # centerize
        guess_count += 1

    text_new = ''
    for idx, s in enumerate(text):
        if idx % guess_count == 0:
            text_new += '\n'
            if s == ' ':
                s = ''  # new line trip the first space
        text_new += s

    draw.text([int(0.1 * width), int(0.3 * height)],
              text_new,
              font=font,
              fill='black')
    return img


def compose_visualize(dir_path):
    file_list = sorted(os.listdir(dir_path))
    img_list = []
    info_dict = {'prompts': set(), 'sample_args': set(), 'suffix': set()}
    avg_clip = 0
    for filename in file_list:
        prompt, sample_args, index, suffix = osp.splitext(
            osp.basename(filename))[0].split('---')

        clip_score = float(sample_args.split("-")[0])
        avg_clip = avg_clip + clip_score
        filepath = osp.join(dir_path, filename)
        img = ToTensor()(Image.open(filepath))
        height, width = img.shape[1:]

        if prompt not in info_dict['prompts']:
            img_list.append(ToTensor()(draw_prompt(prompt,
                                                   height=height,
                                                   width=width,
                                                   font_size=45)))
        info_dict['prompts'].add(prompt)
        info_dict['sample_args'].add(sample_args)
        info_dict['suffix'].add(suffix)

        img_list.append(img)
    #assert len(
    #    info_dict['sample_args']
    #) == 1, 'compose dir should contain images form same sample args.'
    #assert len(info_dict['suffix']
    #           ) == 1, 'compose dir should contain images form same suffix.'
    avg_clip = avg_clip / len(file_list)
    grid = make_grid(img_list, nrow=len(img_list) // len(info_dict['prompts']))
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
        'cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    save_name = f"{str(avg_clip)}---{info_dict['suffix'].pop()}.jpg"
    im.save(osp.join(osp.dirname(dir_path), save_name))