import os
import pdb
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F

from vggt.models.vggt import VGGT
from vggt.heads.dpt_head import DPTHead
from uw_model import TokenPrototypeModulator, LightAEstimator, TokenPrototypeModulatorGAT
from dataset import SingleImageDataset, SingleImageDepthDataset, RandomSequencePathDataset, TwoimageDepthDataset_SQUID1, MultiImageDepthDatasetV2

from visual_util import (save_multiframe_colored_pointcloud, estimate_beta_A_multiframe_rgb_tensor, 
                        compute_J_from_d_v2, get_A, estimate_A, estimate_A_darkchannel, compute_A_gt, 
                        scale_align, compute_depth_metrics, compute_si_rmse)
from vggt.utils.pose_enc import pose_encoding_to_extri_intri


def average_metrics(metrics_list):
    avg_metrics = {}
    num_samples = len(metrics_list)
    if num_samples == 0:
        return avg_metrics 

    keys = metrics_list[0].keys()
    for key in keys:
        avg_metrics[key] = sum([m[key] for m in metrics_list]) / num_samples

    return avg_metrics

def validate_depth_model(model_ckpt_path, test_set, test_set1=None, num_images=2, save_dir="output_depth_eval"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    img_identifiers = []  

    if test_set=='squid':
        if num_images==2:
            test_dataset = TwoimageDepthDataset_SQUID1(rgb_dir, depth_dir)
            test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)
        if num_images==1:
            test_dataset = TwoimageDepthDataset_SQUID1(rgb_dir, depth_dir, num_images=num_images)
            test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)            

    print("Loading model...")
    model = VGGT().to(device)
    model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
    model.eval()
    for p in model.parameters():
        p.requires_grad = False

    modulator = TokenPrototypeModulatorGAT(token_dim=2048, num_prototypes=24).to(device)
    A_net = LightAEstimator().to(device)

    checkpoint = torch.load(model_ckpt_path)
    modulator.load_state_dict(checkpoint['modulator_state_dict'])
    A_net.load_state_dict(checkpoint['A_net_state_dict'])
    print("Model weights loaded.")

    modulator.eval()
    A_net.eval()

    os.makedirs(save_dir, exist_ok=True)

    D_init_metrics_list, Depth_metrics_list = [], []
    predictions={}

    img_paths_all = [] 

    with torch.no_grad():
        for idx, (images, gt_depths, img_paths) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
            for img_path in img_paths:
                img_paths_all.append(img_path)

            images = images.to(device).squeeze(0)  # [2, 3, H, W]
            gt_depths = gt_depths.to(device)

            A_pred = A_net(images)  # [2, 3]
            tokens, patch_start_idx = model.aggregator(images.unsqueeze(0))  # tokens: [1, 2, 25681, 2048]

            
            if isinstance(img_path, list) and len(img_path) > 0:
                identifier = os.path.basename(img_path[0])  
            elif isinstance(img_path, str):
                identifier = os.path.basename(img_path)
            else:
                identifier = f"Index_{idx}"
            img_identifiers.append(identifier)


            D_init, _ = model.depth_head(tokens, images=images.unsqueeze(0), patch_start_idx=patch_start_idx)
            D_init_pred = D_init[0]  # [2, H, W, 1]


            ########################
            mod_tokens_list = [[] for _ in range(images.shape[0])]  

            split_tokens_per_frame = [[] for _ in range(images.shape[0])]  # [[...], [...]]
            for token in tokens:  
                for i in range(images.shape[0]): 
                    token_i = token[:, i:i+1]  
                    split_tokens_per_frame[i].append(token_i)

            for i in range(images.shape[0]):
                A_pred_i = A_pred[i:i+1]  # [1, 3]
                mod_tokens, _, _, _ = modulator(split_tokens_per_frame[i], A_pred_i)  # mod_tokens是list，长度24
                mod_tokens_list[i] = mod_tokens

            mod_tokens_per_frame = [torch.cat(mod_tokens_list[i], dim=0) for i in range(images.shape[0])]

            final_mod_tokens = []
            for j in range(len(tokens)):
                mod_token_j = torch.cat([mod_tokens_per_frame[i][j:j+1] for i in range(images.shape[0])], dim=1)
                # shape: [1, 2, 25681, 2048]
                final_mod_tokens.append(mod_token_j)
            ########################

            Depth_pred, _ = model.depth_head(final_mod_tokens, images=images.unsqueeze(0), patch_start_idx=patch_start_idx)
            Depth_pred = Depth_pred[0]  # [2, H, W, 1]


            if test_set=='squid':
                for i in range(images.shape[0]):
                    D_init_pred_i = D_init_pred[i, :, :, 0]
                    Depth_pred_i  = Depth_pred[i, :, :, 0]

                    gt_depth_i = gt_depths[0, i, 0]  # [H_gt, W_gt]

                    target_size = gt_depth_i.shape  # (H_gt, W_gt)

                    D_init_pred_i = F.interpolate(D_init_pred_i.unsqueeze(0).unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze()
                    Depth_pred_i  = F.interpolate(Depth_pred_i.unsqueeze(0).unsqueeze(0),  size=target_size, mode='bilinear', align_corners=False).squeeze()

                    # scale align
                    D_init_aligned = scale_align(D_init_pred_i, gt_depth_i)
                    Ours_aligned   = scale_align(Depth_pred_i, gt_depth_i)

                    D_init_metrics = compute_depth_metrics(D_init_aligned, gt_depth_i)
                    Depth_metrics  = compute_depth_metrics(Ours_aligned, gt_depth_i)

                    # si-RMSE
                    D_init_metrics['si-RMSE'] = compute_si_rmse(D_init_pred_i, gt_depth_i)
                    Depth_metrics['si-RMSE']  = compute_si_rmse(Depth_pred_i, gt_depth_i)

                    D_init_metrics_list.append(D_init_metrics)
                    Depth_metrics_list.append(Depth_metrics)
            else:
                D_init_pred_i = D_init_pred[0, :, :, 0]
                Depth_pred_i  = Depth_pred[0, :, :, 0]

                gt_depth_i = gt_depths[0, 0, 0]  # [H_gt, W_gt]

                target_size = gt_depth_i.shape  # (H_gt, W_gt)

                D_init_pred_i = F.interpolate(D_init_pred_i.unsqueeze(0).unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze()
                Depth_pred_i  = F.interpolate(Depth_pred_i.unsqueeze(0).unsqueeze(0),  size=target_size, mode='bilinear', align_corners=False).squeeze()

                # scale align
                D_init_aligned = scale_align(D_init_pred_i, gt_depth_i)
                Ours_aligned   = scale_align(Depth_pred_i, gt_depth_i)

                D_init_metrics = compute_depth_metrics(D_init_aligned, gt_depth_i)
                if D_init_metrics is None:
                    continue 
                Depth_metrics  = compute_depth_metrics(Ours_aligned, gt_depth_i)

                # si-RMSE
                D_init_metrics['si-RMSE'] = compute_si_rmse(D_init_pred_i, gt_depth_i)
                Depth_metrics['si-RMSE']  = compute_si_rmse(Depth_pred_i, gt_depth_i)

                D_init_metrics_list.append(D_init_metrics)
                Depth_metrics_list.append(Depth_metrics)                

    if test_set1 is not None:
        save_results_to_file(
            D_init_metrics_list, 
            Depth_metrics_list, 
            f'results_summary_{test_set}_{test_set1}_numimage{num_images}.txt',
            img_paths_all,
            img_identifiers
        )
    else:
        save_results_to_file(
            D_init_metrics_list, 
            Depth_metrics_list, 
            f'results_summary_{test_set}_numimage{num_images}.txt',
            img_paths_all,
            img_identifiers
        )


# ✅ 调用验证
validate_depth_model(
    model_ckpt_path="checkpoints/model.pth",
    test_set='squid', num_images=2, save_dir="output_depth_eval_squid"
)
