import json
from collections import defaultdict
import os,sys
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import numpy as np
from adjustText import adjust_text  # pip install adjustText required

project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from others.prompt.visualization import visualize_class_probility

font_path = os.path.join(project_root, "font", "times.ttf")  # Replace with your font path
font_prop = fm.FontEntry(fname=font_path, name="Times New Roman")  # Set font name
fm.fontManager.ttflist.insert(0, font_prop)

mpl.rcParams['font.family'] = 'Times New Roman'
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 14
mpl.rcParams['axes.titlesize'] = 18



file_path = 'data/cocogbv1/Ksplit_gender_category.json'
# Load data
with open(file_path, 'r') as f:
    data = json.load(f)

def plot_label_cocolabel_bars(label_counts,cocolabel_counts, cooccurrence_counts):
    n = 0
    for label in label_counts:
        n += int(label_counts[label])

    def compute_PMI(x_z, x, z):
        return np.log(float(x_z)*n/(x*z))

    label_name = {
    1: "person",
    2: "bicycle",
    3: "car",
    4: "motorcycle",
    5: "airplane",
    6: "bus",
    7: "train",
    8: "truck",
    9: "boat",
    10: "traffic light",
    11: "fire hydrant",
    13: "stop sign",
    14: "parking meter",
    15: "bench",
    16: "bird",
    17: "cat",
    18: "dog",
    19: "horse",
    20: "sheep",
    21: "cow",
    22: "elephant",
    23: "bear",
    24: "zebra",
    25: "giraffe",
    27: "backpack",
    28: "umbrella",
    31: "handbag",
    32: "tie",
    33: "suitcase",
    34: "frisbee",
    35: "skis",
    36: "snowboard",
    37: "sports ball",
    38: "kite",
    39: "baseball bat",
    40: "baseball glove",
    41: "skateboard",
    42: "surfboard",
    43: "tennis racket",
    44: "bottle",
    46: "wine glass",
    47: "cup",
    48: "fork",
    49: "knife",
    50: "spoon",
    51: "bowl",
    52: "banana",
    53: "apple",
    54: "sandwich",
    55: "orange",
    56: "broccoli",
    57: "carrot",
    58: "hot dog",
    59: "pizza",
    60: "donut",
    61: "cake",
    62: "chair",
    63: "couch",
    64: "potted plant",
    65: "bed",
    67: "dining table",
    70: "toilet",
    72: "tv",
    73: "laptop",
    74: "mouse",
    75: "remote",
    76: "keyboard",
    77: "cell phone",
    78: "microwave",
    79: "oven",
    80: "toaster",
    81: "sink",
    82: "refrigerator",
    84: "book",
    85: "clock",
    86: "vase",
    87: "scissors",
    88: "teddy bear",
    89: "hair drier",
    90: "toothbrush"
}
    
    
    # saved_category_name = ["wine glass", "spoon", "bowl", "dining table", "bicycle", "backpack", "cell phone",
    #     "tie", "handbag", "suitcase", "car", "umbrella", "horse", "bench", "cup", "chair",
    #     "surfboard", "cake", "book", "tennis racket", "kite", "dog", "bottle", "frisbee",
    #     "skis", "sink", "couch", "laptop", "fork", "tv", "knife", "bed", "sports ball",
    #     "potted plant", "remote", "truck"]
    
    saved_category_name = [
        "spoon", "bowl", "dining table", "bicycle", "backpack", "cell phone",
        "tie",  "cup", "chair","surfboard", "cake", "book", "kite", "frisbee",
        "sink", "couch", "laptop", "sports ball","remote", "truck"
    ]
    
    """
    Plot bar charts of co-occurrences between labels_batch and cocolabel_batch.
    Draw 0 and 1 separately, x-axis for cocolabel (from small to large), y-axis for calculated values.
    """
    # 1. Build mapping from category name to id
    name_to_id = {v: k for k, v in label_name.items()}

    # 2. Only keep categories in saved_category_name and maintain order
    filtered_cocolabels = [name_to_id[name] for name in saved_category_name if name in name_to_id]

    # 3. Regenerate y0, y1, label_name_coco
    y0 = []
    y1 = []
    label_name_coco = []
    for cocolabel in filtered_cocolabels:
        label_name_coco.append(label_name.get(cocolabel, f"Unknown({cocolabel})"))
        if cocolabel in cooccurrence_counts[0]:
            y0.append(compute_PMI(x_z=cooccurrence_counts[0][cocolabel], x=label_counts[0], z=cocolabel_counts[cocolabel]))
        else:
            y0.append(0)
        if cocolabel in cooccurrence_counts[1]:
            y1.append(compute_PMI(x_z=cooccurrence_counts[1][cocolabel], x=label_counts[1], z=cocolabel_counts[cocolabel]))
        else:
            y1.append(0)

    x = np.arange(len(filtered_cocolabels))
    width = 0.35

    
    
    plt.style.use('default')
    mpl.rcParams['font.family'] = 'Times New Roman'  # Change to Arial font
    mpl.rcParams['font.size'] = 14
    mpl.rcParams['axes.labelsize'] = 16
    mpl.rcParams['xtick.labelsize'] = 14
    mpl.rcParams['ytick.labelsize'] = 14
    mpl.rcParams['legend.fontsize'] = 13
    mpl.rcParams['axes.grid'] = False
    mpl.rcParams['axes.spines.top'] = False    # Remove top border
    mpl.rcParams['axes.spines.right'] = False 
    # Use softer color scheme
    colors = [ '#FFC3A0','#A6C8E8']  # Soft blue and coral

    # Adjust figure size and DPI
    fig, ax = plt.subplots(figsize=(max(12, len(filtered_cocolabels) * 0.6), 6), dpi=800)

    # Draw bar chart using new color scheme
    ax.bar(x - width/2, y0, width, label='Female', color=colors[0])
    ax.bar(x + width/2, y1, width, label='Male', color=colors[1])

    # Set labels and title
    ax.set_xlabel('Context Categories', fontsize=27, fontweight='bold')
    ax.set_ylabel('PMI', fontsize=27, fontweight='bold')

    # Set x-axis labels
    ax.set_xticks(x)
    ax.set_xticklabels(label_name_coco, rotation=45, ha='right', fontsize=22)

    # Move legend to top center
    ax.legend(ncol=2, bbox_to_anchor=(0.5, 1.02), loc='lower center', 
             frameon=False, columnspacing=1,
                fontsize=25,  # Increase legend font size
                labelspacing=1,  # Increase vertical spacing between legend items
                handlelength=2,  # Increase legend marker length
                handletextpad=1  # Increase spacing between legend marker and text
                )

    # Set y-axis ticks
    ax.yaxis.set_major_locator(plt.MultipleLocator(0.2))  # Adjust based on actual data range
    ax.tick_params(axis='y', labelsize=18)

    # Set axis line thickness
    for spine in ax.spines.values():
        spine.set_linewidth(1.5)  # pixels

    # Add grid lines (optional)
    # ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.grid(linestyle='--', alpha=0.3)

    # Adjust layout
    plt.subplots_adjust(top=0.85, bottom=0.2)

    # Save figure
    plt.savefig('cooccurrence_bar_new.pdf', bbox_inches='tight', dpi=800, transparent=True)
    plt.savefig('cooccurrence_bar_new.png', bbox_inches='tight', dpi=800)

# Extract images with split='train' and gender=0 or 1 from the images list
train_gender_cocoid_pairs = []
for image in data['images']:
    if image.get('split') == 'train' and image.get('gender') in [0, 1]:
        train_gender_cocoid_pairs.append({
            # 'cocoid': image.get('cocoid'),
            'category_id': image.get('category_id'),
            'gender': image.get('gender')
        })

# Print number of extracted data
print(f"Extracted {len(train_gender_cocoid_pairs)} images with split='train' and gender=0 or 1")
# Count label occurrences and co-occurrences, then pass to plotting function
label_counts = defaultdict(int)
cocolabel_counts = defaultdict(int)
cooccurrence_counts = defaultdict(lambda: defaultdict(int))
for data in train_gender_cocoid_pairs:
    # print(type(data['category_id'])) -> str
    # print(type(data['gender'])) -> int
    label_counts[data['gender']] += 1
    category_list = [c for c in data['category_id'].split(' ') if c.strip() != '']
    for category in category_list:
        cocolabel_counts[int(category)] += 1
        cooccurrence_counts[data['gender']][int(category)] += 1

plot_label_cocolabel_bars(label_counts, cocolabel_counts, cooccurrence_counts)






# Optional: Save results to file
# with open('train_gender_cocoid_data.json', 'w') as f:
#     json.dump(train_gender_cocoid_pairs, f, indent=2)

# If you just need to print or process
# for item in train_gender_cocoid_pairs[:10]:  # Print first 10 examples
#     print(f"Cocoid: {item['cocoid']}, Gender: {item['gender']}")



