import os
import json
import torch
import collections
from PIL import Image
import numpy as np
import tqdm
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import sys

device='cuda:5'
label_info = json.load(open("evaluation_classification_capability/coco_panoptic_categories.json"))
class_names = [label_info[i]["name"] for i in range(len(label_info))]
root = sys.argv[1]
fns = os.listdir(root)
fns = [fn for fn in fns if "mask" not in fn]
fns = sorted(fns)
scores = [os.path.splitext(fn)[0]+"_mask.png" for fn in fns]

evaluation_meta_json = json.load(open(os.path.join(root, "../../logs/evaluation_meta.json"), "r"))
evaluation_meta = dict()
for it in evaluation_meta_json:
    scene_id = it['scene_id']
    scene_id = os.path.split(scene_id)[-1]
    episode_id = it['episode_id']
    success = int(it['metrics']['success'])
    assert (scene_id, episode_id) not in evaluation_meta
    evaluation_meta[(scene_id, episode_id)] = success

processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-coco-panoptic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-coco-panoptic")
model.to(device)

class Stats:
    def __init__(self):
        self.count = 0
        self.total = 0.0

    def update(self, values):
        self.count += len(values)
        self.total += np.sum(values)

    def mean(self):
        return self.total / self.count if self.count > 0 else 0.0


goal_token = Stats()
preserved_goal_token = Stats()
preserved_ratio = Stats()
goal_names = ["bed", "chair", "potted plant", "toilet", "tv", "couch"]
coco_goal_names = ["bed", "chair", "potted plant", "toilet", "tv", "couch"]
for i, fn in enumerate(tqdm.tqdm(fns)):
    image = Image.open(os.path.join(root, fn)).convert("RGB")
    score = Image.open(os.path.join(root, scores[i]))
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    class_queries_logits = outputs.class_queries_logits
    masks_queries_logits = outputs.masks_queries_logits
    result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

    pos = fn.find("_goal:")
    goal = int(fn[pos+6:pos+7])

    semantic_names = [class_names[label['label_id']] for label in result['segments_info']]
    if 'bed' in semantic_names or 'chair' in semantic_names or 'potted plant' in semantic_names or 'toilet' in semantic_names or 'tv' in semantic_names or 'couch' in semantic_names:
        fl = True
    else:
        fl = False

    if not fl:
        continue

    score = np.array(score).astype(np.float32) / 255.0
    for label in result['segments_info']:
        label_id = label['label_id']
        idx = label['id']
        name = class_names[label_id]
        
        if name == goal_names[goal]:
            mask = (result['segmentation'] == idx).cpu().numpy()
            goal_token_num = mask.sum()
            goal_token.update([goal_token_num])

            preserved_score = score[mask]
            preserved_goal_token.update([preserved_score.sum()])
            preserved_ratio.update([preserved_score.sum() / goal_token_num if goal_token_num > 0 else 0.0])
    
    if preserved_ratio.count > 200:
        break

print("#Tokens:", goal_token.total)
print("#Preserved:", preserved_goal_token.total)
print('Preserve Ratio:', preserved_ratio.mean())


exit(0)
preserve_ratios = [Stats() for _ in range(2)]
goal_names = ["bed", "chair", "potted plant", "toilet", "tv", "couch"]
coco_goal_names = ["bed", "chair", "potted plant", "toilet", "tv", "couch"]
for i, fn in enumerate(tqdm.tqdm(fns)):
    continue_flag = False
    for s in preserve_ratios:
        if s.count < 200:
            continue_flag = True
    if not continue_flag:
        break
    pos = fn.find("_goal:")
    goal = int(fn[pos+6:pos+7])

    scene_id, episode_id, _, __ = fn.split('_')
    if (scene_id, episode_id) not in evaluation_meta:
        continue

    suc = evaluation_meta[(scene_id, episode_id)]
    if preserve_ratios[suc].count >= 200:
        continue
    
    image = Image.open(os.path.join(root, fn)).convert("RGB")
    score = Image.open(os.path.join(root, scores[i]))
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    class_queries_logits = outputs.class_queries_logits
    masks_queries_logits = outputs.masks_queries_logits
    result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

    score = np.array(score).astype(np.float32) / 255.0
    for label in result['segments_info']:
        label_id = label['label_id']
        idx = label['id']
        name = class_names[label_id]
        
        if name == goal_names[goal]:
            mask = (result['segmentation'] == idx).cpu().numpy()
            goal_token_num = mask.sum()

            preserved_score = score[mask]
            preserve_ratios[suc].update([preserved_score.sum() / goal_token_num if goal_token_num > 0 else 0.0])
            

print('Preserve Ratio:', [p.mean() for p in preserve_ratios])