import os
from os import path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import numpy as np
import cv2
import ot
import math
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt

from segment_anything import sam_model_registry, SamPredictor
from segment_anything import SamAutomaticMaskGenerator
from dinov2.models import vision_transformer as vits
import dinov2.utils.utils as dinov2_utils
from dinov2.data.transforms import MaybeToTensor, make_normalize_transform

from matcher.k_means import kmeans_pp

from segment_anything.utils.amg import (
    batch_iterator, 
)

from scipy.sparse import csgraph


class Matcher:
    def __init__(
            self,
            encoder,
            generator=None,
            input_size=1024,
            num_centers=8,
            use_box=False,
            use_points_or_centers=True,
            sample_range=(4, 6),
            max_sample_iterations=30,
            alpha=1.,
            beta=0.,
            exp=0.,
            score_filter_cfg=None,
            num_merging_mask=10,
            device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    ):
        # models
        self.encoder = encoder
        self.predictor = generator
        self.rps = None

        if not isinstance(input_size, tuple):
            input_size = (input_size, input_size)
        self.input_size = input_size

        img_size = 518
        feat_size = img_size // self.encoder.patch_size

        self.encoder_img_size = img_size
        self.encoder_feat_size = feat_size

        # transforms for image encoder
        self.encoder_transform = transforms.Compose([
            MaybeToTensor(),
            transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
            make_normalize_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

        self.tar_img = None
        self.tar_img_np = None

        self.ref_imgs = None
        self.ref_masks_pool = None
        self.nshot = None

        # self.encoder_img_size = None
        # self.encoder_feat_size = None

        self.num_centers = num_centers
        self.use_box = use_box
        self.use_points_or_centers = use_points_or_centers
        self.sample_range = sample_range
        self.max_sample_iterations =max_sample_iterations

        self.alpha, self.beta, self.exp = alpha, beta, exp
        assert score_filter_cfg is not None
        self.score_filter_cfg = score_filter_cfg
        self.num_merging_mask = num_merging_mask

        self.device = device
        self.count = 0

        self.vis = False
        self.label_colors = np.array([[0, 0, 255], [0, 255, 0]])
        self.label_colors = np.concatenate([self.label_colors, np.random.randint(0, 255, (400, 3))])

    def set_reference(self, imgs, masks):

        def reference_masks_verification(masks):
            if masks.sum() == 0:
                _, _, sh, sw = masks.shape
                masks[..., (sh // 2 - 7):(sh // 2 + 7), (sw // 2 - 7):(sw // 2 + 7)] = 1
            return masks

        imgs = imgs.flatten(0, 1)  # bs, 3, h, w

        # process reference masks
        masks = reference_masks_verification(masks)
        masks = masks.permute(1, 0, 2, 3)  # ns, 1, h, w
        nshot = masks.shape[0]

        self.ref_imgs = imgs
        self.nshot = nshot
        self.ref_masks = masks
        print(imgs.shape, masks.shape)
        self.ref_imgs_np = imgs.mul(255).byte().permute(0, 2, 3, 1).cpu().numpy()
        self.ref_masks_np = masks.mul(255).byte().permute(0, 2, 3, 1).cpu().numpy()

    def set_target(self, img):

        img_h, img_w = img.shape[-2:]
        assert img_h == self.input_size[0] and img_w == self.input_size[1]

        # transform query to numpy as input of sam
        img_np = img.mul(255).byte()
        img_np = img_np.squeeze(0).permute(1, 2, 0).cpu().numpy()

        self.tar_img = img
        self.tar_img_np = img_np

    def show_anns(self, masks, img, sfx='sup', coords=None, labels=None):
        if self.vis == False:
            return

        img = img.permute(1, 2, 0).cpu().numpy()
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) * 255
        img = cv2.resize(img, (400, 400))

        if len(masks) == 0:
            return

        masks = F.interpolate(torch.from_numpy(masks).unsqueeze(0).float(), (img.shape[0], img.shape[1]))
        masks = masks[0].cpu().numpy()


        for idx, ann in enumerate(masks):
            m = ann > 0.5
            # if len(masks) == 1:
            #     color_mask = np.array([0, 200, 0])  # Bright green color
            # else:
            #     color_mask = np.random.randint(0, 255, (3,))
            color_mask = self.label_colors[idx]
            img[m] += color_mask * 0.4

        if coords is not None:
            for coord, label in zip(coords, labels):
                color = (0, 0, 255) if label == 1 else (0, 0, 0)
                cv2.circle(img, (int(coord[0] / 1024 *400), int(coord[1]/1024*400)), 5, color, -1)

        cv2.imwrite(f'output/figs/{self.count}_{sfx}.png', img)

    def show_heatmap(self, simmap, img, sfx='sim'):
        if self.vis == False:
            return

        # img = img.permute(1, 2, 0).cpu().numpy()
        # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        # generate heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * simmap[0]), cv2.COLORMAP_VIRIDIS)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        heatmap = cv2.resize(heatmap, (400, 400), interpolation=cv2.INTER_NEAREST)

        # superimpose the heatmap on original image
        # superimposed_img = heatmap * 0.5 + img
        cv2.imwrite(f'output/figs/{self.count}_{sfx}.png', heatmap)

    def show_points(self, img, coords, labels, color=None, sfx='points'):
        if self.vis == False:
            return
        img = img.permute(1, 2, 0).cpu().numpy()
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) * 255

        # generate colors for each label
        # if np.max(labels) == 1:
        #     label_colors = np.array([[0, 0, 255], [0, 255, 0]])
        # else:
        #     label_colors = np.random.randint(0, 255, (np.max(labels)+1, 3))
        # print(f"label_colors: {label_colors}")

        for coord, label in zip(coords, labels):
            if color is None:
                # a color for each label
                color_now = tuple(self.label_colors[label].tolist())
            else:
                color_now = color
            # color = (0, 0, 255) if label == 1 else (0, 0, 0)  # Red for label 1, black for others
            cv2.circle(img, (int(coord[0]), int(coord[1])), 10, color_now, -1)

        img = cv2.resize(img, (400, 400))
        cv2.imwrite(f'output/figs/{self.count}_{sfx}.png', img)

    def predict(self):

        ref_mask = self.predictor.set_image(self.ref_imgs_np[0], self.ref_masks_np[0])
        ref_feat = self.predictor.features.squeeze().permute(1, 2, 0)

        gt_mask = (ref_mask > 0).float()[:, 0, :, :].flatten(1)
        ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
        ref_mask = ref_mask.squeeze()[0]


        # Target feature extraction
        target_feat = ref_feat[ref_mask > 0]
        target_embedding = target_feat.mean(0).unsqueeze(0)
        target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
        target_embedding = target_embedding.unsqueeze(0)

        # Cosine similarity
        h, w, C = ref_feat.shape
        target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
        ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
        ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
        sim = target_feat @ ref_feat

        sim = sim.reshape(1, 1, h, w)
        sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
        sim = self.predictor.model.postprocess_masks(
                        sim,
                        input_size=self.input_size,
                        original_size=self.predictor.original_size).squeeze()

        # Positive-negative location prior
        topk_xy, topk_label = self.point_selection(sim, topk=1)

        print('======> Start Training')
        # Learnable mask weights
        mask_weights = Mask_Weights().cuda()
        mask_weights.train()
        
        optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)

        for train_idx in range(1000):

            # Run the decoder
            masks, scores, logits, logits_high = self.predictor.predict(
                point_coords=topk_xy,
                point_labels=topk_label,
                multimask_output=True)
            logits_high = logits_high.flatten(1)

            # Weighted sum three-scale masks
            weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
            logits_high = logits_high * weights
            logits_high = logits_high.sum(0).unsqueeze(0)

            dice_loss = self.calculate_dice_loss(logits_high, gt_mask)
            focal_loss = self.calculate_sigmoid_focal_loss(logits_high, gt_mask)
            loss = dice_loss + focal_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if train_idx % 200 == 0:
                print('Train Epoch: {:} / {:}'.format(train_idx, 1000))
                current_lr = scheduler.get_last_lr()[0]
                print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))


        mask_weights.eval()
        weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
        weights_np = weights.detach().cpu().numpy()
        print('======> Mask weights:\n', weights_np)


        self.predictor.set_image(self.tar_img_np)
        test_feat = self.predictor.features.squeeze()

        # Cosine similarity
        C, h, w = test_feat.shape
        test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
        test_feat = test_feat.reshape(C, h * w)
        sim = target_feat @ test_feat

        sim = sim.reshape(1, 1, h, w)
        sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
        sim = self.predictor.model.postprocess_masks(
                        sim,
                        input_size=self.input_size,
                        original_size=self.predictor.original_size).squeeze()

        # Positive location prior
        topk_xy, topk_label = self.point_selection(sim, topk=1)

        # First-step prediction
        masks, scores, logits, logits_high = self.predictor.predict(
                    point_coords=topk_xy,
                    point_labels=topk_label,
                    multimask_output=True)
        best_idx = 0

        # Weighted sum three-scale masks
        logits_high = logits_high * weights.unsqueeze(-1)
        logit_high = logits_high.sum(0)
        mask = (logit_high > 0).detach().cpu().numpy()

        logits = logits * weights_np[..., None]
        logit = logits.sum(0)

        # Cascaded Post-refinement-1
        y, x = np.nonzero(mask)
        if x.shape[0] == 0:
            return torch.zeros((1, 1024, 1024)).cuda()
        x_min = x.min()
        x_max = x.max()
        y_min = y.min()
        y_max = y.max()
        input_box = np.array([x_min, y_min, x_max, y_max])
        masks, scores, logits, _ = self.predictor.predict(
            point_coords=topk_xy,
            point_labels=topk_label,
            box=input_box[None, :],
            mask_input=logit[None, :, :],
            multimask_output=True)
        best_idx = np.argmax(scores)

        # Cascaded Post-refinement-2
        y, x = np.nonzero(masks[best_idx])
        x_min = x.min()
        x_max = x.max()
        y_min = y.min()
        y_max = y.max()
        input_box = np.array([x_min, y_min, x_max, y_max])
        masks, scores, logits, _ = self.predictor.predict(
            point_coords=topk_xy,
            point_labels=topk_label,
            box=input_box[None, :],
            mask_input=logits[best_idx: best_idx + 1, :, :],
            multimask_output=True)
        best_idx = np.argmax(scores)

        # Save masks
        # plt.figure(figsize=(10, 10))
        # plt.imshow(test_image)
        # show_mask(masks[best_idx], plt.gca())
        # show_points(topk_xy, topk_label, plt.gca())
        # plt.title(f"Mask {best_idx}", fontsize=18)
        # plt.axis('off')
        # vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}.jpg')
        # with open(vis_mask_output_path, 'wb') as outfile:
        #     plt.savefig(outfile, format='jpg')

        final_mask = masks[best_idx]
        final_mask = torch.tensor(final_mask).cuda().unsqueeze(0).float()
        # mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
        # mask_colors[final_mask, :] = np.array([[0, 0, 128]])
        # mask_output_path = os.path.join(output_path, test_idx + '.png')
        # cv2.imwrite(mask_output_path, mask_colors)
        return final_mask


    def point_selection(self, mask_sim, topk=1):
        # Top-1 point selection
        w, h = mask_sim.shape
        topk_xy = mask_sim.flatten(0).topk(topk)[1]
        topk_x = (topk_xy // h).unsqueeze(0)
        topk_y = (topk_xy - topk_x * h)
        topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
        topk_label = np.array([1] * topk)
        topk_xy = topk_xy.cpu().numpy()
        
        return topk_xy, topk_label


    def calculate_dice_loss(self, inputs, targets, num_masks = 1):
        """
        Compute the DICE loss, similar to generalized IOU for masks
        Args:
            inputs: A float tensor of arbitrary shape.
                    The predictions for each example.
            targets: A float tensor with the same shape as inputs. Stores the binary
                    classification label for each element in inputs
                    (0 for the negative class and 1 for the positive class).
        """
        inputs = inputs.sigmoid()
        inputs = inputs.flatten(1)
        numerator = 2 * (inputs * targets).sum(-1)
        denominator = inputs.sum(-1) + targets.sum(-1)
        loss = 1 - (numerator + 1) / (denominator + 1)
        return loss.sum() / num_masks


    def calculate_sigmoid_focal_loss(self, inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
        """
        Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
        Args:
            inputs: A float tensor of arbitrary shape.
                    The predictions for each example.
            targets: A float tensor with the same shape as inputs. Stores the binary
                    classification label for each element in inputs
                    (0 for the negative class and 1 for the positive class).
            alpha: (optional) Weighting factor in range (0,1) to balance
                    positive vs negative examples. Default = -1 (no weighting).
            gamma: Exponent of the modulating factor (1 - p_t) to
                balance easy vs hard examples.
        Returns:
            Loss tensor
        """
        prob = inputs.sigmoid()
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        p_t = prob * targets + (1 - prob) * (1 - targets)
        loss = ce_loss * ((1 - p_t) ** gamma)

        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            loss = alpha_t * loss

        return loss.mean(1).sum() / num_masks


    def clear(self):

        self.tar_img = None
        self.tar_img_np = None

        self.ref_imgs = None
        self.ref_masks_pool = None
        self.nshot = None

        # self.encoder_img_size = None
        # self.encoder_feat_size = None


class Mask_Weights(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)


def build_matcher_oss(args):

    # DINOv2, Image Encoder
    dinov2_kwargs = dict(
        img_size=518,
        patch_size=14,
        init_values=1e-5,
        ffn_layer='mlp',
        block_chunks=0,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
    )
    dinov2 = vits.__dict__[args.dinov2_size](**dinov2_kwargs)

    dinov2_utils.load_pretrained_weights(dinov2, args.dinov2_weights, "teacher")
    dinov2.eval()
    dinov2.to(device=args.device)

    # SAM
    sam = sam_model_registry[args.sam_size](checkpoint=args.sam_weights)
    sam.to(device=args.device)
    for name, param in sam.named_parameters():
        param.requires_grad = False
    predictor = SamPredictor(sam)

    score_filter_cfg = {
        "emd": args.emd_filter,
        "purity": args.purity_filter,
        "coverage": args.coverage_filter,
        "score_filter": args.use_score_filter,
        "score": args.deep_score_filter,
        "score_norm": args.deep_score_norm_filter,
        "topk_scores_threshold": args.topk_scores_threshold
    }

    return Matcher(
        encoder=dinov2,
        generator=predictor,
        num_centers=args.num_centers,
        use_box=args.use_box,
        use_points_or_centers=args.use_points_or_centers,
        sample_range=args.sample_range,
        max_sample_iterations=args.max_sample_iterations,
        alpha=args.alpha,
        beta=args.beta,
        exp=args.exp,
        score_filter_cfg=score_filter_cfg,
        num_merging_mask=args.num_merging_mask,
        device=args.device
    )
