#%%
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

train_csv = "trainingSet.csv"
val_csv = "validationSet.csv"
df_train = pd.read_csv(train_csv, header=None, names=["video_path", "ann", "num1", "num2"])
df_val = pd.read_csv(val_csv, header=None, names=["video_path", "ann", "num1", "num2"])

df = pd.concat([df_train, df_val], ignore_index=True)

label_counts = Counter(df["ann"])
sorted_counts = dict(sorted(label_counts.items(), key=lambda x: x[1], reverse=True))

plt.figure(figsize=(max(20, len(sorted_counts) * 0.3), 6))  
plt.bar(sorted_counts.keys(), sorted_counts.values())
plt.xticks(rotation=90, ha="right")
plt.xlabel("Action ann")
plt.ylabel("Number of Videos")
plt.title("Video Count per Action ann (Sorted, Train + Val)")
plt.tight_layout()
plt.show()

#%%
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

train_csv = "trainingSet.csv"
val_csv = "validationSet.csv"
df_train = pd.read_csv(train_csv, header=None, names=["video_path", "ann", "num1", "num2"])
df_val = pd.read_csv(val_csv, header=None, names=["video_path", "ann", "num1", "num2"])

train_counts = Counter(df_train["ann"])
sorted_train = dict(sorted(train_counts.items(), key=lambda x: x[1], reverse=True))

plt.figure(figsize=(max(20, len(sorted_train) * 0.3), 6))
plt.bar(sorted_train.keys(), sorted_train.values())
plt.xticks(rotation=90, ha="right")
plt.xlabel("Action ann")
plt.ylabel("Number of Videos")
plt.title("Train Video Count per Action ann (Sorted)")
plt.tight_layout()
plt.show()

val_counts = Counter(df_val["ann"])
sorted_val = dict(sorted(val_counts.items(), key=lambda x: x[1], reverse=True))

plt.figure(figsize=(max(20, len(sorted_val) * 0.3), 6))
plt.bar(sorted_val.keys(), sorted_val.values())
plt.xticks(rotation=90, ha="right")
plt.xlabel("Action ann")
plt.ylabel("Number of Videos")
plt.title("Val Video Count per Action ann (Sorted)")
plt.tight_layout()
plt.show()

train_counts = Counter(df_train["ann"])
sorted_train = dict(sorted(train_counts.items(), key=lambda x: x[1], reverse=True))

min_label, min_count = min(train_counts.items(), key=lambda x: x[1])
print(f"🔍 most mini'{min_label}', only {min_count}")
#%%
import pandas as pd
import os
import json
import cv2
import random
from tqdm import tqdm
from collections import defaultdict, Counter

random.seed(42)

train_csv = "trainingSet.csv"
val_csv = "validationSet.csv"
output_dir = "image_500"
os.makedirs(output_dir, exist_ok=True)

train_img_dir = os.path.join(output_dir, "training")
val_img_dir = os.path.join(output_dir, "validation")
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)

def clean_ann(ann: str) -> str:
    ann = ann.lower()
    if ann in {
        "adult+female+singing",
        "adult+male+singing"
    }:
        return "singing"
    elif ann in {
        "adult+female+speaking",
        "adult+male+speaking"
    }:
        return "speaking"
    else:
        return ann


def extract_middle_frame(video_path, save_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"[ERROR] Cannot open video: {video_path}")
        return False
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    mid_frame = frame_count // 2
    cap.set(cv2.CAP_PROP_POS_FRAMES, mid_frame)
    ret, frame = cap.read()
    cap.release()
    if ret and frame is not None:
        cv2.imwrite(save_path, frame)
        return True
    else:
        print(f"[WARNING] Failed to extract frame from {video_path}")
        return False

def group_and_sample(df, prefix, max_per_class):
    grouped = defaultdict(list)
    for _, row in df.iterrows():
        label = clean_ann(row["ann"])
        video_path = os.path.join(prefix, row["video_path"])
        base_name = os.path.splitext(row["video_path"])[0] + ".jpg"
        grouped[label].append((video_path, base_name))

    sampled = []
    for label, videos in grouped.items():
        selected = random.sample(videos, min(len(videos), max_per_class))
        for video_path, base_name in selected:
            sampled.append({
                "video_path": video_path,
                "image_rel_path": os.path.join(prefix, base_name),
                "image_out_path": base_name,
                "annotation": label
            })
    return sampled

def process_and_extract(samples, mode_dir, prefix_label):
    output_json = []
    for item in tqdm(samples, desc=f"Extracting {prefix_label}"):
        save_path = os.path.join(mode_dir, item["image_out_path"])
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        if extract_middle_frame(item["video_path"], save_path):
            img = cv2.imread(save_path)
            h, w = img.shape[:2] if img is not None else (0, 0)
            output_json.append({
                "image_path": os.path.relpath(save_path, output_dir),
                "image_height": h,
                "image_width": w,
                "annotation": item["annotation"]
            })
    return output_json

df_train = pd.read_csv(train_csv, header=None, names=["video_path", "ann", "num1", "num2"])
df_val = pd.read_csv(val_csv, header=None, names=["video_path", "ann", "num1", "num2"])

train_samples = group_and_sample(df_train, "training", 500)
val_samples = group_and_sample(df_val, "validation", 100)

train_json = process_and_extract(train_samples, train_img_dir, "Train")
val_json = process_and_extract(val_samples, val_img_dir, "Val")

with open(os.path.join(output_dir, "train_output.json"), "w") as f:
    json.dump(train_json, f, indent=2)
with open(os.path.join(output_dir, "val_output.json"), "w") as f:
    json.dump(val_json, f, indent=2)

def print_min_max_info(records, split_name):
    label_counter = Counter([entry["annotation"] for entry in records])
    if not label_counter:
        print(f"[{split_name.upper()}] No valid data.")
        return
    min_label, min_count = min(label_counter.items(), key=lambda x: x[1])
    max_label, max_count = max(label_counter.items(), key=lambda x: x[1])
    print(f"\n[{split_name.upper()}] Minimum: {min_label} ({min_count})")
    print(f"[{split_name.upper()}] Maximum: {max_label} ({max_count})")

print_min_max_info(train_json, "train")
print_min_max_info(val_json, "val")

print("✅ Done. Output saved in:", output_dir)

#%%