import torch
from torch.nn import functional as F

import os
import cv2
import random
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch.nn as nn
import numpy as np
from scipy.ndimage import binary_erosion, binary_dilation, label
from skimage.morphology import remove_small_objects
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from universeg import universeg
from segment_anything import sam_model_registry as per_sam_model_registry
from segment_anything import SamPredictor

# utils
def mask_fusion(high_res_masks, soft_pred,shape_,gamma):
    high_res_masks = high_res_masks/high_res_masks.std()
    soft_pred = (soft_pred-0.5)/soft_pred.std()
    soft_pred = F.interpolate(soft_pred, size=shape_, mode="bilinear")
    mask_final = (1-gamma)*high_res_masks[0]+gamma*soft_pred[0,0]
    return mask_final.cpu().detach().numpy()

class UniSAM_predictor(nn.Module):
    def __init__(
        self,
        alpha,
        delta,
        gamma,
        Context_size,
        pseudo_universeg = 1.0,
        checkpoint = "/userhome/Desam/checkpoint/sam_vit_b_01ec64.pth",
    ) -> None:
        """
        Uses SAM to calculate the image embedding for an image, and then
        allow repeated, efficient mask prediction given prompts.
        """
        super().__init__()
        # hyper parameter setting
        self.alpha = alpha
        self.delta = delta
        self.gamma = gamma
        self.Context_size = Context_size
        self.pseudo_universeg = pseudo_universeg

        # load universeg
        self.model = universeg(pretrained=True).cuda()

        # Load SAM
        sam_type, sam_ckpt = 'vit_b', checkpoint 
	#"/userhome/Desam/checkpoint/sam_vit_b_01ec64.pth"
        self.sam = per_sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
        if 'medsam_vit_b' in sam_ckpt: # change the normalization method of SAM
            self.sam.MedSAM_norm = True
        self.sam.eval()
        self.predictor = SamPredictor(self.sam)
        print('sam.MedSAM_norm:',self.sam.MedSAM_norm)

        # Build the 1 by 1 conv for confidence map
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)

    def set_support_images(
        self,
        test_image_list,
        test_mask_list,
        ):
        self.test_mask_list = test_mask_list
        self.model_LG = compute_logistic_regression(test_image_list, test_mask_list, self.predictor)
        self.C_support = [norm_(i) for i in test_image_list]

        # Assign the parameters of 1 by 1 conv
        self.conv1.weight.data = torch.from_numpy(self.model_LG.coef_.reshape(1, 256, 1, 1)).float()
        self.conv1.bias.data = torch.from_numpy(self.model_LG.intercept_).float()
        self.conv1 = self.conv1.cuda()

    # compute the confidence map for test
    def Image2ConfidenceMap_conv(self, test_image,predictor,model_LG,verbose = False):
        '''
        test_image[1,3,ori_size,ori_size]
        '''
        # Image feature encoding
        predictor.set_image(test_image[0,:].permute(1,2,0).numpy())
        self.test_feat = predictor.features
        confidence_map = torch.sigmoid(self.conv1(predictor.features))
        # Compute the confidence map
        confidence_map = F.interpolate(confidence_map, size=test_image.shape[2:], mode="bilinear")
        if verbose:
            print('test_feat',test_feat.shape,type(test_feat))
            print('confidence_map',confidence_map.shape,type(confidence_map))
        return confidence_map

    def predict(self,
                test_image,
                target_index = 0,# the selected channel for UniverSeg
               ):

        semantic_confidence_map = self.Image2ConfidenceMap_conv(test_image,self.predictor,self.model_LG,verbose = False)


        C_test = norm_(test_image)
        images, support_images, support_labels = get_structured_data(C_test, self.C_support, self.test_mask_list, target_index=target_index)

        # run UniverSeg with Semantic confidence map
        logits = self.model.forward_attention(
            images.float().cuda(),
            support_images.float().cuda(),
            support_labels.float().cuda(),
            semantic_confidence_map.float().cuda(),
                alpha = self.alpha,
            )
        soft_pred = torch.sigmoid(logits)
        self.soft_pred = F.interpolate(soft_pred, size=test_image.shape[2:], mode="bilinear").cpu().detach() # record
        hard_pred = soft_pred.round().clip(0,1)
        hard_pred_oriSize = F.interpolate(hard_pred.detach().cpu(), size=test_image.shape[2:], mode="bilinear")
        hard_pred_alpha = hard_pred_oriSize.numpy()


        # Build BBox
        try:
            pseudo_label = self.soft_pred
            pseudo_label = pseudo_label.round().clip(0,1)
            mask_tmp = process_mask_iter(pseudo_label[0,0,:].numpy())
            if type(mask_tmp) is tuple:
                mask_tmp, other_mask_tmp = mask_tmp[0],mask_tmp[1]
            else:
                other_mask_tmp = []
            bbox = np.array(compute_bounding_box(mask_tmp))
        except:
            # No Component inside the mask
            bbox = np.array([0,0,1,1])
            other_mask_tmp = []

        self.input_box = bbox
        self.other_box = []

        # Prompt
        sim = semantic_confidence_map.clone()
        sim = (sim-0.5) / (sim.std()+1e-8)
        attn_sim = F.interpolate(sim.cuda(), size=(64, 64), mode="bilinear")
        attn_sim *= self.delta
        attn_sim = torch.exp(attn_sim.reshape(1,1,-1))


        # SAM
        masks, scores, _, high_res_masks = self.predictor.predict(
            box=self.input_box, 
            mask_input =  None,
            multimask_output=False,
            attn_sim=attn_sim,  # Target-guided Attention
            target_embedding=None  # Target-semantic Prompting
        )

        # Iterative mask generation
        iter_flag = False
        if len(other_mask_tmp)>0:
            iter_flag = True
            masks = masks[None,:]
            high_res_masks = high_res_masks[None,:]
            for mask_tmp in other_mask_tmp:
                bbox = np.array(compute_bounding_box(mask_tmp))
                self.other_box.append(bbox)
                masks_iter, scores, _, high_res_masks_iter = self.predictor.predict(
                        box=bbox, 
                        mask_input =  None,
                        multimask_output=False,
                        attn_sim=attn_sim,  # Target-guided Attention delta*attn_sim
                        target_embedding=None,  # Target-semantic Prompting
                    )
                high_res_masks = torch.concatenate([high_res_masks, high_res_masks_iter[None,:]])
                masks = np.concatenate([masks, masks_iter[None,:]])
            high_res_masks,_ = high_res_masks.max(dim=0)
            masks = np.logical_or.reduce(masks, axis=0)

        mask_SAM =  masks[0,:]

        mask_final = mask_fusion(high_res_masks, soft_pred,test_image.shape[2:],self.gamma)
        mask_Fuse = mask_final>0
        return hard_pred_alpha, mask_SAM, mask_Fuse

    def predict_detailed(self,
                test_image,
                target_index = 0,# the selected channel for UniverSeg
               ):

        semantic_confidence_map = self.Image2ConfidenceMap_conv(test_image,self.predictor,self.model_LG,verbose = False)

        C_test = norm_(test_image)
        images, support_images, support_labels = get_structured_data(C_test, self.C_support, self.test_mask_list, target_index=target_index)

        logits = self.model.forward_attention(
            images.float().cuda(),
            support_images.float().cuda(),
            support_labels.float().cuda(),
            semantic_confidence_map.float().cuda(),
                alpha = 0,
            )
        soft_pred = torch.sigmoid(logits)
        hard_pred = soft_pred.round().clip(0,1)
        hard_pred_oriSize = F.interpolate(hard_pred.detach().cpu(), size=test_image.shape[2:], mode="bilinear")
        hard_pred_ori = hard_pred_oriSize.numpy()

        # run UniverSeg with Semantic confidence map
        logits = self.model.forward_attention(
            images.float().cuda(),
            support_images.float().cuda(),
            support_labels.float().cuda(),
            semantic_confidence_map.float().cuda(),
                alpha = self.alpha,
            )
        soft_pred = torch.sigmoid(logits)
        self.soft_pred = F.interpolate(soft_pred, size=test_image.shape[2:], mode="bilinear").cpu().detach() # record
        hard_pred = soft_pred.round().clip(0,1)
        hard_pred_oriSize = F.interpolate(hard_pred.detach().cpu(), size=test_image.shape[2:], mode="bilinear")
        hard_pred_alpha = hard_pred_oriSize.numpy()


        # Build BBox
        try:
            pseudo_label = self.soft_pred
            pseudo_label = pseudo_label.round().clip(0,1)
            mask_tmp = process_mask_iter(pseudo_label[0,0,:].numpy())
            if type(mask_tmp) is tuple:
                mask_tmp, other_mask_tmp = mask_tmp[0],mask_tmp[1]
            else:
                other_mask_tmp = []
            bbox = np.array(compute_bounding_box(mask_tmp))
        except:
            # No Component inside the mask
            bbox = np.array([0,0,1,1])
            other_mask_tmp = []

        self.input_box = bbox
        self.other_box = []

        # Prompt
        sim = semantic_confidence_map.clone()
        sim = (sim-0.5) / (sim.std()+1e-8)
        attn_sim = F.interpolate(sim.cuda(), size=(64, 64), mode="bilinear")
        attn_sim *= self.delta
        attn_sim = torch.exp(attn_sim.reshape(1,1,-1))


        # SAM
        masks, scores, _, high_res_masks = self.predictor.predict(
            box=self.input_box, 
            mask_input =  None,
            multimask_output=False,
            attn_sim=attn_sim,  # Target-guided Attention
            target_embedding=None  # Target-semantic Prompting
        )

        # Iterative mask generation
        iter_flag = False
        if len(other_mask_tmp)>0:
            iter_flag = True
            masks = masks[None,:]
            high_res_masks = high_res_masks[None,:]
            for mask_tmp in other_mask_tmp:
                bbox = np.array(compute_bounding_box(mask_tmp))
                self.other_box.append(bbox)
                masks_iter, scores, _, high_res_masks_iter = self.predictor.predict(
                        box=bbox, 
                        mask_input =  None,
                        multimask_output=False,
                        attn_sim=attn_sim,  # Target-guided Attention delta*attn_sim
                        target_embedding=None,  # Target-semantic Prompting
                    )
                high_res_masks = torch.concatenate([high_res_masks, high_res_masks_iter[None,:]])
                masks = np.concatenate([masks, masks_iter[None,:]])
            high_res_masks,_ = high_res_masks.max(dim=0)
            masks = np.logical_or.reduce(masks, axis=0)

        mask_SAM =  masks[0,:]

        mask_final = mask_fusion(high_res_masks, soft_pred,test_image.shape[2:],self.gamma)
        mask_Fuse = mask_final>0
        return hard_pred_ori, hard_pred_alpha, mask_SAM, mask_Fuse