import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T
from tqdm import tqdm
import argparse
import time


from vpr_model import VPRModel
from utils.validation import get_validation_recalls
# Dataloader
from dataloaders.val.NordlandDataset import NordlandDataset
from dataloaders.val.NordlandGSVDataset import NordlandGSVDataset
from dataloaders.val.MapillaryDataset import MSLS
from dataloaders.val.MapillaryTestDataset import MSLSTest
from dataloaders.val.PittsburghDataset import PittsburghDataset
from dataloaders.val.SPEDDataset import SPEDDataset
from dataloaders.val.AmstertimeDataset import AmstertimeDataset
from dataloaders.val.StLuciaDataset import StLuciaDataset
from dataloaders.val.Tokyo247Dataset import Tokyo247Dataset
from dataloaders.val.PittsburghDataset import PittsburghDataset
from dataloaders.val.EynshamDataset import EynshamDataset
from dataloaders.val.SVOXDataset import SVOXDataset

VAL_DATASETS = ['MSLS', 'MSLS_Test', 'pitts30k_test', 'pitts250k_test', 'Nordland', 'Nordland_GSV',
                'SPED', 'Amstertime', 'StLucia',
                'Tokyo247', 'Eynsham', 
                'SVOX_night', 'SVOX_overcast', 'SVOX_rain', 'SVOX_snow', 'SVOX_sun']


def input_transform(image_size=None):
    MEAN=[0.485, 0.456, 0.406]; STD=[0.229, 0.224, 0.225]
    if image_size:
        return T.Compose([
            T.Resize(image_size,  interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD)
        ])
    else:
        return T.Compose([
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD)
        ])

def get_val_dataset(dataset_name, image_size=None):
    dataset_name = dataset_name.lower()
    transform = input_transform(image_size=image_size)
    
    if 'nordland' == dataset_name:    
        ds = NordlandDataset(input_transform=transform)
    
    elif 'nordland_gsv' in dataset_name:
        ds = NordlandGSVDataset(input_transform=transform)

    elif 'msls_test' in dataset_name:
        # ds = MSLSTest(input_transform=transform, path_return=True)
        ds = MSLSTest(input_transform=transform)

    elif 'msls' in dataset_name:
        # ds = MSLS(input_transform=transform, path_return=True)
        ds = MSLS(input_transform=transform)

    elif 'pitts' in dataset_name:
        ds = PittsburghDataset(which_ds=dataset_name, input_transform=transform)

    elif 'sped' in dataset_name:
        ds = SPEDDataset(input_transform=transform)

    elif 'amstertime' in dataset_name:
        ds = AmstertimeDataset(input_transform=transform)

    elif 'stlucia' in dataset_name:
        ds = StLuciaDataset(input_transform=transform)
    
    elif 'tokyo247' in dataset_name:
        ds = Tokyo247Dataset(input_transform=transform)

    elif 'eynsham' in dataset_name:
        ds = EynshamDataset(input_transform=transform)

    elif 'svox' in dataset_name:
        ds = SVOXDataset(which_ds=dataset_name, input_transform=transform)
    
    else:
        raise ValueError
    
    num_references = ds.num_references
    num_queries = ds.num_queries
    ground_truth = ds.ground_truth
    return ds, num_references, num_queries, ground_truth

def get_descriptors(model, dataloader, device):
    descriptors = []
    img_paths = []
    with torch.no_grad():
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            for batch in tqdm(dataloader, 'Calculating descritptors...'):
                # imgs, labels = batch
                if len(batch) == 2:
                    imgs, labels = batch
                elif len(batch) == 3:
                    imgs, labels, img_path = batch
                    img_paths.append(img_path)
                output = model(imgs.to(device)).cpu()
                descriptors.append(output)
    if len(img_paths) > 0:
        return torch.cat(descriptors), [item for sublist in img_paths for item in sublist]
    else:
        return torch.cat(descriptors)
    
# import statistics as stats
# def get_descriptors(model, dataloader, device, *, warmup_steps=100, drop_first_k=10):
#     """
#     측정 버전:
#     - print로 통계 출력
#     - 반환은 기존과 동일: (torch.cat(descriptors), [flat_img_paths]) 또는 torch.cat(descriptors)
#     - 측정 구간은 model(imgs_device)만 포함
#     """
#     model.eval().to(device)

#     times_ms = []
#     descriptors = []
#     img_paths_all = []

#     # ---- (간단한) Warmup: 첫 배치를 가져와 같은 입력으로 여러 번 실행 ----
#     try:
#         first_batch = next(iter(dataloader))
#         if len(first_batch) == 2:
#             imgs0, _ = first_batch
#         else:
#             imgs0, _, _ = first_batch
#         imgs0 = imgs0.to(device, non_blocking=True)
#         torch.cuda.synchronize()
#         with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
#             for _ in range(warmup_steps):
#                 _ = model(imgs0)
#         torch.cuda.synchronize()
#     except StopIteration:
#         print("[get_descriptors] 빈 dataloader 입니다.")
#         return torch.empty(0)

#     # ---- 측정 루프 ----
#     starter = torch.cuda.Event(enable_timing=True)
#     ender   = torch.cuda.Event(enable_timing=True)

#     with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
#         for batch in tqdm(dataloader, desc='Calculating descriptors (timed)...'):
#             # 기존 unpack 로직 유지
#             if len(batch) == 2:
#                 imgs, labels = batch
#             elif len(batch) == 3:
#                 imgs, labels, img_path = batch
#                 img_paths_all.append(img_path)

#             # H2D copy는 타이밍 밖에서 수행
#             imgs_device = imgs.to(device, non_blocking=True)
#             torch.cuda.synchronize()

#             # ===== model() only timing =====
#             starter.record()
#             output = model(imgs_device)
#             ender.record()
#             torch.cuda.synchronize()
#             times_ms.append(starter.elapsed_time(ender))
#             # =================================

#             # 결과 수집은 타이밍 이후 (D2H 제외)
#             descriptors.append(output.detach().cpu())

#     # ----- 통계 계산 & print -----
#     vals = times_ms[drop_first_k:] if len(times_ms) > drop_first_k else times_ms
#     vals_sorted = sorted(vals)

#     def pct(vs, q):
#         if not vs: return None
#         idx = max(0, min(len(vs)-1, int(round((q/100.0)*len(vs)-1))))
#         return vs[idx]

#     mean_ms   = (sum(vals)/len(vals)) if vals else float('nan')
#     median_ms = stats.median(vals)     if vals else float('nan')
#     p95_ms    = pct(vals_sorted, 95)   if vals_sorted else float('nan')
#     p99_ms    = pct(vals_sorted, 99)   if vals_sorted else float('nan')
#     std_ms    = stats.pstdev(vals)     if len(vals) > 1 else 0.0
#     max_mem   = torch.cuda.max_memory_allocated() / (1024*1024)

#     print("\n[Inference timing — model() only]")
#     print(f"- warmup_steps: {warmup_steps}, drop_first_k: {drop_first_k}")
#     print(f"- count: {len(vals)} (총 {len(times_ms)} 중 앞 {min(drop_first_k,len(times_ms))}개 제외)")
#     print(f"- mean:  {mean_ms:.4f} ms")
#     print(f"- median:{median_ms:.4f} ms")
#     print(f"- p95:   {p95_ms:.4f} ms")
#     print(f"- p99:   {p99_ms:.4f} ms")
#     print(f"- std:   {std_ms:.4f} ms")
#     print(f"- max CUDA memory: {max_mem:.2f} MB\n")

#     # ----- 반환은 기존 함수와 동일 -----
#     if len(descriptors) == 0:
#         cat = torch.empty(0)
#     else:
#         cat = torch.cat(descriptors)

#     if len(img_paths_all) > 0:
#         flat_img_paths = [p for sub in img_paths_all for p in sub]
#         return cat, flat_img_paths
#     else:
#         return cat





def load_model(ckpt_path):
    model = VPRModel(

        # # DINOv2 with VFM-Adapter
        backbone_arch='dinov2_vitb14_da',
        backbone_config={
            'hidden_dim': 48,
            'return_token': True,
            'norm_layer': True,
        },

        # backbone_arch='dinov2_vitl14_da',
        # backbone_config={
        #     'hidden_dim': 48,
        #     'return_token': True,
        #     'norm_layer': True,
        # },

     

        agg_arch='DA',
        agg_config={
            'in_channels': 768, # 768 1024
            'num_layers': 2
        },

    )

    # use pre-trained model
    if ckpt_path.split('/')[-1] == 'dino_salad.ckpt':
        model.load_state_dict(torch.load(ckpt_path)) # original
    
    # use training  model
    else:
        checkpoint = torch.load(ckpt_path) #DK
        model.load_state_dict(checkpoint['state_dict']) #DK

    model = model.eval()
    model = model.to('cuda')
    print(f"Loaded model from {ckpt_path} Successfully!")
    return model

def parse_args():
    parser = argparse.ArgumentParser(
        description="Eval VPR model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # Model parameters
    parser.add_argument("--ckpt_path", type=str, default='/home/dkim/VPR/salad/VPR_Project_model_result/RESULT_ddf5sd3_ffn_howlayer_2_bbhiddendim48_agg_layers2_agg_hiddendim768_numqueries64_kernel3_batch128_seed39/checkpoints/dinov2_vitb14_ddf_(35)_R1[0.9561]_R5[0.9904].ckpt', help="Path to the checkpoint")
    
    # Datasets parameters
    parser.add_argument(
        '--val_datasets',
        nargs='+',
        default=VAL_DATASETS,
        help='Validation datasets to use',
        choices=VAL_DATASETS,
    )
    parser.add_argument('--image_size', nargs='*', default=(322, 322), help='Image size (int, tuple or None)')
    parser.add_argument('--batch_size', type=int, default=512, help='Batch size')

    args = parser.parse_args()

    # Parse image size
    if args.image_size:
        if len(args.image_size) == 1:
            args.image_size = (args.image_size[0], args.image_size[0])
        elif len(args.image_size) == 2:
            args.image_size = tuple(args.image_size)
        else:
            raise ValueError('Invalid image size, must be int, tuple or None')
        
        args.image_size = tuple(map(int, args.image_size))

    return args


if __name__ == '__main__':

    torch.backends.cudnn.benchmark = True

    args = parse_args()
    
    model = load_model(args.ckpt_path)

    for val_name in args.val_datasets:
        val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_name, args.image_size)
        val_loader = DataLoader(val_dataset, num_workers=16, batch_size=args.batch_size, shuffle=False, pin_memory=True)
        print(f'Evaluating on {val_name}')
        img_paths = None
        descriptors = get_descriptors(model, val_loader, 'cuda')

        r_paths = None
        q_paths = None
        if isinstance(descriptors, tuple):
            descriptors, img_paths = descriptors
            r_paths = img_paths[ : num_references]
            q_paths = img_paths[num_references : ]

        print(f'Descriptor dimension {descriptors.shape[1]}')
        r_list = descriptors[ : num_references]
        q_list = descriptors[num_references : ]

        print('query size:', q_list.shape[0])
        print('reference size:', r_list.shape[0])

        # testing = isinstance(val_dataset, MSLSTest)
        testing = None

        preds = get_validation_recalls(
            r_list=r_list,
            q_list=q_list,
            r_paths=r_paths,
            q_paths=q_paths,
            # k_values=[1, 5, 10, 15, 20, 25],
            k_values=[1, 5, 10],
            gt=ground_truth,
            print_results=True,
            dataset_name=val_name,
            faiss_gpu=False,
            testing=testing,
        )

        if testing:
            val_dataset.save_predictions(preds, args.ckpt_path + '.' + model.agg_arch + '.preds.txt')

        del descriptors
        print('========> DONE!\n\n')

# python eval.py --ckpt_path 'ckpt/dino_salad.ckpt' --val_datasets pitts30k_test pitts250k_test MSLS Nordland --image_size 322 --batch_size 256

# CUDA_VISIBLE_DEVICES=0 python eval.py --ckpt_path '/database/dkim/DA2VPR_test/VPR_Project/DA2_dalayer_2_bbhiddendim48_layers2_aghiddendim768_numqueries64_kernel3_batch128_seed172/checkpoints/dinov2_vitb14_da_(29)_R1[0.9575]_R5[0.9900].ckpt' --val_datasets MSLS --image_size 322 --batch_size 1