import random
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import numpy as np
import PIL.Image as Image
import logging

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.utils import make_grid


def seed_everything(seed=1234):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed % (2**32))
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


"""
=============== Tensor Related ===============
"""


def slerp(val, low, high):
    """
    taken from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
    """
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm * high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (
        torch.sin(val * omega) / so
    ).unsqueeze(1) * high
    return res


def slerp_tensor(val, low, high):
    """
    used in negtive prompt inversion
    """
    shape = low.shape
    res = slerp(val, low.flatten(1), high.flatten(1))
    return res.reshape(shape)


"""
=============== Diffusion Related ===============
"""


def init_latent(latent, model, height, width, generator, batch_size):
    if latent is None:
        latent = torch.randn(
            (1, model.unet.in_channels, height // 8, width // 8),
            generator=generator,
        )
    latents = latent.expand(
        batch_size, model.unet.in_channels, height // 8, width // 8
    ).to(model.device)
    return latent, latents


"""
=============== Image Related ===============
"""


@torch.no_grad()
def latent2image(model, latents, return_type="np"):
    latents = 1 / 0.18215 * latents.detach()
    image = model.decode(latents)["sample"]
    if return_type == "np":
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        image = (image * 255).astype(np.uint8)
    return image


@torch.no_grad()
def image2latent(model, image):
    with torch.no_grad():
        if type(image) is Image:
            image = np.array(image)
        if type(image) is torch.Tensor and image.dim() == 4:
            latents = image
        else:
            image = torch.from_numpy(image).float() / 127.5 - 1
            image = image.permute(2, 0, 1).unsqueeze(0).to(model.device)
            latents = model.encode(image)["latent_dist"].mean
            latents = latents * 0.18215
    return latents


def load_512(image_path, left=0, right=0, top=0, bottom=0):
    if type(image_path) is str:
        image = np.array(Image.open(image_path))[:, :, :3]
    else:
        image = image_path
    h, w, c = image.shape
    left = min(left, w - 1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top : h - bottom, left : w - right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset : offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset : offset + w]
    image = np.array(Image.fromarray(image).resize((512, 512)))
    return image


def get_word_inds(text: str, word_place: int, tokenizer):
    split_text = text.split(" ")
    if type(word_place) is str:
        word_place = [i for i, word in enumerate(split_text) if word_place == word]
    elif type(word_place) is int:
        word_place = [word_place]
    out = []
    if len(word_place) > 0:
        words_encode = [
            tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)
        ][1:-1]
        cur_len, ptr = 0, 0

        for i in range(len(words_encode)):
            cur_len += len(words_encode[i])
            if ptr in word_place:
                out.append(i + 1)
            if cur_len >= len(split_text[ptr]):
                ptr += 1
                cur_len = 0
    return np.array(out)


def update_alpha_time_word(alpha, bounds, prompt_ind, word_inds=None):
    if type(bounds) is float:
        bounds = 0, bounds
    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
    if word_inds is None:
        word_inds = torch.arange(alpha.shape[2])
    alpha[:start, prompt_ind, word_inds] = 0
    alpha[start:end, prompt_ind, word_inds] = 1
    alpha[end:, prompt_ind, word_inds] = 0
    return alpha


def get_time_words_attention_alpha(
    prompts, num_steps, cross_replace_steps, tokenizer, max_num_words=77
):
    if type(cross_replace_steps) is not dict:
        cross_replace_steps = {"default_": cross_replace_steps}
    if "default_" not in cross_replace_steps:
        cross_replace_steps["default_"] = (0.0, 1.0)
    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
    for i in range(len(prompts) - 1):
        alpha_time_words = update_alpha_time_word(
            alpha_time_words, cross_replace_steps["default_"], i
        )
    for key, item in cross_replace_steps.items():
        if key != "default_":
            inds = [
                get_word_inds(prompts[i], key, tokenizer)
                for i in range(1, len(prompts))
            ]
            for i, ind in enumerate(inds):
                if len(ind) > 0:
                    alpha_time_words = update_alpha_time_word(
                        alpha_time_words, item, i, ind
                    )
    alpha_time_words = alpha_time_words.reshape(
        num_steps + 1, len(prompts) - 1, 1, 1, max_num_words
    )
    return alpha_time_words


def txt_draw(text, target_size=[512, 512]):
    plt.figure(dpi=300, figsize=(1, 1))
    plt.text(
        -0.1,
        1.1,
        text,
        fontsize=3.5,
        wrap=True,
        verticalalignment="top",
        horizontalalignment="left",
    )
    plt.axis("off")

    canvas = FigureCanvasAgg(plt.gcf())
    canvas.draw()
    w, h = canvas.get_width_height()
    buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8)
    buf.shape = (w, h, 4)
    buf = np.roll(buf, 3, axis=2)
    image = Image.frombytes("RGBA", (w, h), buf.tostring())
    image = image.resize(target_size, Image.ANTIALIAS)
    image = np.asarray(image)[:, :, :3]

    plt.close("all")

    return image


def load_img(image_path, device):
    image_pil = T.Resize(512)(Image.open(image_path).convert("RGB"))
    image = T.ToTensor()(image_pil).unsqueeze(0).to(device)
    return image

def load_masked_image(image_pil, mask=None, mask_path=None):
    # Assuming you have:
    # image_pil - a PIL Image (512x512)
    # mask - tensor with values 0 or 1
    # mask_path - path to load the mask
    if mask is None and mask_path is not None:
        mask = torch.load(mask_path)
    elif mask is not None and mask.max().item() > 1.0:
        mask = mask.float() / 255.0
    if mask.ndim == 4 and mask.shape[0] == 1:
        mask = mask[0]

    # Step 1: Convert PIL Image to PyTorch tensor
    image_tensor = torch.tensor(np.array(image_pil)).permute(2, 0, 1).float() / 255.0
    # Step 2: Apply the face mask to the image tensor by element-wise multiplication.
    masked_image_tensor = image_tensor.cpu() * mask.cpu()
    # Step 3: Optionally, convert the result back to a PIL Image
    masked_image_tensor_scaled = torch.clamp(masked_image_tensor * 255, 0, 255).type(
        torch.uint8
    )
    masked_image_pil = Image.fromarray(
        masked_image_tensor_scaled.permute(1, 2, 0).numpy()
    )
    return masked_image_pil


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def show_grid(tensor_imgs):
    def show(imgs):
        if not isinstance(imgs, list):
            imgs = [imgs]
        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            img = F.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    grid = make_grid(tensor_imgs)
    show(grid)


def check_gradients(module, module_name=""):
    for name, param in module.named_parameters():
        if param.requires_grad:
            print(f"{module_name}.{name} has been updated.")
        else:
            print(f"{module_name}.{name} has NOT been updated.")


def count_updated_params(module, specify_key=None, verbose=False):
    total_params = 0
    updated_params = 0
    no_updated_params_list = []
    updated_params_list = []

    for name, param in module.named_parameters():
        if specify_key is not None and specify_key not in name:
            continue  # Skip if the specify_key is not in the parameter name

        total_params += param.numel()
        if param.requires_grad:
            updated_params += param.numel()
            if verbose:
                updated_params_list.append(name)
        elif verbose:
            no_updated_params_list.append(name)

    if verbose:
        print(f'These Params will be updated:\n {updated_params_list}')
        print(f'These Params wont be updated:\n {no_updated_params_list}')
    
    if total_params == 0:
        return 0  # Avoid division by zero
    return updated_params / total_params


def check_zero_initialization(model, check_key, logger=None):
    # Initialize an empty list to store keys of non-zero-initialized parameters
    non_zero_params = []

    # Iterate over the model's named parameters
    for k, v in model.named_parameters():
        # Check if the key contains the specified string
        if check_key in k:
            # If the sum of the parameter is not zero, append its name to the list
            if torch.sum(v) != 0:
                non_zero_params.append(k)

    # After iteration, check if any non-zero-initialized parameters were found
    if non_zero_params:
        # Use logger if provided, otherwise default to print
        message_function = logger.info if logger else print
        message_function("The following parameters are not zero-initialized:")
        for param_name in non_zero_params:
            message_function(param_name)
    else:
        # Same check for logger or print usage
        message_function = logger.info if logger else print
        message_function(f"The '{check_key}' related parameters are zero-initialized.")


def get_module_params(module, only_trainable=False, specify_key=None):
    if module is None:
        num_params, num_trainable_params = 0.0, 0.0
    else:
        # Filter by specify_key if provided
        if specify_key is not None:
            params = [p for name, p in module.named_parameters() if specify_key in name]
            params_trainable = [p for p in params if p.requires_grad]
        else:
            params_trainable = list(filter(lambda p: p.requires_grad, module.parameters()))
            params = list(module.parameters())

        num_trainable_params = sum(p.numel() for p in params_trainable) / 1e6
        num_params = sum(p.numel() for p in params) / 1e6

    if not only_trainable:
        return num_params, num_trainable_params
    else:
        return num_trainable_params


def compare_model_params(state_dict1, state_dict2, specify_key=''):
    '''
    # Usage example
        state_dict1 = model1.state_dict()
        state_dict2 = model2.state_dict()
        compare_model_params(state_dict1, state_dict2, specify_key="attn2")
    '''
    differing_params_count = 0
    total_params_count = 0

    for key in state_dict1.keys():
        if specify_key in key:
            if key in state_dict2:
                param1 = state_dict1[key]
                param2 = state_dict2[key]
                total_params_count += torch.numel(param1)

                if not torch.equal(param1, param2):
                    differing_params_count += torch.numel(param1)
                    print(f"Parameters differ for key: {key}")

            else:
                print(f"Key {key} in 1st state_dict not found in 2nd state_dict")

    for key in state_dict2.keys():
        if specify_key in key:
            if key not in state_dict1:
                print(f"Key {key} in 2nd state_dict not found in 1st state_dict")

    print(f"Total differing parameters count: {differing_params_count / 1e6} M")
    print(f"Total parameters count with '{specify_key}': {total_params_count / 1e6} M")


def print_highlighted_block_log(
    title, message, title_color="\033[1;34m", text_color="\033[1;37m", logger=None
):
    """
    Prints a highlighted log block with a title and custom colors.

    :param title: The title of the log block.
    :param message: The message to be printed within the block.
    :param title_color: ANSI color code for the title. Default is bold blue.
    :param text_color: ANSI color code for the text. Default is bold white.
    """

    message_function = logger.info if logger else print

    # Box-drawing characters for styling
    top_left = "\u2554"
    top_right = "\u2557"
    bottom_left = "\u255A"
    bottom_right = "\u255D"
    horizontal = "\u2550"
    vertical = "\u2551"

    # Calculate the width of the block based on the longest line (title or message)
    if isinstance(message, dict):
        message = "\n".join(f"{{{k}: {v}}}" for k, v in message.items())
    lines = message.split("\n")
    max_width = (
        max(len(title), max(len(line) for line in lines)) + 4
    )  # Padding for sides

    # Construct the top and bottom borders
    top_border = f"{top_left}{horizontal * (max_width - 2)}{top_right}\n"
    bottom_border = f"{bottom_left}{horizontal * (max_width - 2)}{bottom_right}"

    # Print the highlighted block
    message_function(top_border)
    message_function(
        f"{vertical} {title_color}{title.center(max_width - 4)}{vertical}\n"
    )  # Centered title
    for line in lines:
        message_function(
            f"{vertical} {text_color}{line.center(max_width - 4)}{vertical}"
        )
    message_function(bottom_border)
    message_function("\033[0m")  # Reset all formatting at the end


# The background is set with 40 plus the number of the color, and the foreground with 30
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)

# These are the sequences need to get colored ouput
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[1;%dm"
BOLD_SEQ = "\033[1m"

COLORS = {
    "WARNING": YELLOW,
    "INFO": GREEN,
    "DEBUG": BLUE,
    "CRITICAL": YELLOW,
    "ERROR": RED,
}

def formatter_message(message, use_color=True):
    if use_color:
        message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ)
    else:
        message = message.replace("$RESET", "").replace("$BOLD", "")
    return message

class ColoredFormatter(logging.Formatter):
    def __init__(self, msg, use_color=True):
        logging.Formatter.__init__(self, msg)
        self.use_color = use_color

    def format(self, record):
        levelname = record.levelname
        if self.use_color and levelname in COLORS:
            levelname_color = (
                COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ
            )
            record.levelname = levelname_color
        return logging.Formatter.format(self, record)


# Custom logger class with multiple destinations
class ColoredLogger(logging.Logger):
    FORMAT = "[$BOLD%(name)-20s$RESET][%(levelname)-18s]  %(message)s ($BOLD%(filename)s$RESET:%(lineno)d)"
    COLOR_FORMAT = formatter_message(FORMAT, True)

    def __init__(self, name):
        logging.Logger.__init__(self, name, logging.DEBUG)

        color_formatter = ColoredFormatter(self.COLOR_FORMAT)

        console = logging.StreamHandler()
        console.setFormatter(color_formatter)

        self.addHandler(console)
        return
