# Copyright (c) OpenMMLab. All rights reserved.
import random

import torch
import inflect
from mmdet.registry import DATASETS
from utils.data.nuimage import NuImageDataset
import numpy as np
import cv2
import json
from collections import defaultdict
from mmdet.visualization import DetLocalVisualizer
from mmdet.visualization.palette import _get_adaptive_scales
import os
from compel import Compel
from utils.plot_utils import create_polygon_mask
from PIL import Image, ImageDraw

# DATASETS = Registry('datasets')

@DATASETS.register_module()
class COCODataset(NuImageDataset):
    CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
         'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
         'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
         'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
         'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
         'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
         'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
         'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
         'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
         'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
         'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
         'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
         'scissors', 'teddy bear', 'hair drier', 'toothbrush']
    p_eng = inflect.engine()
    def __init__ (self,  **kwargs):
        super().__init__(**kwargs)
        

        self.filter_bbox_min_size = 32

        category_counts = defaultdict(int)
        for category_id in self.cat_img_map:
            category_counts[category_id] = len(self.cat_img_map[category_id])

        total_count = sum(list(category_counts.values()))
        self.category_prob = {cat: cnt / total_count for cat, cnt in category_counts.items()}
        self.visualizer = None # DetLocalVisualizer()

        ## Encoder text prompt
        self.background_condition = None
        self.category_conditions = None
        self.n_prefix = None
        self.compel = None
        
        caption_file = 'captions_train2017.json' if self.split == "train" else 'captions_val2017.json'
        with open(os.path.join("data/coco/annotations", caption_file), 'r') as f:
            self.captions_data = json.load(f)

        self.image_id_to_captions = defaultdict(list)
        for annotation in self.captions_data['annotations']:
            self.image_id_to_captions[annotation['image_id']].append(annotation['caption'])

    def get_captions_by_image_id(self, image_id):
        return self.image_id_to_captions.get(image_id, [])

    def get_captions_by_image_filename(self, image_filename, image_id = None):
        for image in self.captions_data['images']:
            if image['file_name'] == image_filename:
                image_id = image['id']
                break
        
        if image_id is not None:
            return self.get_captions_by_image_id(image_id)
        else:
            return []

    def build_prompts(self, pipe):
        compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)

        category_conditions = []  

        for x in self.CLASSES:
            mask_prompt = x
            category_conditions.append(compel.build_conditioning_tensor(mask_prompt))

        background_prompt = "Generate a clean background."  
        background_condition = compel.build_conditioning_tensor(background_prompt)
    
        self.category_conditions = category_conditions
        self.background_condition = background_condition


    def build_text_embedding(self, bbox_labels, spatial_masks, global_description=None):
        bbox_labels = list(set(bbox_labels))

        if global_description:
            global_condition = self.compel.build_conditioning_tensor(global_description)
        else:
            global_condition = self.global_condition

        condition = [global_condition] + self.conditionings
            
        attn_mask = torch.zeros((self.FEAT_SIZE[0], self.FEAT_SIZE[1]), dtype=torch.float32)
        for i, label in enumerate(bbox_labels):
            spatial_indices = torch.where(spatial_masks[:, :, label] == 1)  # Box region
            attn_mask[spatial_indices[0], spatial_indices[1]] = label + 1
        return condition, attn_mask
    
    def build_text_embedding_v1(self, bbox_labels, spatial_masks):
        condition = self.global_condition.clone()

        bbox_labels = list(set(bbox_labels))
        start_index = self.n_prefix
        end_index = start_index + len(bbox_labels)
        
        if end_index > 77:
            raise NotImplementedError
        
        attn_mask = torch.zeros((self.FEAT_SIZE[0], self.FEAT_SIZE[1], 77), dtype=torch.float32)
        attn_mask[:, :, : start_index] = 1.
        attn_mask[:, :, end_index :] = 1.
        
        for i, label in enumerate(bbox_labels):
            class_condition = self.conditionings[label]
            condition[:, start_index + i] = class_condition

            spatial_indices = torch.where(spatial_masks[:, :, label] == 1)  # Box region
            attn_mask[spatial_indices[0], spatial_indices[1], start_index + i] = 1.

        return condition, attn_mask

    def __getitem__(self, idx):
        while True:
            data = self.prepare_data(idx)
            if data is None:
                idx = self._rand_another(idx)
                continue
            
            # pad_shape is the actual input shape even without padding
            # gt_bboxes = data['gt_bboxes']
            bboxes = data['gt_bboxes'] #.data
            labels = [self.CLASSES[each].replace('-', ' ') for each in data['gt_bboxes_labels']] #data['gt_bboxes_labels'].data]
            
            mean, std = data['img_norm_cfg']['mean'], data['img_norm_cfg']['std']
            pil_image = (data['img'] * std) + mean
            pil_image = np.clip(pil_image, 0, 255).astype(np.uint8)
            
            pad_shape = data['pad_shape'] # data['img_metas'].data['pad_shape']
            img_shape = torch.tensor([pad_shape[1], pad_shape[0], pad_shape[1], pad_shape[0]])  # 512
            bboxes /= img_shape
            
            
            instance_polygon_points = [x['mask'] for x in data['instances']]
            total_mask = Image.new('RGB', (data['width'], data['height']), (0, 0, 0))  

            instance_masks = []
            for i, obj in enumerate(instance_polygon_points):
                mask_img = Image.new('L', (data['width'], data['height']), 0)
                obj_color = tuple(random.randint(0, 255) for _ in range(3))

                if isinstance(obj, list):
                    for polygon in obj:
                        ImageDraw.Draw(mask_img, 'L').polygon(polygon, fill=(255))
                        ImageDraw.Draw(total_mask, 'RGB').polygon(polygon, fill=obj_color)
                else:
                    ImageDraw.Draw(mask_img, 'L').rectangle(data['instances'][i]['bbox'], fill=(255))
                    ImageDraw.Draw(total_mask, 'RGB').rectangle(data['instances'][i]['bbox'], fill=obj_color)


                instance_masks.append(mask_img.resize((data['img_shape'][1], data['img_shape'][0])))   # resize to 512
            
            # total_mask = total_mask.resize((data['img_shape'][1], data['img_shape'][0]))       
            total_mask = total_mask.resize((data['width'] // 8, data['height'] // 8))   
            
            total_mask = (np.array(total_mask, dtype=np.float32) > 0).astype(np.float32)
            instance_masks = [(np.array(x, dtype=np.float32) > 0).astype(np.float32) for x in instance_masks]

            objs = {k: 0 for k in set(labels)}
            bbox_mask = torch.zeros(self.FEAT_SIZE[0], self.FEAT_SIZE[1], len(self.CLASSES)).float() # [H, W]
            bbox_labels = []
            for each in range(len(labels)):
                label = labels[each]
                bbox = bboxes[each]

                
                FEAT_SIZE = torch.tensor([self.FEAT_SIZE[1], self.FEAT_SIZE[0], self.FEAT_SIZE[1], self.FEAT_SIZE[0]])
                coord = torch.round(bbox * FEAT_SIZE).int().tolist()
                bbox_mask[coord[1]: coord[3], coord[0]: coord[2], data['gt_bboxes_labels'][each]] = 1.
                bbox_labels.append(self.CLASSES[data['gt_bboxes_labels'][each]])
                objs[label] += 1
            
            if self.prompt_version == "v1":
                text = 'An image with '
                for item in objs.items():
                    text +=  f"{self.p_eng.number_to_words(item[1])} {item[0]}, "
                text = text.rstrip().rstrip(',')
            elif self.prompt_version == "v2":
                text = 'Generate a clean background.'
            
            captions = self.get_captions_by_image_id(data['img_id'])
            if isinstance(captions, list):
                captions = captions[0]
            
            example = {}
            example["pixel_values"] = data['img'].data
            example["pil_image"] = pil_image
            example["text"] = text
            example["caption"] = captions
            # if self.foreground_loss_mode is not None:
            example['seg_mask'] = total_mask
            example['instance_seg_mask'] = instance_masks
            example["bbox_mask"] = bbox_mask
            example["bboxes"]  = bboxes
            example["labels"] = bbox_labels
            example["img_path"] = data["img_path"]
            example["height"] = data["height"]
            example["width"] = data["width"]
            return example 
        