import numpy as np
from PIL import Image
import os
import cv2
import json
import random
import torch
from tqdm import tqdm
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


def classify_iou(iou_path):
    iou = torch.load(iou_path, weights_only=True)

    indices = torch.triu_indices(iou.shape[0], iou.shape[1], offset=1)
    # filter out j > i + 100
    mask = indices[1] <= indices[0] + 100
    indices = indices[:, mask]
    # filter out j < i + 10
    mask = indices[1] >= indices[0] + 10
    indices = indices[:, mask]

    iou_values = iou[indices[0], indices[1]]

    low_mask = (iou_values >= 0.3) & (iou_values < 0.4)
    medium_mask = (iou_values >= 0.4) & (iou_values < 0.6)
    high_mask = (iou_values >= 0.6) & (iou_values < 0.8)

    low_indices = indices[:, low_mask]
    medium_indices = indices[:, medium_mask]
    high_indices = indices[:, high_mask]

    low_values = iou_values[low_mask]
    medium_values = iou_values[medium_mask]
    high_values = iou_values[high_mask]

    low = [
        (i.item(), j.item(), v.item())
        for i, j, v in zip(low_indices[0], low_indices[1], low_values)
    ]
    medium = [
        (i.item(), j.item(), v.item())
        for i, j, v in zip(medium_indices[0], medium_indices[1], medium_values)
    ]
    high = [
        (i.item(), j.item(), v.item())
        for i, j, v in zip(high_indices[0], high_indices[1], high_values)
    ]

    return low, medium, high


def main():
    scans_path = "data/scannet/val"
    scans = os.listdir(scans_path)
    scans = [scan for scan in scans if os.path.isdir(os.path.join(scans_path, scan))]
    scans = sorted(scans)
    result = []
    for scan in scans:
        scan_path = os.path.join(scans_path, scan)
        depth_path = os.path.join(scan_path, "depth")
        items = os.listdir(depth_path)
        items = [int(img.split(".")[0]) for img in items]
        items = sorted(items)
        iou_path = os.path.join(scan_path, "iou.pt")
        low, medium, high = classify_iou(iou_path)
        # set ramdom seed 42
        random.seed(42)
        # each scan sample 1 pairs from low, 2 pairs from medium, 3 pairs from high
        low_sample = random.sample(low, 1)
        medium_sample = random.sample(medium, 2)
        high_sample = random.sample(high, 3)
        for i, j, v in low_sample:
            # sample 4 views between i and j
            all_frames_between = items[items.index(i) + 1 : items.index(j)]
            if len(all_frames_between) < 4:
                raise ValueError("not enough frames between i and j")
            target_ids = random.sample(all_frames_between, 4)
            target_ids = target_ids + [i, j]
            target_ids = sorted(target_ids)
            result.append(
                {
                    "scan": scan,
                    "context_ids": [i, j],
                    "target_ids": target_ids,
                    "iou": v,
                }
            )
        for i, j, v in medium_sample:
            all_frames_between = items[items.index(i) + 1 : items.index(j)]
            if len(all_frames_between) < 4:
                raise ValueError("not enough frames between i and j")
            target_ids = random.sample(all_frames_between, 4)
            target_ids = target_ids + [i, j]
            target_ids = sorted(target_ids)
            result.append(
                {
                    "scan": scan,
                    "context_ids": [i, j],
                    "target_ids": target_ids,
                    "iou": v,
                }
            )
        for i, j, v in high_sample:
            all_frames_between = items[items.index(i) + 1 : items.index(j)]
            if len(all_frames_between) < 4:
                raise ValueError("not enough frames between i and j")
            target_ids = random.sample(all_frames_between, 4)
            target_ids = target_ids + [i, j]
            target_ids = sorted(target_ids)
            result.append(
                {
                    "scan": scan,
                    "context_ids": [i, j],
                    "target_ids": target_ids,
                    "iou": v,
                }
            )
    with open("data/scannet/val_pair.json", "w") as f:
        json.dump(result, f, indent=4)


if __name__ == "__main__":
    main()
