import numpy as np
import cv2
import torch
import tqdm
from torch.utils.data import DataLoader
from models import SDInference
import torch.nn.functional as F
from seg_dataset import TestSegDataset
from metrics import calculate_ap_with_confidence
from train import collection


def test(lora_path):
    """Evaluates segmentation model performance across multiple datasets using mIoU and AP metrics.

    Iterates over predefined test datasets, performs inference to generate segmentation
    predictions, computes mean Intersection over Union (mIoU) across threshold values
    (0.4-0.6), and calculates Average Precision (AP). Outputs performance metrics and
    identifies the optimal threshold maximizing mIoU for each dataset.

    Args:
        lora_path: Path to the trained LoRA weights required for model initialization.
    """
    model = SDInference(
        "/mnt/bn/stable-diffusion-v1-5",
        lora_path,
    )
    test_data_list = [
        [{"name": "phrase_cut", "type": "h5"}],
        [{"name": "refcoco", "type": "h5"}],
        [{"name": "refcoco+", "type": "h5"}],
        [{"name": "refcocog", "type": "h5"}],
    ]
    for data_type in test_data_list:
        dataset = TestSegDataset(data_type, dataset_type="test", unfold=True)
        dataloader = DataLoader(
            dataset,
            shuffle=False,
            persistent_workers=False,
            batch_size=1,
            num_workers=4,
            collate_fn=collection
        )
        pred_list = []
        gt_list = []

        with tqdm.tqdm(dataloader, desc="test") as pbar:
            for batch in pbar:
                for k in batch:
                    if isinstance(batch[k], torch.Tensor):
                        batch[k] = batch[k].cuda()
                pred = model.evaluate(batch, keep_dim=True, num_samples=1, only_pred=True, step_enhance=True)
                mask = batch["mask"]
                pred = F.interpolate(pred, size=mask.shape[2:], mode="bilinear", align_corners=False)
                # print(pred.shape, mask.shape)
                for p, m in zip(pred, mask):
                    pred_list.append(p.clone())
                    gt_list.append(m.clone())

        pred_list = pred_list
        gt_list = gt_list
        data = []
        for thresh in range(40, 70):
            thresh /= 100
            miou_list = []
            for pred, gt in zip(pred_list, gt_list):
                pred_binary = pred > thresh
                gt_binary = gt > 0.5
                inter = torch.logical_and(pred_binary, gt_binary).int().sum()
                union = torch.logical_or(pred_binary, gt_binary).int().sum()
                miou = (inter.float() / (union.float() + 1e-3)).cpu().item()
                miou_list.append(miou)
            data.append((thresh, round(np.mean(miou_list), 4)))
            # print(f"[{data_type}] miou: {np.mean(total_miou)}")

        pred_arr = [p[0] for p in pred_list]
        gt_arr = [g[0] for g in gt_list]
        ap = calculate_ap_with_confidence(pred_arr, gt_arr)
        print(f"miou: {np.mean(miou_list):.4f}, ap: {ap:.4f}")
        data = sorted(data, key=lambda x: x[1], reverse=True)
        print(f"best thresh: {data[0][0]}, miou={data[0][1]:.4f}")


def find_best_thresh(lora_path, only_show_best=False, dataset_type="test"):
    """Determines optimal threshold for segmentation predictions by maximizing mIoU.

    Scans threshold values (0.4-0.6) on the phrase_cut dataset, computes mean
    Intersection over Union (mIoU) for each threshold, and reports results with
    optional best-performing threshold display.

    Args:
        lora_path: Path to the trained LoRA weights required for model initialization.
        only_show_best: If True, displays only the threshold with maximum mIoU.
        dataset_type: Specifies dataset split to use (e.g., 'test', 'val').
    """
    model = SDInference(
        "/mnt/bn/stable-diffusion-v1-5",
        lora_path
    )

    dataset = TestSegDataset([{"name": "phrase_cut", "type": "h5"}], dataset_type=dataset_type, unfold=True)
    dataloader = DataLoader(
        dataset,
        shuffle=False,
        persistent_workers=False,
        batch_size=1,
        num_workers=0,
        collate_fn=collection
    )
    pred_list = []
    gt_list = []

    with tqdm.tqdm(dataloader, desc=dataset_type) as pbar:
        for batch in pbar:
            for k in batch:
                if isinstance(batch[k], torch.Tensor):
                    batch[k] = batch[k].cuda()
            pred = model.evaluate(batch, keep_dim=True, num_samples=1, only_pred=True)
            mask = batch["mask"]
            pred = F.interpolate(pred, size=mask.shape[2:], mode="bilinear", align_corners=False)
            # print(pred.shape, mask.shape)
            for p, m in zip(pred, mask):
                pred_list.append(p.clone())
                gt_list.append(m.clone())

    pred_list = pred_list
    gt_list = gt_list
    data = []
    for thresh in range(40, 70):
        thresh /= 100
        miou_list = []
        for pred, gt in zip(pred_list, gt_list):
            pred_binary = pred > thresh
            gt_binary = gt > 0.5
            inter = torch.logical_and(pred_binary, gt_binary).int().sum()
            union = torch.logical_or(pred_binary, gt_binary).int().sum()
            miou = (inter.float() / (union.float() + 1e-3)).cpu().item()
            miou_list.append(miou)
        data.append((thresh, round(np.mean(miou_list), 4)))
        # print(f"[{data_type}] miou: {np.mean(total_miou)}")

    pred_arr = [p[0] for p in pred_list]
    gt_arr = [g[0] for g in gt_list]
    ap = calculate_ap_with_confidence(pred_arr, gt_arr)
    print(f"miou: {np.mean(miou_list):.4f}, ap: {ap:.4f}")
    if only_show_best:
        data = sorted(data, key=lambda x: x[1], reverse=True)
        print(f"best thresh: {data[0][0]}, miou={data[0][1]:.4f}")
    else:
        for thresh, miou in data:
            print(f"{thresh}, miou={miou:.4f}")


if __name__ == '__main__':
    lora_path = "your_path"
    test(lora_path)
