# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel

import os
import cv2
import json
import random
import glob
import torch
import einops
import numpy as np
import datetime
import torchvision

from PIL import Image


def min_resize(x, m):
    if x.shape[0] < x.shape[1]:
        s0 = m
        s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
    else:
        s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
        s1 = m
    new_max = max(s1, s0)
    raw_max = max(x.shape[0], x.shape[1])
    if new_max < raw_max:
        interpolation = cv2.INTER_AREA
    else:
        interpolation = cv2.INTER_LANCZOS4
    y = cv2.resize(x, (s1, s0), interpolation=interpolation)
    return y


def d_resize(x, y):
    H, W, C = y.shape
    new_min = min(H, W)
    raw_min = min(x.shape[0], x.shape[1])
    if new_min < raw_min:
        interpolation = cv2.INTER_AREA
    else:
        interpolation = cv2.INTER_LANCZOS4
    y = cv2.resize(x, (W, H), interpolation=interpolation)
    return y


def resize_and_center_crop(image, target_width, target_height):
    if target_height == image.shape[0] and target_width == image.shape[1]:
        return image

    pil_image = Image.fromarray(image)
    original_width, original_height = pil_image.size
    scale_factor = max(target_width / original_width, target_height / original_height)
    resized_width = int(round(original_width * scale_factor))
    resized_height = int(round(original_height * scale_factor))
    resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
    left = (resized_width - target_width) / 2
    top = (resized_height - target_height) / 2
    right = (resized_width + target_width) / 2
    bottom = (resized_height + target_height) / 2
    cropped_image = resized_image.crop((left, top, right, bottom))
    return np.array(cropped_image)


def resize_and_center_crop_pytorch(image, target_width, target_height):
    B, C, H, W = image.shape

    if H == target_height and W == target_width:
        return image

    scale_factor = max(target_width / W, target_height / H)
    resized_width = int(round(W * scale_factor))
    resized_height = int(round(H * scale_factor))

    resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)

    top = (resized_height - target_height) // 2
    left = (resized_width - target_width) // 2
    cropped = resized[:, :, top:top + target_height, left:left + target_width]

    return cropped


def resize_without_crop(image, target_width, target_height):
    if target_height == image.shape[0] and target_width == image.shape[1]:
        return image

    pil_image = Image.fromarray(image)
    resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
    return np.array(resized_image)


def just_crop(image, w, h):
    if h == image.shape[0] and w == image.shape[1]:
        return image

    original_height, original_width = image.shape[:2]
    k = min(original_height / h, original_width / w)
    new_width = int(round(w * k))
    new_height = int(round(h * k))
    x_start = (original_width - new_width) // 2
    y_start = (original_height - new_height) // 2
    cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
    return cropped_image


def write_to_json(data, file_path):
    temp_file_path = file_path + ".tmp"
    with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
        json.dump(data, temp_file, indent=4)
    os.replace(temp_file_path, file_path)
    return


def read_from_json(file_path):
    with open(file_path, 'rt', encoding='utf-8') as file:
        data = json.load(file)
    return data


def get_active_parameters(m):
    return {k: v for k, v in m.named_parameters() if v.requires_grad}


def cast_training_params(m, dtype=torch.float32):
    result = {}
    for n, param in m.named_parameters():
        if param.requires_grad:
            param.data = param.to(dtype)
            result[n] = param
    return result


def separate_lora_AB(parameters, B_patterns=None):
    parameters_normal = {}
    parameters_B = {}

    if B_patterns is None:
        B_patterns = ['.lora_B.', '__zero__']

    for k, v in parameters.items():
        if any(B_pattern in k for B_pattern in B_patterns):
            parameters_B[k] = v
        else:
            parameters_normal[k] = v

    return parameters_normal, parameters_B


def set_attr_recursive(obj, attr, value):
    attrs = attr.split(".")
    for name in attrs[:-1]:
        obj = getattr(obj, name)
    setattr(obj, attrs[-1], value)
    return


def print_tensor_list_size(tensors):
    total_size = 0
    total_elements = 0

    if isinstance(tensors, dict):
        tensors = tensors.values()

    for tensor in tensors:
        total_size += tensor.nelement() * tensor.element_size()
        total_elements += tensor.nelement()

    total_size_MB = total_size / (1024 ** 2)
    total_elements_B = total_elements / 1e9

    print(f"Total number of tensors: {len(tensors)}")
    print(f"Total size of tensors: {total_size_MB:.2f} MB")
    print(f"Total number of parameters: {total_elements_B:.3f} billion")
    return


@torch.no_grad()
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
    batch_size = a.size(0)

    if b is None:
        b = torch.zeros_like(a)

    if mask_a is None:
        mask_a = torch.rand(batch_size) < probability_a

    mask_a = mask_a.to(a.device)
    mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
    result = torch.where(mask_a, a, b)
    return result


@torch.no_grad()
def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module


@torch.no_grad()
def supress_lower_channels(m, k, alpha=0.01):
    data = m.weight.data.clone()

    assert int(data.shape[1]) >= k

    data[:, :k] = data[:, :k] * alpha
    m.weight.data = data.contiguous().clone()
    return m


def freeze_module(m):
    if not hasattr(m, '_forward_inside_frozen_module'):
        m._forward_inside_frozen_module = m.forward
    m.requires_grad_(False)
    m.forward = torch.no_grad()(m.forward)
    return m


def get_latest_safetensors(folder_path):
    safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))

    if not safetensors_files:
        raise ValueError('No file to resume!')

    latest_file = max(safetensors_files, key=os.path.getmtime)
    latest_file = os.path.abspath(os.path.realpath(latest_file))
    return latest_file


def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
    tags = tags_str.split(', ')
    tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
    prompt = ', '.join(tags)
    return prompt


def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
    numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
    if round_to_int:
        numbers = np.round(numbers).astype(int)
    return numbers.tolist()


def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
    edges = np.linspace(0, 1, n + 1)
    points = np.random.uniform(edges[:-1], edges[1:])
    numbers = inclusive + (exclusive - inclusive) * points
    if round_to_int:
        numbers = np.round(numbers).astype(int)
    return numbers.tolist()


def soft_append_bcthw(history, current, overlap=0):
    if overlap <= 0:
        return torch.cat([history, current], dim=2)

    assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
    assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"

    weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
    blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
    output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)

    return output.to(history)


def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
    b, c, t, h, w = x.shape

    per_row = b
    for p in [6, 5, 4, 3, 2]:
        if b % p == 0:
            per_row = p
            break

    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
    torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
    return x


def save_bcthw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def save_bchw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c h w -> c h (b w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def add_tensors_with_padding(tensor1, tensor2):
    if tensor1.shape == tensor2.shape:
        return tensor1 + tensor2

    shape1 = tensor1.shape
    shape2 = tensor2.shape

    new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))

    padded_tensor1 = torch.zeros(new_shape)
    padded_tensor2 = torch.zeros(new_shape)

    padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
    padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2

    result = padded_tensor1 + padded_tensor2
    return result


def print_free_mem():
    torch.cuda.empty_cache()
    free_mem, total_mem = torch.cuda.mem_get_info(0)
    free_mem_mb = free_mem / (1024 ** 2)
    total_mem_mb = total_mem / (1024 ** 2)
    print(f"Free memory: {free_mem_mb:.2f} MB")
    print(f"Total memory: {total_mem_mb:.2f} MB")
    return


def print_gpu_parameters(device, state_dict, log_count=1):
    summary = {"device": device, "keys_count": len(state_dict)}

    logged_params = {}
    for i, (key, tensor) in enumerate(state_dict.items()):
        if i >= log_count:
            break
        logged_params[key] = tensor.flatten()[:3].tolist()

    summary["params"] = logged_params

    print(str(summary))
    return


def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
    from PIL import Image, ImageDraw, ImageFont

    txt = Image.new("RGB", (width, height), color="white")
    draw = ImageDraw.Draw(txt)
    font = ImageFont.truetype(font_path, size=size)

    if text == '':
        return np.array(txt)

    # Split text into lines that fit within the image width
    lines = []
    words = text.split()
    current_line = words[0]

    for word in words[1:]:
        line_with_word = f"{current_line} {word}"
        if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
            current_line = line_with_word
        else:
            lines.append(current_line)
            current_line = word

    lines.append(current_line)

    # Draw the text line by line
    y = 0
    line_height = draw.textbbox((0, 0), "A", font=font)[3]

    for line in lines:
        if y + line_height > height:
            break  # stop drawing if the next line will be outside the image
        draw.text((0, y), line, fill="black", font=font)
        y += line_height

    return np.array(txt)


def blue_mark(x):
    x = x.copy()
    c = x[:, :, 2]
    b = cv2.blur(c, (9, 9))
    x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
    return x


def green_mark(x):
    x = x.copy()
    x[:, :, 2] = -1
    x[:, :, 0] = -1
    return x


def frame_mark(x):
    x = x.copy()
    x[:64] = -1
    x[-64:] = -1
    x[:, :8] = 1
    x[:, -8:] = 1
    return x


@torch.inference_mode()
def pytorch2numpy(imgs):
    results = []
    for x in imgs:
        y = x.movedim(0, -1)
        y = y * 127.5 + 127.5
        y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
        results.append(y)
    return results


@torch.inference_mode()
def numpy2pytorch(imgs):
    h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
    h = h.movedim(-1, 1)
    return h


@torch.no_grad()
def duplicate_prefix_to_suffix(x, count, zero_out=False):
    if zero_out:
        return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
    else:
        return torch.cat([x, x[:count]], dim=0)


def weighted_mse(a, b, weight):
    return torch.mean(weight.float() * (a.float() - b.float()) ** 2)


def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
    x = (x - x_min) / (x_max - x_min)
    x = max(0.0, min(x, 1.0))
    x = x ** sigma
    return y_min + x * (y_max - y_min)


def expand_to_dims(x, target_dims):
    return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))


def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
    if tensor is None:
        return None

    first_dim = tensor.shape[0]

    if first_dim == batch_size:
        return tensor

    if batch_size % first_dim != 0:
        raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")

    repeat_times = batch_size // first_dim

    return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))


def dim5(x):
    return expand_to_dims(x, 5)


def dim4(x):
    return expand_to_dims(x, 4)


def dim3(x):
    return expand_to_dims(x, 3)


def crop_or_pad_yield_mask(x, length):
    B, F, C = x.shape
    device = x.device
    dtype = x.dtype

    if F < length:
        y = torch.zeros((B, length, C), dtype=dtype, device=device)
        mask = torch.zeros((B, length), dtype=torch.bool, device=device)
        y[:, :F, :] = x
        mask[:, :F] = True
        return y, mask

    return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)


def extend_dim(x, dim, minimal_length, zero_pad=False):
    original_length = int(x.shape[dim])

    if original_length >= minimal_length:
        return x

    if zero_pad:
        padding_shape = list(x.shape)
        padding_shape[dim] = minimal_length - original_length
        padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
    else:
        idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
        last_element = x[idx]
        padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)

    return torch.cat([x, padding], dim=dim)


def lazy_positional_encoding(t, repeats=None):
    if not isinstance(t, list):
        t = [t]

    from diffusers.models.embeddings import get_timestep_embedding

    te = torch.tensor(t)
    te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)

    if repeats is None:
        return te

    te = te[:, None, :].expand(-1, repeats, -1)

    return te


def state_dict_offset_merge(A, B, C=None):
    result = {}
    keys = A.keys()

    for key in keys:
        A_value = A[key]
        B_value = B[key].to(A_value)

        if C is None:
            result[key] = A_value + B_value
        else:
            C_value = C[key].to(A_value)
            result[key] = A_value + B_value - C_value

    return result


def state_dict_weighted_merge(state_dicts, weights):
    if len(state_dicts) != len(weights):
        raise ValueError("Number of state dictionaries must match number of weights")

    if not state_dicts:
        return {}

    total_weight = sum(weights)

    if total_weight == 0:
        raise ValueError("Sum of weights cannot be zero")

    normalized_weights = [w / total_weight for w in weights]

    keys = state_dicts[0].keys()
    result = {}

    for key in keys:
        result[key] = state_dicts[0][key] * normalized_weights[0]

        for i in range(1, len(state_dicts)):
            state_dict_value = state_dicts[i][key].to(result[key])
            result[key] += state_dict_value * normalized_weights[i]

    return result


def group_files_by_folder(all_files):
    grouped_files = {}

    for file in all_files:
        folder_name = os.path.basename(os.path.dirname(file))
        if folder_name not in grouped_files:
            grouped_files[folder_name] = []
        grouped_files[folder_name].append(file)

    list_of_lists = list(grouped_files.values())
    return list_of_lists


def generate_timestamp():
    now = datetime.datetime.now()
    timestamp = now.strftime('%y%m%d_%H%M%S')
    milliseconds = f"{int(now.microsecond / 1000):03d}"
    random_number = random.randint(0, 9999)
    return f"{timestamp}_{milliseconds}_{random_number}"


def write_PIL_image_with_png_info(image, metadata, path):
    from PIL.PngImagePlugin import PngInfo

    png_info = PngInfo()
    for key, value in metadata.items():
        png_info.add_text(key, value)

    image.save(path, "PNG", pnginfo=png_info)
    return image


def torch_safe_save(content, path):
    torch.save(content, path + '_tmp')
    os.replace(path + '_tmp', path)
    return path


def move_optimizer_to_device(optimizer, device):
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)
