# -*- coding: utf-8 -*-
"""detr_demo.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb

# Object Detection with DETR - a minimal implementation

In this notebook we show a demo of DETR (Detection Transformer), with slight differences with the baseline model in the paper.

We show how to define the model, load pretrained weights and visualize bounding box and class predictions.

Let's start with some common imports.
"""

# Cogqanted out IPython magic to ensure Python compatibility.
from PIL import Image
import matplotlib.pyplot as plt
# %config InlineBackend.figure_format = 'retina'
import os

import torch
from torch import nn
from torchvision.models import resnet50
import json
import torchvision.transforms as T
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import DetrForObjectDetection

torch.set_grad_enabled(False)

"""## DETR
Here is a minimal implementation of DETR:
"""

class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # repeat pos and query_pos for each element in the batch
        batch_size = h.shape[0]  # assuming h has shape [batch_size, hidden_size, ...]
        pos = pos.expand(-1, batch_size, -1)  # repeat along batch dimension

        query_pos = self.query_pos.unsqueeze(1).expand(-1, batch_size, -1)  # repeat along batch dimension


        # propagate through the transformer
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             query_pos).transpose(0, 1)

        # print(h.shape) # [1, 100, 256] 100 is the number of object queries, 256 is the hidden dimension
        # finally project transformer outputs to class labels and bounding boxes
        # in CLIPVisionModel, we use the hidden state of the -2 layer as the image feature, select_layer: -2
        # image_features: torch.Size([1, 257, 1024]), 257 includes the cls token

        return {'pred_logits': self.linear_class(h),
                'pred_boxes': self.linear_bbox(h).sigmoid(),'hidden':h}


# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, shape):

    img_h, img_w = shape[-2:]

    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


class ImageDataset(Dataset):
    def __init__(self, image_folder, image_paths):
        # image_paths = os.listdir(image_folder)
        self.image_folder = image_folder
        # self.image_paths = image_paths
        # print(f"image_paths: {len(image_paths)} {image_paths[0]}")
        # need to stack batch, so resize to same size of width and height
        self.transform = transforms.Compose([
            transforms.Resize((800,800)), #TODO:check if this is the right way to resize
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        # question_file = "/data/linxi/workspace/LLaVA/playground/data/coco2014_val_gpt4_qa_30x3.jsonl"
        # self.image_paths = [json.loads(line)["image"] for line in open(question_file, "r")]
        self.image_paths = image_paths
        print(f"image_paths: {len(self.image_paths)} {self.image_paths[0]}")


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        if not os.path.exists(os.path.join(self.image_folder, image_path)):
            raise ValueError(f"image_path: {image_path} does not exist")
        image = Image.open(os.path.join(self.image_folder, image_path)).convert('RGB')
        # if image.shape[0] == 3:
        #     print(f"image_path: {image_path}")
        #     # expand to 3 channels
        #     image = image.convert('RGB')
        #     print(f"image shape: {image.shape}")
            
        
        # image_tensor = self.transform(image)
        # image to tensor
        # resize to 800x800
        image = image.resize((800,800))
        image = transforms.ToTensor()(image)
        if image.shape[0] == 1:
            # expand to 3 channels
            image = image.expand(3, -1, -1)
        # normalize
        image_tensor = transforms.Normalize([0.485, 0.456, 0.406], [0.229,0.224,0.225])(image)
        return image_tensor, image_path, image

global th
def detect(img, model):
    # demo model only support by default images with aspect ratio between 0.5 and 2
    # if you want to use images with an aspect ratio outside this range
    # rescale your image so that the maximum size is at most 1333 for best results
    # assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'

    # propagate through the model
    outputs = model(img)
    # print(f"outputs shape: {outputs['pred_logits'].shape}")

    # keep only predictions with 0.7+ confidence
    if 'logits' in outputs:
        probas = outputs['logits'].softmax(-1)[:, :, :-1]
    elif 'pred_logits' in outputs:
        probas = outputs['pred_logits'].softmax(-1)[:, :, :-1]
    else:
        raise ValueError("No logits or pred_logits in model output")
    # probas = outputs['pred_logits'].softmax(-1)[:, :, :-1]
    keep = probas.max(-1).values > th

    # print(probas.shape)
    # print(f"boxes shape: {outputs['pred_boxes'].shape}") # [1, 100, 4])
    # print(f"keep shape: {keep.shape}") # [2,100]
    prob = []
    bbox = []
    for i in range(keep.shape[0]):
        prob.append(probas[i][keep[i]])
        o = outputs['pred_boxes'][i][keep[i]]
        # print(f"o shape: {o.shape}")
        bboxes_scaled = rescale_bboxes(o, img.shape)
        # print(f"bboxes_scaled shape: {bboxes_scaled.shape}")
        bbox.append(bboxes_scaled)

    # print(f"output keep {outputs['pred_boxes'][keep].shape}") 

    # convert boxes from [0; 1] to image scales
    # bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][keep], img.shape)
    # print(f"bboxes_scaled shape: {bboxes_scaled.shape}")

    return prob, bbox

# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

def plot_results(pil_imgs, probs, boxes_ls, save_path, img_names):

    # colors for visualization
    COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
            [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

    # enable batch dimension
    def plot_result(pil_img, prob, boxes, save_path, img_name):
        print(f"pil_img shape: {pil_img.shape}")
        print(f"prob shape: {prob.shape}") # [21,91]
        print(f"boxes shape: {boxes.shape}") # [21, 4]

        plt.figure(figsize=(16,10))
        pil_img = pil_img.permute(1,2,0)
        print(f"min: {pil_img.min()}, max: {pil_img.max()}")
        # change the range of image from [-1,1] to [0,1]
        pil_img = (pil_img + 1.0) / 2.0
        print(f"min: {pil_img.min()}, max: {pil_img.max()}")
        plt.imshow(pil_img) # TypeError: Invalid shape (3, 800, 800) for image data

        ax = plt.gca()
        for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                    fill=False, color=c, linewidth=3))
            cl = p.argmax()
            text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
            ax.text(xmin, ymin, text, fontsize=15,
                    bbox=dict(facecolor='yellow', alpha=0.5))
        plt.axis('off')
        # save at save_path
        plt.savefig(f"{save_path}/detected_{img_name}_th{th}.png")
        print(f"Save image at {save_path}/detected_{img_name}_th{th}.png")
        plt.close()
    # plt.show()

    if len(pil_imgs.shape) == 4:
        for pil_img, prob, boxes, img_name in zip(pil_imgs, probs, boxes_ls, img_names):
            plot_result(pil_img, prob, boxes, save_path, img_name)

    # else:
    #     plot_result(pil_imgs, prob, boxes, save_path, img_name)

def load_img_ls(path, file_ls, type):    
    # QA files revise here
    # path = "/data/linxi/workspace/MME/eval_tool/MME_data"
    # file_ls=['color','existence','count','position']
    # path = f"../POPE/output/{args.type}"
    # file_ls = [f'{args.type}_pope_seem_adversarial', f'{args.type}_pope_seem_random', f'{args.type}_pope_seem_popular']

    image_ls = []
    for file in file_ls:
        coco_pope_file = f'{path}/{file}.json'
        # read the file
        try:
            coco_data = [json.loads(q) for q in open(coco_pope_file, 'r')]
        except:
            with open(coco_pope_file, 'r') as f:
                coco_data = json.load(f)
        for q in coco_data:
            if "image" in q:
                image_ls.append(q["image"])
    # set
    image_ls = list(set(image_ls))
    return image_ls

if "__main__" == __name__:
    # args
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--th", type=float, default=0.95)
    parser.add_argument("--image_folder", type=str, default="/data/linxi/workspace/POPE/data/val/val2014")
        # path = "/data/linxi/workspace/MME/eval_tool/MME_data"
    parser.add_argument("--question_path", type=str, default="/data/linxi/workspace/MME/eval_tool/MME_data")
    parser.add_argument("--type", type=str, default="color")
    # parser file list
    parser.add_argument("--file_ls", type=str, default="color")
    args = parser.parse_args()
# python detr_demo_batch.py \
    # --th 0.95 \
        # --image_folder /data/linxi/workspace/POPE/data/val/val2014
    save_path = "/data/linxi/workspace/detr/plot_results"

    image_folder = "/data/linxi/workspace/POPE/data/val/val2014"
    # image_folder = "/data/linxi/workspace/MME/gqa"
    # image_folder = "/data/linxi/workspace/POPE/data/minival2014/minival2014"
    if "MME" in args.image_folder:
        image_ls = os.listdir(args.image_folder)
    else:
        image_ls = load_img_ls(args.question_path, [args.file_ls], args.type)
        
    image_dataset = ImageDataset(args.image_folder, image_ls)
    dataloader = DataLoader(image_dataset, batch_size=4, shuffle=False)

    
    ## model
    ## results saved in _{th}
    # detr = DETRdemo(num_classes=91)
    # state_dict = torch.hub.load_state_dict_from_url(
    #     url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
    #     map_location='cpu', check_hash=True)
    # detr.load_state_dict(state_dict)
    
    ## results saved in _th{th}, results are better than DETRdemo
    detr = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
    detr.eval()
    
    # pope use 0.5, others 0.95
    th_ls = [0.95, 0.5]
    for th in th_ls:
        print(f"Eval th: {th}")
        for image_tensors, image_paths, images in dataloader:
            
            results_dict = {}
            
            scores, boxes = detect(image_tensors, detr)

            # plot_results(images, scores, boxes, save_path, image_paths)
            
            for image_path, score, box in zip(image_paths, scores, boxes):

                cl_ls = [p.argmax() for p in score]
                classes = [CLASSES[cl] for cl in cl_ls]
                scores_ls = [p[cl] for p, cl in zip(score, cl_ls)]
                scores_ls = [s.cpu().numpy().tolist() for s in scores_ls]
                scores_ls = [round(s, 4) for s in scores_ls]

                results_dict[image_path] = {"classes": classes, "scores": scores_ls}


            with open(f"{save_path}/scores_boxes_{args.type}_th{th}.json", "a+") as f:
                for img_name, results in results_dict.items():
                    f.write(json.dumps({img_name: results}))
                    f.write("\n")

            # print(f"Save scores and boxes to {save_path}/scores_boxes_{args.type}_th{th}.json")
    
    # ## data
    # coco_path  = "./data/minival2017"
    # save_path = "../detr/plot_results"
    # image_files = [os.listdir(coco_path)[0]]

    # ## eval
    # for image_file in image_files:

    #     image_path = f"{coco_path}/{image_file}"
    #     im = Image.open(image_path)
        # img_name = "r"+image_file.split('/')[-1]

        # scores, boxes = detect(im, detr)

        # plot_results(im, scores, boxes, save_path, img_name)
