import os
import math
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from PIL import Image
import pickle
import nltk
from nltk.tokenize import word_tokenize
from nltk import pos_tag
nltk.download('averaged_perceptron_tagger')


def preprocess(max_height=256, max_width=256,attn_masks=None):
    # max_height, max_width = 0, 0
    attn_map = []
    for v in attn_masks:
        v = torch.mean(v,axis=0).squeeze(0)
        _, h, w = v.shape
        max_height = max(h, max_height)
        max_width = max(w, max_width)
        v = v.unsqueeze(0)
        v = F.interpolate(
            v,
            size=(max_height, max_width),
            mode='bilinear',
            align_corners=False
        )
        v=v.squeeze(0) # (77,64,64)
        attn_map.append(v)
    attn_map = torch.stack(attn_map, axis=0)
    attn_map = torch.mean(attn_map, axis=0)

    return attn_map



def prompt2tokens(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    tokens = []
    for text_input_id in text_input_ids[0]:
        token = tokenizer.decoder[text_input_id.item()]
        tokens.append(token)
    return tokens

def is_noun(token):
    tag = nltk.pos_tag([token])[0][1]
    return tag.startswith('N')


def top_var_tokens(k, attn_map, tokenizer, prompt):
    # match with tokens
    tokens = prompt2tokens(tokenizer, prompt)
    bos_token = tokenizer.bos_token
    eos_token = tokenizer.eos_token
    pad_token = tokenizer.pad_token
    
    top_k_var_tokens = {}
    
    for i, (token, token_attn_map) in enumerate(zip(tokens, attn_map)):
        if token == bos_token:
            continue
        if token == eos_token:
            break
        token = token.replace('</w>','')
        if not is_noun(token):
            continue
        token_attn_map = token_attn_map.cpu().numpy()
        distribution = token_attn_map.flatten() # my code
        var = np.var(distribution)
        
        if len(top_k_var_tokens) < k:
            top_k_var_tokens[token] = var
            
        else:
            min_k = min(top_k_var_tokens, key=top_k_var_tokens.get)
            if var > top_k_var_tokens[min_k]:
                top_k_var_tokens.pop(min_k)
                top_k_var_tokens[token] = var
    sorted_k_tokens = dict(sorted(top_k_var_tokens.items(), key=lambda item: item[1], reverse=True))
    return sorted_k_tokens        


def binary_mask_gen(prompt_index,num,attn_map, tokenizer, prompt,threshhold,save_mask):
    # match with tokens
    tokens = prompt2tokens(tokenizer, prompt)
    bos_token = tokenizer.bos_token
    eos_token = tokenizer.eos_token
    pad_token = tokenizer.pad_token
    save_path = 'test_data/paper_results/max_sum_var/'
    
    # if not os.path.exists(save_path):
    #     os.mkdir(save_path)
    all_noun_tokens = {}
    top_k_var_tokens = {}
    var_token_maps = {}
    top_k_max_tokens = {}
    max_token_maps = {}
    top_k_sum_tokens = {}
    sum_token_maps = {}
    
    # to_pil = transforms.ToPILImage()
    for i, (token, token_attn_map) in enumerate(zip(tokens, attn_map)):
        
        if token == bos_token:
            continue
        elif token == eos_token:
            break
        else:
            token = token.replace('</w>','')
            if not is_noun(token):
                continue
            distribution = token_attn_map.flatten()
            normalized_token_attn_map = (token_attn_map - torch.min(token_attn_map)) / (torch.max(token_attn_map) - torch.min(token_attn_map)) * 255 
            
            var = torch.var(distribution)
            max_v = token_attn_map.max().item()
            sum_v = torch.sum(token_attn_map).item()
            
            # top k variance tokens and attention maps
            if len(top_k_var_tokens) < num:
                top_k_var_tokens[token] = var
                var_token_maps[token] = normalized_token_attn_map
            else:
                min_k = min(top_k_var_tokens, key=top_k_var_tokens.get)
                if var > top_k_var_tokens[min_k]:
                    top_k_var_tokens.pop(min_k)
                    top_k_var_tokens[token] = var
                
                    var_token_maps.pop(min_k)
                    var_token_maps[token] = normalized_token_attn_map
            
            # top k max tokens and attention maps
            if len(top_k_max_tokens) < num:
                top_k_max_tokens[token] = max_v
                max_token_maps[token] = normalized_token_attn_map
            else:
                min_k = min(top_k_max_tokens, key=top_k_max_tokens.get)
                if max_v > top_k_max_tokens[min_k]:
                    top_k_max_tokens.pop(min_k)
                    top_k_max_tokens[token] = max_v
                
                    max_token_maps.pop(min_k)
                    max_token_maps[token] = normalized_token_attn_map
            
            # top k sum tokens and attention maps
            if len(top_k_sum_tokens) < num:
                top_k_sum_tokens[token] = sum_v
                sum_token_maps[token] = normalized_token_attn_map
            else:
                min_k = min(top_k_sum_tokens, key=top_k_sum_tokens.get)
                if sum_v > top_k_sum_tokens[min_k]:
                    top_k_sum_tokens.pop(min_k)
                    top_k_sum_tokens[token] = sum_v
                
                    sum_token_maps.pop(min_k)
                    sum_token_maps[token] = normalized_token_attn_map
                
    
    sorted_var_tokens = dict(sorted(top_k_var_tokens.items(), key=lambda item: item[1], reverse=True))
    # sorted_all_tokens = dict(sorted(all_noun_tokens.items(), key=lambda item: item[1], reverse=True))
    ordered_var_token_maps = {key: var_token_maps[key] for key in sorted_var_tokens}
    
    sorted_max_tokens = dict(sorted(top_k_max_tokens.items(), key=lambda item: item[1], reverse=True))
    ordered_max_token_maps = {key: max_token_maps[key] for key in sorted_max_tokens}
    
    sorted_sum_tokens = dict(sorted(top_k_sum_tokens.items(), key=lambda item: item[1], reverse=True))
    ordered_sum_token_maps = {key: sum_token_maps[key] for key in sorted_sum_tokens}

    #write to file
    # if threshhold == 80 and save_mask:
        
    #     with open('test_data/prompt_noun_variance_sort.txt', 'a') as f:
    #         f.write("prompt index {}: ".format(prompt_index))
    #         for k,v in sorted_all_tokens.items():
    #             f.write("{} ".format(k))
    #         f.write("\n") 
    
    var_mask = None
    max_mask = None
    sum_mask = None
    # masks = []
    if sorted_sum_tokens.keys() == sorted_max_tokens.keys() == sorted_var_tokens.keys():
        return sorted_var_tokens, var_mask
    else:     
        for key,value in ordered_var_token_maps.items():
            var_mask = torch.where(value > threshhold, 255, 0)
            value = value.cpu().numpy().astype(np.uint8)
            var_mask = var_mask.cpu().numpy().astype(np.uint8)
            var_image = Image.fromarray(value)
            var_b_image = Image.fromarray(var_mask)
            width1, height1 = var_image.size
            width2, height2 = var_b_image.size

            # Create a new blank image with the appropriate size
            grid_width = max(width1, width2)
            grid_height = height1 + height2
            grid_image1 = Image.new('RGB', (grid_width, grid_height))

            # Paste the two images into the new blank image
            grid_image1.paste(var_image, (0, 0))
            grid_image1.paste(var_b_image, (0, height1))

            # Save the combined image
            grid_image1.save(os.path.join(save_path,f'{prompt_index}_{key}_var.png'))   
        for key,value in ordered_max_token_maps.items():
            max_mask = torch.where(value > threshhold, 255, 0)
            value = value.cpu().numpy().astype(np.uint8)
            max_mask = max_mask.cpu().numpy().astype(np.uint8)
            
            max_image = Image.fromarray(value)
            max_b_image = Image.fromarray(max_mask)
            width1, height1 = max_image.size
            width2, height2 = max_b_image.size

            # Create a new blank image with the appropriate size
            grid_width = max(width1, width2)
            grid_height = height1 + height2
            grid_image2 = Image.new('RGB', (grid_width, grid_height))

            # Paste the two images into the new blank image
            grid_image2.paste(max_image, (0, 0))
            grid_image2.paste(max_b_image, (0, height1))

            # Save the combined image
            grid_image2.save(os.path.join(save_path,f'{prompt_index}_{key}_max.png'))   
            
            
        for key,value in ordered_sum_token_maps.items():
            sum_mask = torch.where(value > threshhold, 255, 0)
            
            value = value.cpu().numpy().astype(np.uint8)
            sum_mask = sum_mask.cpu().numpy().astype(np.uint8)
            
            sum_image = Image.fromarray(value)
            sum_b_image = Image.fromarray(sum_mask)
            width1, height1 = sum_image.size
            width2, height2 = sum_b_image.size

            # Create a new blank image with the appropriate size
            grid_width = max(width1, width2)
            grid_height = height1 + height2
            grid_image3 = Image.new('RGB', (grid_width, grid_height))

            # Paste the two images into the new blank image
            grid_image3.paste(sum_image, (0, 0))
            grid_image3.paste(sum_b_image, (0, height1))

            # Save the combined image
            grid_image3.save(os.path.join(save_path,f'{prompt_index}_{key}_sum.png'))   
        
            index_list = []
            if os.path.exists("test_data/paper_results/max_sum_var/index.pkl"):
                with open("test_data/paper_results/max_sum_var/index.pkl",'rb') as file:
                    index_list = pickle.load(file)
                
            index_list.append(prompt_index)  
            with open("test_data/paper_results/max_sum_var/index.pkl", 'wb') as file1:
                pickle.dump(index_list, file1)
          
        
        
        
        
        # masks.append(mask)
        # if save_mask:
        #     image = Image.fromarray(value)
        #     # image.save(os.path.join(save_path, key))
        #     b_image = Image.fromarray(mask)
        #     # b_image.save(os.path.join(save_path, 'mask.jpg'))
        #     # Get the size of the images
        #     width1, height1 = image.size
        #     width2, height2 = b_image.size

        #     # Create a new blank image with the appropriate size
        #     grid_width = max(width1, width2)
        #     grid_height = height1 + height2
        #     grid_image = Image.new('RGB', (grid_width, grid_height))

        #     # Paste the two images into the new blank image
        #     grid_image.paste(image, (0, 0))
        #     grid_image.paste(b_image, (0, height1))

        #     # Save the combined image
        #     grid_image.save(os.path.join(save_path,f'{prompt_index}_{key}.png'))
    #save image
    # mask = None
    # for key,value in token_maps.items():
    #     mask = np.where(value > threshhold, 255, 0).astype(np.uint8)
    #     if save_mask:
    #         image = Image.fromarray(value)
    #         # image.save(os.path.join(save_path, key))
    #         b_image = Image.fromarray(mask)
    #         # b_image.save(os.path.join(save_path, 'mask.jpg'))
    #         # Get the size of the images
    #         width1, height1 = image.size
    #         width2, height2 = b_image.size

    #         # Create a new blank image with the appropriate size
    #         grid_width = max(width1, width2)
    #         grid_height = height1 + height2
    #         grid_image = Image.new('RGB', (grid_width, grid_height))

    #         # Paste the two images into the new blank image
    #         grid_image.paste(image, (0, 0))
    #         grid_image.paste(b_image, (0, height1))

    #         # Save the combined image
    #         grid_image.save(os.path.join(save_path,f'{prompt_index}_{key}.png'))
        
        
    return sorted_var_tokens, var_mask     


def optimize_binary_mask(binary_mask):
    binary_mask = binary_mask.to('cpu').numpy().astype(np.uint8)
    num_labels, labels = cv2.connectedComponents(binary_mask, connectivity=4)
    output_mask = np.zeros(binary_mask.shape, dtype=np.uint8)

    for label in range(1, num_labels):  # Skip background (label 0)
        component = (labels == label).astype(np.uint8)
        if np.sum(component) >= 1000:
            output_mask[labels == label] = 255
    
    output_mask = torch.from_numpy(output_mask).to('cuda')
    return output_mask

##########################################################################################
from typing import Dict, Optional, Tuple, Union


def dilate_mask(
    mask: Union[torch.Tensor, np.ndarray], dilation: Union[int, Tuple[int, int]]  # [C, H, W] or [H, W]
) -> Union[torch.Tensor, np.ndarray]:
    if isinstance(dilation, int):
        dilation = (dilation, dilation)

    if dilation[0] <= 0 and dilation[1] <= 0:
        return mask

    if isinstance(mask, torch.Tensor):
        ret = mask.clone()
    else:
        assert isinstance(mask, np.ndarray)
        ret = mask.copy()

    if len(ret.shape) == 2:
        for i in range(1, dilation[0] + 1):
            ret[:-i] |= mask[i:]
            ret[i:] |= mask[:-i]
        for i in range(1, dilation[1] + 1):
            ret[:, :-i] |= mask[:, i:]
            ret[:, i:] |= mask[:, :-i]
    elif len(ret.shape) == 3:
        for i in range(1, dilation + 1):
            ret[:, :-i] |= mask[:, i:]
            ret[:, i:] |= mask[:, :-i]
        for i in range(1, dilation[1] + 1):
            ret[:, :, :-i] |= mask[:, :, i:]
            ret[:, :, i:] |= mask[:, :, :-i]
    else:
        raise NotImplementedError("Unknown mask dimension [%d]!!!" % mask.dim())
    return ret





def downsample_mask(
    mask: torch.Tensor,
    min_res: Union[int, Tuple[int, int]] = 4,
    dilation: Union[int, Tuple[int, int]] = 1,
    threshold: float = 0.3,
    eps: float = 1e-3,
) -> Dict[Tuple[int, int], torch.Tensor]:
    assert mask.dim() == 2
    H, W = mask.shape
    if isinstance(min_res, int):
        min_h = min_res
        min_w = min_res
    else:
        min_h, min_w = min_res
    h = H
    w = W

    masks = {}
    interpolated_mask = mask.view(1, 1, H, W).float()
    while True:
        t = min(threshold, interpolated_mask.max() - eps)
        sparsity_mask = interpolated_mask[0, 0] > t
        sparsity_mask = dilate_mask(sparsity_mask, dilation)
        masks[(h, w)] = sparsity_mask
        h //= 2
        w //= 2
        if h < min_h and w < min_w:
            break
        else:
            interpolated_mask = F.interpolate(interpolated_mask, (h, w), mode="bilinear", align_corners=False)
    return masks
