import torch
import torch.nn.functional as F
import numpy as np
import cv2

from PIL import Image
from thop import profile
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])


class Cache(list):
    def __init__(self, max_size=0):
        super().__init__()
        self.max_size = max_size

    def append(self, x):
        if self.max_size <= 0:
            return
        super().append(x)
        if len(self) > self.max_size:
            self.pop(0)

def cal_l2_distance(keys_tensor, original_val, K=5, tau=1.0):
    # keys_tensor: [N, C]
    # original_val: [C]
    dist = torch.norm(keys_tensor - original_val, dim=-1)  # [N]

    topk_dist, topk_indices = torch.topk(dist, k=K, largest=False)
    topk_keys = keys_tensor[topk_indices]  # [K, C]

    weights = torch.softmax(-topk_dist / tau, dim=0)  # [K]
    # print("weights", weights)
    weighted_avg_key = torch.sum(topk_keys * weights.unsqueeze(-1), dim=0)  # [C]

    return weighted_avg_key

def cal_cosine_sim(keys_tensor, original_val):
    original_val_norm = F.normalize(original_val.unsqueeze(0), p=2, dim=-1)  # [1, C]
    keys_tensor_norm = F.normalize(keys_tensor, p=2, dim=-1)  # [N, C]

    cos_sim = torch.sum(original_val_norm * keys_tensor_norm, dim=-1)  # [N]

    weights = F.softmax(cos_sim, dim=0)  # [N]

    weighted_avg_key = torch.sum(keys_tensor * weights.unsqueeze(-1), dim=0)  # [C]

    return weighted_avg_key


@torch.no_grad()
def run_one_image(img, tgt, model, device, panicl):
    x = torch.tensor(img)
    # make it a batch-like
    x = torch.einsum('nhwc->nchw', x)

    tgt = torch.tensor(tgt)
    # make it a batch-like
    tgt = torch.einsum('nhwc->nchw', tgt)

    bool_masked_pos = torch.zeros(model.patch_embed.num_patches)
    bool_masked_pos[model.patch_embed.num_patches // 2:] = 1
    bool_masked_pos = bool_masked_pos.unsqueeze(dim=0)
    valid = torch.ones_like(tgt)

    if model.seg_type == 'instance':
        seg_type = torch.ones([valid.shape[0], 1])
    else:
        seg_type = torch.zeros([valid.shape[0], 1])

    feat_ensemble = 0 if len(x) > 1 else -1
    loss, y, mask, latent_target, latent = model(x.float().to(device), tgt.float().to(device), bool_masked_pos.to(device),
                                      valid.float().to(device), seg_type.to(device), feat_ensemble, panicl)

    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    output = y[0, y.shape[1] // 2:, :, :]
    output = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255)
    return output, latent


@torch.no_grad()
def run_one_image_datastore(img, tgt, model, device, datastore, args, target=True):
    x = torch.tensor(img)
    # make it a batch-like: [N, C, H, W]
    x = torch.einsum('nhwc->nchw', x)

    tgt = torch.tensor(tgt)
    # make it a batch-like: [N, C, H, W]
    tgt = torch.einsum('nhwc->nchw', tgt)

    bool_masked_pos = torch.zeros(model.patch_embed.num_patches)
    bool_masked_pos[model.patch_embed.num_patches // 2:] = 1
    bool_masked_pos = bool_masked_pos.unsqueeze(dim=0)

    valid = torch.ones_like(tgt)

    if model.seg_type == 'instance':
        seg_type = torch.ones([valid.shape[0], 1])
    else:
        seg_type = torch.zeros([valid.shape[0], 1])

    feat_ensemble = 0 if len(x) > 1 else -1

    _, _, _, _, query_latent = model(
        x.float().to(device),
        tgt.float().to(device),
        bool_masked_pos.to(device),
        valid.float().to(device),
        seg_type.to(device),
        feat_ensemble
    )

    query_latent = torch.cat(query_latent, dim=-1)
    if target:
        query_latent_target = query_latent[0, query_latent.shape[1] // 2:, :, :]
    else:
        query_latent_target = query_latent[0]
        # print("latent_target shape: ", query_latent_target.shape)

    H, W, C = query_latent_target.shape
    if args.indice_level:
        for i_ in range(H):
            for j_ in range(W):
                patch_key = str(i_ * W + j_)
                if patch_key in datastore:
                    entries = datastore[patch_key]

                    # if len(entries) > 0:
                    # print("entry", entries)
                    keys_tensor = torch.stack([entry["key"] for entry in entries], dim=0).to(device)  # [N, C]
                    # print("key_tensor", keys_tensor.shape)

                    # print("avg_key", avg_key.shape)
                    # print("avg_key", avg_key)
                    # assert 1 == 0
                    original_val = query_latent_target[i_, j_, :]

                    # # for mean the key
                    # avg_key = torch.mean(keys_tensor, dim=0)  # [C]
                    # blended_val = (1 - args.alpha) * original_val + args.alpha * avg_key

                    weighted_avg_key = cal_l2_distance(keys_tensor, original_val, keys_tensor.shape[0], args.temp)

                    blended_val = (1 - args.alpha) * original_val + args.alpha * weighted_avg_key

                    # print("original_val[:5]", original_val[:5])
                    # # print("weighted_avg_key[:5]", weighted_avg_key[:5])
                    # print("blended_val[:5]", blended_val[:5])

                    # split_vals = torch.split(blended_val, 1024, dim=0)
                    #
                    # normed_chunks = []
                    # for chunk in split_vals:
                    #     chunk_2d = chunk.unsqueeze(0)  # 形状变为 (1, 1024)
                    #     chunk_normed = model.norm(chunk_2d)
                    #     chunk_normed = chunk_normed.squeeze(0)
                    #     normed_chunks.append(chunk_normed)
                    # blended_val = torch.cat(normed_chunks, dim=0)

                    query_latent_target[i_, j_, :] = blended_val
    else:
        for i_ in range(H):
            for j_ in range(W):
                entries = datastore
                keys_tensor = torch.stack([entry["key"] for entry in entries], dim=0).to(device)  # [N, C]
                # avg_key = torch.mean(keys_tensor, dim=0)  # [C]
                # print("avg_key", avg_key)
                # assert 1 == 0
                original_val = query_latent_target[i_, j_, :]

                # blended_val = (1 - alpha) * original_val + alpha * avg_key
                weighted_avg_key = cal_l2_distance(keys_tensor, original_val, args.k)

                blended_val = (1 - args.alpha) * original_val + args.alpha * weighted_avg_key
                query_latent_target[i_, j_, :] = blended_val
    # print("query_latent_target shape: ", query_latent_target.shape)
    if target:
        query_latent[0, query_latent.shape[1] // 2:, :, :] = query_latent_target
    else:
        query_latent[0, :, :, :] = query_latent_target

    # query_latent[0, H:, :, :] = query_latent_target
    # print("query_latent shape: ", query_latent.shape)
    x_list = list(torch.split(query_latent, 1024, dim=-1))
    # x_list = [model.norm(x_chunk) for x_chunk in x_list]

    pred = model.forward_decoder(x_list)
    y = model.patchify(pred)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    output = y[0, y.shape[1] // 2:, :, :]
    output = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255)
    return output


def inference_image(model, device, img_path, img2_paths, tgt2_paths, panicl=False):
    res, hres = 448, 448

    image = Image.open(img_path).convert("RGB")
    input_image = np.array(image)
    size = image.size
    image = np.array(image.resize((res, hres))) / 255.

    image_batch, target_batch = [], []
    for img2_path, tgt2_path in zip(img2_paths, tgt2_paths):
        img2 = Image.open(img2_path).convert("RGB")
        img2 = img2.resize((res, hres))
        img2 = np.array(img2) / 255.
        if isinstance(tgt2_path, Image.Image):
            tgt2 = tgt2_path.convert("RGB")
            tgt2 = tgt2.resize((res, hres), Image.NEAREST)
        else:
            tgt2 = Image.open(tgt2_path).convert("RGB")
            tgt2 = tgt2.resize((res, hres), Image.NEAREST)
        tgt2 = np.array(tgt2) / 255.
        # tgt2 = np.array(tgt2)
        tgt = tgt2  # tgt is not available
        tgt = np.concatenate((tgt2, tgt), axis=0)
        img = np.concatenate((img2, image), axis=0)

        assert img.shape == (2 * res, res, 3), f'{img.shape}'
        # normalize by ImageNet mean and std
        img = img - imagenet_mean
        img = img / imagenet_std

        assert tgt.shape == (2 * res, res, 3), f'{img.shape}'
        # normalize by ImageNet mean and std
        tgt = tgt - imagenet_mean
        tgt = tgt / imagenet_std

        image_batch.append(img)
        target_batch.append(tgt)

    img = np.stack(image_batch, axis=0)
    tgt = np.stack(target_batch, axis=0)
    """### Run SegGPT on the image"""
    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)
    raw_output, latent_target = run_one_image(img, tgt, model, device, panicl)
    # output_mask = output.detach()
    # output_mask = torch.clip((output_mask.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy()
    output = F.interpolate(
        raw_output[None, ...].permute(0, 3, 1, 2),
        size=[size[1], size[0]],
        mode='nearest',
    ).permute(0, 2, 3, 1)[0].numpy()
    output_mask = np.uint8(output.copy())
    # output_mask = Image.fromarray(output_mask.astype(np.uint8))
    # output_mask.save(out_path)
    # output = Image.fromarray((input_image * (0.8 * output / 255 + 0.2)).astype(np.uint8))
    # output = Image.fromarray((input_image * (0.8 * output + 0.2)).astype(np.uint8))
    output = Image.fromarray((output).astype(np.uint8))
    return output_mask, output, raw_output, latent_target

def inference_image_datastore(model, device, img_path, img2_paths, tgt2_paths, datastore, args, target):
    res, hres = 448, 448

    image = Image.open(img_path).convert("RGB")
    input_image = np.array(image)
    size = image.size
    # size = [112, 112]
    image = np.array(image.resize((res, hres))) / 255.

    image_batch, target_batch = [], []
    for img2_path, tgt2_path in zip(img2_paths, tgt2_paths):
        img2 = Image.open(img2_path).convert("RGB")
        img2 = img2.resize((res, hres))
        img2 = np.array(img2) / 255.
        if isinstance(tgt2_path, Image.Image):
            tgt2 = tgt2_path.convert("RGB")
            tgt2 = tgt2.resize((res, hres), Image.NEAREST)
        else:
            tgt2 = Image.open(tgt2_path).convert("RGB")
            tgt2 = tgt2.resize((res, hres), Image.NEAREST)
        tgt2 = np.array(tgt2) / 255.
        # tgt2 = np.array(tgt2)
        tgt = tgt2  # tgt is not available
        tgt = np.concatenate((tgt2, tgt), axis=0)
        img = np.concatenate((img2, image), axis=0)

        assert img.shape == (2 * res, res, 3), f'{img.shape}'
        # normalize by ImageNet mean and std
        img = img - imagenet_mean
        img = img / imagenet_std

        assert tgt.shape == (2 * res, res, 3), f'{img.shape}'
        # normalize by ImageNet mean and std
        tgt = tgt - imagenet_mean
        tgt = tgt / imagenet_std

        image_batch.append(img)
        target_batch.append(tgt)

    img = np.stack(image_batch, axis=0)
    tgt = np.stack(target_batch, axis=0)
    """### Run SegGPT on the image"""
    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)
    raw_output = run_one_image_datastore(img, tgt, model, device, datastore, args, target)
    # output_mask = output.detach()
    # output_mask = torch.clip((output_mask.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy()
    output = F.interpolate(
        raw_output[None, ...].permute(0, 3, 1, 2),
        size=[size[1], size[0]],
        mode='nearest',
    ).permute(0, 2, 3, 1)[0].numpy()
    output_mask = np.uint8(output.copy())
    # output_mask = Image.fromarray(output_mask.astype(np.uint8))
    # output_mask.save(out_path)
    # output = Image.fromarray((input_image * (0.8 * output / 255 + 0.2)).astype(np.uint8))
    # output = Image.fromarray((input_image * (0.8 * output + 0.2)).astype(np.uint8))
    output = Image.fromarray((output).astype(np.uint8))
    return output_mask, output, raw_output


def inference_image_cache(model, device, img_path, img2_paths, tgt2_paths, target=False):
    res, hres = 448, 448

    image = Image.open(img_path).convert("RGB")
    input_image = np.array(image)
    size = image.size
    image = np.array(image.resize((res, hres))) / 255.

    image_batch, target_batch = [], []
    for img2_path, tgt2_path in zip(img2_paths, tgt2_paths):
        img2 = Image.open(img2_path).convert("RGB")
        img2 = img2.resize((res, hres))
        img2 = np.array(img2) / 255.
        if isinstance(tgt2_path, Image.Image):
            tgt2 = tgt2_path.convert("RGB")
            tgt2 = tgt2.resize((res, hres), Image.NEAREST)
        else:
            tgt2 = Image.open(tgt2_path).convert("RGB")
            tgt2 = tgt2.resize((res, hres), Image.NEAREST)
        tgt2 = np.array(tgt2) / 255.
        # tgt2 = np.array(tgt2)
        tgt = tgt2  # tgt is not available
        tgt = np.concatenate((tgt2, tgt), axis=0)
        img = np.concatenate((img2, image), axis=0)

        assert img.shape == (2 * res, res, 3), f'{img.shape}'
        # normalize by ImageNet mean and std
        img = img - imagenet_mean
        img = img / imagenet_std

        assert tgt.shape == (2 * res, res, 3), f'{img.shape}'
        # normalize by ImageNet mean and std
        tgt = tgt - imagenet_mean
        tgt = tgt / imagenet_std

        image_batch.append(img)
        target_batch.append(tgt)

    img = np.stack(image_batch, axis=0)
    tgt = np.stack(target_batch, axis=0)
    """### Run SegGPT on the image"""
    # make random mask reproducible (comment out to make it change)
    torch.manual_seed(2)
    _, latent = run_one_image(img, tgt, model, device)

    latent = torch.cat(latent, dim=-1)
    if target:
        latent_target = latent[0, latent.shape[1] // 2:, :, :]
    else:
        latent_target = latent[0]
        # print("latent_target shape: ", latent_target.shape)

    # print("latent_target shape: ", latent_target.shape)  # torch.Size([28, 28, 4096])
    # # output_mask = output.detach()
    # # output_mask = torch.clip((output_mask.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy()
    # output = F.interpolate(
    #     raw_output[None, ...].permute(0, 3, 1, 2),
    #     size=[size[1], size[0]],
    #     mode='nearest',
    # ).permute(0, 2, 3, 1)[0].numpy()
    # output_mask = np.uint8(output.copy())
    # # output_mask = Image.fromarray(output_mask.astype(np.uint8))
    # # output_mask.save(out_path)
    # # output = Image.fromarray((input_image * (0.8 * output / 255 + 0.2)).astype(np.uint8))
    # # output = Image.fromarray((input_image * (0.8 * output + 0.2)).astype(np.uint8))
    # output = Image.fromarray((output).astype(np.uint8))
    return latent_target


def inference_video(model, device, vid_path, num_frames, img2_paths, tgt2_paths, out_path):
    res, hres = 448, 448

    cap = cv2.VideoCapture(vid_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height), True)

    if img2_paths is None:
        _, frame = cap.read()
        img2 = Image.fromarray(frame[:, :, ::-1]).convert('RGB')
    else:
        img2 = Image.open(img2_paths[0]).convert("RGB")
    img2 = img2.resize((res, hres))
    img2 = np.array(img2) / 255.

    tgt2 = Image.open(tgt2_paths[0]).convert("RGB")
    tgt2 = tgt2.resize((res, hres), Image.NEAREST)
    tgt2 = np.array(tgt2) / 255.

    frames_cache, target_cache = Cache(num_frames), Cache(num_frames)

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        image_batch, target_batch = [], []
        image = Image.fromarray(frame[:, :, ::-1]).convert('RGB')
        input_image = np.array(image)
        size = image.size
        image = np.array(image.resize((res, hres))) / 255.

        for prompt, target in zip([img2] + frames_cache, [tgt2] + target_cache):
            tgt = target  # tgt is not available
            tgt = np.concatenate((target, tgt), axis=0)
            img = np.concatenate((prompt, image), axis=0)

            assert img.shape == (2 * res, res, 3), f'{img.shape}'
            # normalize by ImageNet mean and std
            img = img - imagenet_mean
            img = img / imagenet_std

            assert tgt.shape == (2 * res, res, 3), f'{img.shape}'
            # normalize by ImageNet mean and std
            tgt = tgt - imagenet_mean
            tgt = tgt / imagenet_std

            image_batch.append(img)
            target_batch.append(tgt)

        img = np.stack(image_batch, axis=0)
        tgt = np.stack(target_batch, axis=0)

        torch.manual_seed(2)
        output = run_one_image(img, tgt, model, device)

        frames_cache.append(image)
        target_cache.append(
            output.mean(-1) \
                .gt(128).float() \
                .unsqueeze(-1).expand(-1, -1, 3) \
                .numpy()
        )

        output = F.interpolate(
            output[None, ...].permute(0, 3, 1, 2),
            size=[size[1], size[0]],
            mode='nearest',
        ).permute(0, 2, 3, 1)[0].numpy()
        output = input_image * (0.6 * output / 255 + 0.4)
        video_writer.write(np.ascontiguousarray(output.astype(np.uint8)[:, :, ::-1]))

    video_writer.release()
