import os
import json
import argparse
import math
from networkx import all_simple_edge_paths, overall_reciprocity
import numpy as np
from sympy import total_degree
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageFile
from pycocotools import mask as mask_utils
from region_utils import show_image
from datasets import load_dataset
import cv2

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

from io import BytesIO
import base64
def load_image_from_base64(image):
    return Image.open(BytesIO(base64.b64decode(image)))

def visualize(data: list | dict, title: str="Bincount Visualization", save_path: str="playground/visualize.png", 
              figtype: str="hist", bin_size: int=10, distribution_shift: str="linear"):
    if figtype == "bar":
        if isinstance(data, list):
            data = todict(data)
        aggregated_data = {}
        for value, count in data.items():
            if distribution_shift == "log":
                bin_key = round(math.log(value+1, 2))
            elif distribution_shift == "lg":
                bin_key = round(math.log(value, 10))
            elif distribution_shift == "linear":
                bin_key = round(value / bin_size) * bin_size
            elif distribution_shift == "sqrt":
                bin_key = (round(math.sqrt(value) / bin_size) * bin_size) ** 2
        aggregated_data[bin_key] += count

        values = list(aggregated_data.keys())
        counts = list(aggregated_data.values())

        # Create the bar chart
        plt.figure(figsize=(8, 6))
        plt.bar(values, counts, color='skyblue')
        plt.xlabel('Values')
        plt.ylabel('Counts')
        plt.title(title + f"({distribution_shift})")
        plt.xticks(values)  # To ensure each value has a label

        # Save as a PNG file
        plt.savefig(save_path, format='png')
        plt.close()  # Close the plot to free up memory
    elif figtype == "hist":
        if isinstance(data, dict):
            data = tolist(data)
        if distribution_shift == "log":
            data_expanded = np.log2(np.array(data)+1).tolist()
        elif distribution_shift == "lg":
            data_expanded = np.log10(np.array(data)).tolist()
        elif distribution_shift == "sqrt":
            data_expanded = np.sqrt(data).tolist()
        elif distribution_shift == "linear":
            data_expanded = data
        plt.figure(figsize=(8, 6))
        plt.hist(data_expanded, bins=bin_size, color='skyblue', edgecolor='black')
        plt.xlabel('Values')
        plt.ylabel('Counts')
        plt.title(title + f"({distribution_shift})")
        plt.savefig(save_path, format='png')
        plt.close()
    else:
        raise ValueError(f"Invalid figtype: {figtype}")

def log_stats(masks:list, stats:dict, detail:bool=False):
    stats["total"] += 1
    if masks is None:
        stats["missing"] += 1
        return
    if len(masks) == 0:
        stats["empty"] += 1
        return
    nmasks = len(masks)
    stats["nmasks_distribution"][nmasks] = stats["nmasks_distribution"].get(nmasks, 0) + 1
    if detail:
        areas = [mask["area"] for mask in masks]
        binary_masks : np.ndarray = mask_utils.decode([mask["segmentation"] for mask in masks]).astype(bool)
        h, w, _ = binary_masks.shape
        stats["areas"].extend([a / (h*w) for a in areas])
        covered_area = np.bitwise_or.reduce(binary_masks, axis=-1).sum()
        total_area = sum(areas)
        coverage = covered_area / (h*w)
        overlap = (total_area - covered_area) / (h*w)
        stats["coverage"].append(coverage)
        stats["overlap"].append(overlap)

def get_stats(args, dedup=True):
    if len(os.listdir(args.mask_dir)) == 0:
        raise Exception(f"No regions found at {args.mask_dir}")
    logger.info(f"Loading region masks from {args.mask_dir}")
    img_loaded = False
    if args.data_path.endswith('.json'):
        with open(args.data_path, 'r') as f:
            list_data_dict = json.load(f)
    elif args.data_path.endswith('.jsonl'):
        with open(args.data_path, 'r') as f:
            list_data_dict = [json.loads(l) for l in f]
    elif args.data_path.endswith('.tsv'):
        import pandas as pd
        questions = pd.read_table(os.path.expanduser(args.data_path))
        all_image_files = [f"{q}.png" for q in questions['index']]
        # targets = list(questions['image'])
        img_loaded = True
    elif args.data_path == "Lin-Chen/MMStar":
        questions = load_dataset("Lin-Chen/MMStar", "val")["val"]
        all_image_files = [f"{q}.png" for q in questions['index']]
        img_loaded = True
    if not img_loaded:
        all_image_files = [d['image'] for d in list_data_dict if 'image' in d]
        if dedup:
            all_image_files = list(set(all_image_files))

    stats = {
        "total": 0,
        "missing": 0,
        "empty": 0,
        "nmasks_example": {},
        "nmasks_distribution": {},
    }
    if args.detail:
        stats.update({
            "areas": [],
            "coverage": [],
            "overlap": [],
        })
    nmasks = {}
    for f_img in tqdm(all_image_files, desc="collecting stats"):
        f_mask = os.path.splitext(f_img)[0] + ".json"
        if os.path.exists(os.path.join(args.mask_dir, f_mask)):
            masks = json.load(open(os.path.join(args.mask_dir, f_mask), 'r+'))
            nmasks[f_img] = len(masks)
            if len(masks) not in stats["nmasks_example"]:
                stats["nmasks_example"][len(masks)] = f_img
        else:
            masks = None
        log_stats(masks, stats, detail=args.detail)
    stats["avg_nmask"] = sum([k*v for k,v in stats["nmasks_distribution"].items()]) / sum(stats["nmasks_distribution"].values())
    if args.detail:
        stats["avg_area"] = np.mean(stats["areas"])
        stats["avg_coverage"] = sum(stats["coverage"]) / len(stats["coverage"])
        stats["avg_overlap"] = sum(stats["overlap"]) / len(stats["overlap"])
    return stats, nmasks

def visualize_selected(args, selected_imgs: dict):
    os.makedirs(os.path.join(args.mask_dir, "visualize"), exist_ok=True)
    for nmasks, f_img in tqdm(selected_imgs.items()):
        f_mask = os.path.splitext(f_img)[0] + ".json"
        img  = cv2.imread(os.path.join(args.img_dir, f_img))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        masks = json.load(open(os.path.join(args.mask_dir, f_mask), 'r+'))
        for mask in masks:
            mask["segmentation"] = mask_utils.decode(mask["segmentation"]).astype(bool)
        show_image(img, masks, os.path.join(args.mask_dir, f"visualize/{nmasks}.png"))

def todict(data:list, sort:bool=False):
    result = {}
    if sort:
        data = sorted(data)
    for d in data:
        result[d] = result.get(d, 0) + 1
    return result

def tolist(data:dict):
    result = []
    for k,v in data.items():
        result.extend([k]*v)
    return result

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--data_path",
        type=str,
        default=None,
        help="json file with image paths and annotations",
    )

    parser.add_argument(
        "--img_dir",
        type=str,
        default=None,
        help="Location of images",
    )

    parser.add_argument(
        "--mask_dir",
        type=str,
        default=None,
        help="Location of masks (sam or ground truth if given)",
    )

    parser.add_argument(
        "--detail",
        action="store_true",
        help="get detailed stats",
    )

    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing stats.json",
    )

    args = parser.parse_args()
    if os.path.exists(os.path.join(args.mask_dir, "stats.json")) and not args.overwrite:
        stats = json.load(open(os.path.join(args.mask_dir, "stats.json"), 'r'))
        stats["nmasks_example"] = {int(k):v for k,v in stats["nmasks_example"].items()}
        stats["nmasks_distribution"] = {int(k):v for k,v in stats["nmasks_distribution"].items()}
    else:
        stats, nmasks = get_stats(args)
        json.dump(stats, open(os.path.join(args.mask_dir, "stats.json"), 'w'))
        json.dump(nmasks, open(os.path.join(args.mask_dir, "nmasks.json"), 'w'))
    nmasks_example = stats.pop("nmasks_example")
    nmasks_distribution = stats.pop("nmasks_distribution")
    if args.detail:
        areas = stats.pop("areas"),
        coverage = stats.pop("coverage")
        overlap = stats.pop("overlap")
    print(json.dumps(stats, indent=4))
    nmasks = []
    for k,v in sorted(nmasks_distribution.items()):
        nmasks.extend([k]*v)
    for percentiles in range(10, 100, 10):
        print(f"Percentile {percentiles} of nmasks: {nmasks[int(len(nmasks)*percentiles/100)]}")
    visualize(nmasks_distribution, title="Number of Masks for each Image", 
        save_path=os.path.join(args.mask_dir, "nmasks_distribution.png"), 
        figtype="hist", bin_size=32, distribution_shift="log")
    if args.detail:
        visualize(areas, title="Area of Masks", 
            save_path=os.path.join(args.mask_dir, "area_distribution.png"), 
            figtype="hist", bin_size=32, distribution_shift="lg")
        visualize(coverage, title="Coverage of Masks",
            save_path=os.path.join(args.mask_dir, "coverage_distribution.png"),
            figtype="hist", bin_size=32, distribution_shift="linear")
        visualize(overlap, title="Overlap of Masks",
            save_path=os.path.join(args.mask_dir, "overlap_distribution.png"),
            figtype="hist", bin_size=32, distribution_shift="linear")

    # import pdb; pdb.set_trace()
    # get sorted values by int key
    # selected = {k:v for k,v in sorted(nmasks_example.items())}
    # visualize_selected(args, selected)
    # data_out = [{"image": i} for i in selected.values()]
    # output_path = "playground/data/LLaVA/LLaVA-Pretrain/blip_laion_cc_sbu_selected.json"
    # with open(output_path, "w") as f:
    #     json.dump(data_out, f, indent=2)

