import os
import math
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from PIL import Image
import pickle
import nltk
from nltk.tokenize import word_tokenize
from nltk import pos_tag
nltk.download('averaged_perceptron_tagger')


def preprocess(max_height=256, max_width=256,attn_masks=None):
    # max_height, max_width = 0, 0
    attn_map = []
    for v in attn_masks:
        v = torch.mean(v,axis=0).squeeze(0)
        _, h, w = v.shape
        max_height = max(h, max_height)
        max_width = max(w, max_width)
        v = v.unsqueeze(0)
        v = F.interpolate(
            v,
            size=(max_height, max_width),
            mode='bilinear',
            align_corners=False
        )
        v=v.squeeze(0) # (77,64,64)
        attn_map.append(v)
    attn_map = torch.stack(attn_map, axis=0)
    attn_map = torch.mean(attn_map, axis=0)

    return attn_map



def prompt2tokens(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    tokens = []
    for text_input_id in text_input_ids[0]:
        token = tokenizer.decoder[text_input_id.item()]
        tokens.append(token)
    return tokens

def is_noun(token):
    tag = nltk.pos_tag([token])[0][1]
    return tag.startswith('N')


def top_var_tokens(k, attn_map, tokenizer, prompt):
    # match with tokens
    tokens = prompt2tokens(tokenizer, prompt)
    bos_token = tokenizer.bos_token
    eos_token = tokenizer.eos_token
    pad_token = tokenizer.pad_token
    
    top_k_var_tokens = {}
    
    for i, (token, token_attn_map) in enumerate(zip(tokens, attn_map)):
        if token == bos_token:
            continue
        if token == eos_token:
            break
        token = token.replace('</w>','')
        if not is_noun(token):
            continue
        token_attn_map = token_attn_map.cpu().numpy()
        distribution = token_attn_map.flatten() # my code
        var = np.var(distribution)
        
        if len(top_k_var_tokens) < k:
            top_k_var_tokens[token] = var
            
        else:
            min_k = min(top_k_var_tokens, key=top_k_var_tokens.get)
            if var > top_k_var_tokens[min_k]:
                top_k_var_tokens.pop(min_k)
                top_k_var_tokens[token] = var
    sorted_k_tokens = dict(sorted(top_k_var_tokens.items(), key=lambda item: item[1], reverse=True))
    return sorted_k_tokens        


def binary_mask_gen(prompt_index,num,attn_map, tokenizer, prompt,threshold,save_mask):
    # match with tokens
    tokens = prompt2tokens(tokenizer, prompt)
    bos_token = tokenizer.bos_token
    eos_token = tokenizer.eos_token
    pad_token = tokenizer.pad_token
    
    top_k_var_tokens = {}
    var_token_maps = {}

    for i, (token, token_attn_map) in enumerate(zip(tokens, attn_map)):
        
        if token == bos_token:
            continue
        elif token == eos_token:
            break
        else:
            token = token.replace('</w>','')
            if not is_noun(token):
                continue
            distribution = token_attn_map.flatten()
            normalized_token_attn_map = (token_attn_map - torch.min(token_attn_map)) / (torch.max(token_attn_map) - torch.min(token_attn_map)) * 255 
            
            var = torch.var(distribution)
            
            # top k variance tokens and attention maps
            if len(top_k_var_tokens) < num:
                top_k_var_tokens[token] = var
                var_token_maps[token] = normalized_token_attn_map
            else:
                min_k = min(top_k_var_tokens, key=top_k_var_tokens.get)
                if var > top_k_var_tokens[min_k]:
                    top_k_var_tokens.pop(min_k)
                    top_k_var_tokens[token] = var
                
                    var_token_maps.pop(min_k)
                    var_token_maps[token] = normalized_token_attn_map
            
                
    
    #sorted_var_tokens = dict(sorted(top_k_var_tokens.items(), key=lambda item: item[1], reverse=True))
    #ordered_var_token_maps = {key: var_token_maps[key] for key in sorted_var_tokens}
    # Find max variance token
    max_var_token = max(top_k_var_tokens, key=top_k_var_tokens.get)
    max_var_token_map = var_token_maps[max_var_token]
    var_mask = torch.where(max_var_token_map > threshold, 255, 0)
    #for key,value in ordered_var_token_maps.items():
    #    var_mask = torch.where(value > threshold, 255, 0)
        
        
    return var_mask


def optimize_binary_mask(binary_mask):
    binary_mask = binary_mask.to('cpu').numpy().astype(np.uint8)
    num_labels, labels = cv2.connectedComponents(binary_mask, connectivity=4)
    output_mask = np.zeros(binary_mask.shape, dtype=np.uint8)

    for label in range(1, num_labels):  # Skip background (label 0)
        component = (labels == label).astype(np.uint8)
        if np.sum(component) >= 1000:
            output_mask[labels == label] = 255
    
    output_mask = torch.from_numpy(output_mask).to('cuda')
    return output_mask

##########################################################################################
from typing import Dict, Optional, Tuple, Union


def dilate_mask(
    mask: Union[torch.Tensor, np.ndarray], dilation: Union[int, Tuple[int, int]]  # [C, H, W] or [H, W]
) -> Union[torch.Tensor, np.ndarray]:
    if isinstance(dilation, int):
        dilation = (dilation, dilation)

    if dilation[0] <= 0 and dilation[1] <= 0:
        return mask

    if isinstance(mask, torch.Tensor):
        ret = mask.clone()
    else:
        assert isinstance(mask, np.ndarray)
        ret = mask.copy()

    if len(ret.shape) == 2:
        for i in range(1, dilation[0] + 1):
            ret[:-i] |= mask[i:]
            ret[i:] |= mask[:-i]
        for i in range(1, dilation[1] + 1):
            ret[:, :-i] |= mask[:, i:]
            ret[:, i:] |= mask[:, :-i]
    elif len(ret.shape) == 3:
        for i in range(1, dilation + 1):
            ret[:, :-i] |= mask[:, i:]
            ret[:, i:] |= mask[:, :-i]
        for i in range(1, dilation[1] + 1):
            ret[:, :, :-i] |= mask[:, :, i:]
            ret[:, :, i:] |= mask[:, :, :-i]
    else:
        raise NotImplementedError("Unknown mask dimension [%d]!!!" % mask.dim())
    return ret





def downsample_mask(
    mask: torch.Tensor,
    min_res: Union[int, Tuple[int, int]] = 4,
    dilation: Union[int, Tuple[int, int]] = 1,
    threshold: float = 0.3,
    eps: float = 1e-3,
) -> Dict[Tuple[int, int], torch.Tensor]:
    assert mask.dim() == 2
    H, W = mask.shape
    if isinstance(min_res, int):
        min_h = min_res
        min_w = min_res
    else:
        min_h, min_w = min_res
    h = H
    w = W

    masks = {}
    interpolated_mask = mask.view(1, 1, H, W).float()
    while True:
        t = min(threshold, interpolated_mask.max() - eps)
        sparsity_mask = interpolated_mask[0, 0] > t
        sparsity_mask = dilate_mask(sparsity_mask, dilation)
        masks[(h, w)] = sparsity_mask
        h //= 2
        w //= 2
        if h < min_h and w < min_w:
            break
        else:
            interpolated_mask = F.interpolate(interpolated_mask, (h, w), mode="bilinear", align_corners=False)
    return masks
