import os
from os import path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import numpy as np
import cv2
import ot
import math
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt

from segment_anything import sam_model_registry, SamPredictor
from segment_anything import SamAutomaticMaskGenerator
from dinov2.models import vision_transformer as vits
import dinov2.utils.utils as dinov2_utils
from dinov2.data.transforms import MaybeToTensor, make_normalize_transform

from matcher.k_means import kmeans_pp

from segment_anything.utils.amg import (
    batch_iterator, 
)

from scipy.sparse import csgraph


class Matcher:
    def __init__(
            self,
            encoder,
            generator=None,
            input_size=1024,
            num_centers=8,
            use_box=False,
            use_points_or_centers=True,
            sample_range=(4, 6),
            max_sample_iterations=30,
            alpha=1.,
            beta=0.,
            exp=0.,
            score_filter_cfg=None,
            num_merging_mask=10,
            device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    ):
        # models
        self.encoder = encoder
        self.predictor = generator
        self.rps = None

        if not isinstance(input_size, tuple):
            input_size = (input_size, input_size)
        self.input_size = input_size

        img_size = 518
        feat_size = img_size // self.encoder.patch_size

        self.encoder_img_size = img_size
        self.encoder_feat_size = feat_size

        # transforms for image encoder
        self.encoder_transform = transforms.Compose([
            MaybeToTensor(),
            transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
            make_normalize_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

        self.tar_img = None
        self.tar_img_np = None

        self.ref_imgs = None
        self.ref_masks_pool = None
        self.nshot = None

        # self.encoder_img_size = None
        # self.encoder_feat_size = None

        self.num_centers = num_centers
        self.use_box = use_box
        self.use_points_or_centers = use_points_or_centers
        self.sample_range = sample_range
        self.max_sample_iterations =max_sample_iterations

        self.alpha, self.beta, self.exp = alpha, beta, exp
        assert score_filter_cfg is not None
        self.score_filter_cfg = score_filter_cfg
        self.num_merging_mask = num_merging_mask

        self.device = device
        self.count = 0

        self.vis = False
        self.label_colors = np.array([[0, 0, 255], [0, 255, 0]])
        self.label_colors = np.concatenate([self.label_colors, np.random.randint(0, 255, (400, 3))])

    def set_reference(self, imgs, masks):

        def reference_masks_verification(masks):
            if masks.sum() == 0:
                _, _, sh, sw = masks.shape
                masks[..., (sh // 2 - 7):(sh // 2 + 7), (sw // 2 - 7):(sw // 2 + 7)] = 1
            return masks

        imgs = imgs.flatten(0, 1)  # bs, 3, h, w

        # process reference masks
        masks = reference_masks_verification(masks)
        masks = masks.permute(1, 0, 2, 3)  # ns, 1, h, w
        nshot = masks.shape[0]

        self.ref_imgs = imgs
        self.nshot = nshot
        self.ref_masks = masks
        print(imgs.shape, masks.shape)
        self.ref_imgs_np = imgs.mul(255).byte().permute(0, 2, 3, 1).cpu().numpy()
        self.ref_masks_np = masks.mul(255).byte().permute(0, 2, 3, 1).cpu().numpy()

    def set_target(self, img):

        img_h, img_w = img.shape[-2:]
        assert img_h == self.input_size[0] and img_w == self.input_size[1]

        # transform query to numpy as input of sam
        img_np = img.mul(255).byte()
        img_np = img_np.squeeze(0).permute(1, 2, 0).cpu().numpy()

        self.tar_img = img
        self.tar_img_np = img_np

    def show_anns(self, masks, img, sfx='sup', coords=None, labels=None):
        if self.vis == False:
            return

        img = img.permute(1, 2, 0).cpu().numpy()
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) * 255
        img = cv2.resize(img, (400, 400))

        if len(masks) == 0:
            return

        masks = F.interpolate(torch.from_numpy(masks).unsqueeze(0).float(), (img.shape[0], img.shape[1]))
        masks = masks[0].cpu().numpy()


        for idx, ann in enumerate(masks):
            m = ann > 0.5
            # if len(masks) == 1:
            #     color_mask = np.array([0, 200, 0])  # Bright green color
            # else:
            #     color_mask = np.random.randint(0, 255, (3,))
            color_mask = self.label_colors[idx]
            img[m] += color_mask * 0.4

        if coords is not None:
            for coord, label in zip(coords, labels):
                color = (0, 0, 255) if label == 1 else (0, 0, 0)
                cv2.circle(img, (int(coord[0] / 1024 *400), int(coord[1]/1024*400)), 5, color, -1)

        cv2.imwrite(f'output/figs/{self.count}_{sfx}.png', img)

    def show_heatmap(self, simmap, img, sfx='sim'):
        if self.vis == False:
            return

        # img = img.permute(1, 2, 0).cpu().numpy()
        # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        # generate heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * simmap[0]), cv2.COLORMAP_VIRIDIS)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        heatmap = cv2.resize(heatmap, (400, 400), interpolation=cv2.INTER_NEAREST)

        # superimpose the heatmap on original image
        # superimposed_img = heatmap * 0.5 + img
        cv2.imwrite(f'output/figs/{self.count}_{sfx}.png', heatmap)

    def show_points(self, img, coords, labels, color=None, sfx='points'):
        if self.vis == False:
            return
        img = img.permute(1, 2, 0).cpu().numpy()
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) * 255

        # generate colors for each label
        # if np.max(labels) == 1:
        #     label_colors = np.array([[0, 0, 255], [0, 255, 0]])
        # else:
        #     label_colors = np.random.randint(0, 255, (np.max(labels)+1, 3))
        # print(f"label_colors: {label_colors}")

        for coord, label in zip(coords, labels):
            if color is None:
                # a color for each label
                color_now = tuple(self.label_colors[label].tolist())
            else:
                color_now = color
            # color = (0, 0, 255) if label == 1 else (0, 0, 0)  # Red for label 1, black for others
            cv2.circle(img, (int(coord[0]), int(coord[1])), 10, color_now, -1)

        img = cv2.resize(img, (400, 400))
        cv2.imwrite(f'output/figs/{self.count}_{sfx}.png', img)

    def predict(self):

        ref_mask = self.predictor.set_image(self.ref_imgs_np[0], self.ref_masks_np[0])
        ref_feat = self.predictor.features.squeeze().permute(1, 2, 0)

        ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
        ref_mask = ref_mask.squeeze()[0]

        # Target feature extraction
        target_feat = ref_feat[ref_mask > 0]
        target_embedding = target_feat.mean(0).unsqueeze(0)
        target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
        target_embedding = target_embedding.unsqueeze(0)
        
        self.predictor.set_image(self.tar_img_np)
        test_feat = self.predictor.features.squeeze()

        # Cosine similarity
        C, h, w = test_feat.shape
        test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
        test_feat = test_feat.reshape(C, h * w)
        sim = target_feat @ test_feat

        sim = sim.reshape(1, 1, h, w)
        sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
        sim = self.predictor.model.postprocess_masks(
                        sim,
                        input_size=self.input_size,
                        original_size=self.predictor.original_size).squeeze()

        # Positive-negative location prior
        topk_xy_i, topk_label_i, last_xy_i, last_label_i = self.point_selection(sim, topk=1)
        topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
        topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)

        # Obtain the target guidance for cross-attention layers
        sim = (sim - sim.mean()) / torch.std(sim)
        sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
        attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)

        # First-step prediction
        masks, scores, logits, _ = self.predictor.predict(
            point_coords=topk_xy, 
            point_labels=topk_label, 
            multimask_output=False,
            attn_sim=attn_sim,  # Target-guided Attention
            target_embedding=target_embedding  # Target-semantic Prompting
        )
        best_idx = 0

        # Cascaded Post-refinement-1
        masks, scores, logits, _ = self.predictor.predict(
                    point_coords=topk_xy,
                    point_labels=topk_label,
                    mask_input=logits[best_idx: best_idx + 1, :, :], 
                    multimask_output=True)
        best_idx = np.argmax(scores)

        # Cascaded Post-refinement-2
        y, x = np.nonzero(masks[best_idx])
        x_min = x.min()
        x_max = x.max()
        y_min = y.min()
        y_max = y.max()
        input_box = np.array([x_min, y_min, x_max, y_max])
        masks, scores, logits, _ = self.predictor.predict(
            point_coords=topk_xy,
            point_labels=topk_label,
            box=input_box[None, :],
            mask_input=logits[best_idx: best_idx + 1, :, :], 
            multimask_output=True)
        best_idx = np.argmax(scores)

        # Save masks
        # plt.figure(figsize=(10, 10))
        # plt.imshow(test_image)
        # show_mask(masks[best_idx], plt.gca())
        # show_points(topk_xy, topk_label, plt.gca())
        # plt.title(f"Mask {best_idx}", fontsize=18)
        # plt.axis('off')
        # vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}.jpg')
        # with open(vis_mask_output_path, 'wb') as outfile:
        #     plt.savefig(outfile, format='jpg')

        final_mask = masks[best_idx]
        final_mask = torch.tensor(final_mask).cuda().unsqueeze(0).float()
        # mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
        # mask_colors[final_mask, :] = np.array([[0, 0, 128]])
        # mask_output_path = os.path.join(output_path, test_idx + '.png')
        # cv2.imwrite(mask_output_path, mask_colors)
        return final_mask


    def point_selection(self, mask_sim, topk=1):
        # Top-1 point selection
        w, h = mask_sim.shape
        topk_xy = mask_sim.flatten(0).topk(topk)[1]
        topk_x = (topk_xy // h).unsqueeze(0)
        topk_y = (topk_xy - topk_x * h)
        topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
        topk_label = np.array([1] * topk)
        topk_xy = topk_xy.cpu().numpy()
            
        # Top-last point selection
        last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1]
        last_x = (last_xy // h).unsqueeze(0)
        last_y = (last_xy - last_x * h)
        last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0)
        last_label = np.array([0] * topk)
        last_xy = last_xy.cpu().numpy()
        
        return topk_xy, topk_label, last_xy, last_label


    def clear(self):

        self.tar_img = None
        self.tar_img_np = None

        self.ref_imgs = None
        self.ref_masks_pool = None
        self.nshot = None

        # self.encoder_img_size = None
        # self.encoder_feat_size = None


def build_matcher_oss(args):

    # DINOv2, Image Encoder
    dinov2_kwargs = dict(
        img_size=518,
        patch_size=14,
        init_values=1e-5,
        ffn_layer='mlp',
        block_chunks=0,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
    )
    dinov2 = vits.__dict__[args.dinov2_size](**dinov2_kwargs)

    dinov2_utils.load_pretrained_weights(dinov2, args.dinov2_weights, "teacher")
    dinov2.eval()
    dinov2.to(device=args.device)

    # SAM
    sam = sam_model_registry[args.sam_size](checkpoint=args.sam_weights)
    sam.to(device=args.device)
    predictor = SamPredictor(sam)

    score_filter_cfg = {
        "emd": args.emd_filter,
        "purity": args.purity_filter,
        "coverage": args.coverage_filter,
        "score_filter": args.use_score_filter,
        "score": args.deep_score_filter,
        "score_norm": args.deep_score_norm_filter,
        "topk_scores_threshold": args.topk_scores_threshold
    }

    return Matcher(
        encoder=dinov2,
        generator=predictor,
        num_centers=args.num_centers,
        use_box=args.use_box,
        use_points_or_centers=args.use_points_or_centers,
        sample_range=args.sample_range,
        max_sample_iterations=args.max_sample_iterations,
        alpha=args.alpha,
        beta=args.beta,
        exp=args.exp,
        score_filter_cfg=score_filter_cfg,
        num_merging_mask=args.num_merging_mask,
        device=args.device
    )
