# Copyright (c) Winci.
# Licensed under the Apache License, Version 2.0 (the "License");

import os
import time
import logging
import random
import numpy as np
import copy
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torchvision import transforms
from src.transform import MuiltiCropDataset
from args import get_args
from method import get_method
import src.vision_transformer as vits
import src.resnet as resnet
from tqdm import tqdm

from src.utils import (
    setup_logging,
    cosine_scheduler,
    build_optimizer,
    load_pretrained_im,
    load_pretrained_clue,
    restart_from_checkpoint,
    init_distributed_device,
    AverageMeter,
)

from eval_retrieval import (
    extract_features,
    retrieval_rank,

)


def random_seed(args):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

def unwrap_model(model):
    if hasattr(model, 'module'):
        return model.module
    else:
        return model

def main():
    args = get_args()
    random_seed(args)

    # fully initialize distributed device environment
    device, args = init_distributed_device(args)

    print(f"Rank {args.rank} running on device {args.device}")

    if not os.path.exists(args.dump_path):
        # Create the folder if it doesn't exist
        os.makedirs(args.dump_path)
    
    setup_logging(os.path.join(args.dump_path,'out.log'), logging.INFO)

    if args.local_rank != 0:
        def log_pass(*args): pass
        logging.info = log_pass

    # build data
    traindir = os.path.join(args.data_path, 'train')
    
    train_dataset = MuiltiCropDataset(
        traindir,
        args,
        return_index=False,
        json_path=args.text_path,
        qa_idx=args.qa_idx
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logging.info(f"Building data done with {len(train_dataset)} images loaded.")

    # ================== retrieval setting ==================
    # ============ preparing data ... ============
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
    dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "test"), transform=transform)

    sampler_val_train = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)  # val过程中的train set
    sampler_val_test = torch.utils.data.DistributedSampler(dataset_val, shuffle=False)  # val过程中的test set
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_val_train,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False,
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val_test,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False,
    )
    logging.info(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
    best_rank1 = best_rank5 = best_rank10 = 0.0
    # ================= retrieval setting==================

    # build model
    model = get_method(args)
    # print("model: ", model)
    # for name, param in model.named_parameters():
    #     print(f"{name}: {param.shape}")
    # print("part_proto shape:", model.part_proto.shape)  # 应为 [n_parts, 2048]
    # time.sleep(1000)

    if args.ckpt_from_impre:
        load_pretrained_im(model, args.ckpt_from_impre)
    
    # synchronize batch norm layers
    if "vit" not in args.arch:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # clip_model_v = nn.SyncBatchNorm.convert_sync_batchnorm(clip_model_v)
        # clip_model_t = nn.SyncBatchNorm.convert_sync_batchnorm(clip_model_t)

    # copy model to GPU
    torch.cuda.set_device(device)
    model.cuda(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)
    # clip_model_v.cuda(device)
    # clip_model_v = nn.parallel.DistributedDataParallel(clip_model_v, device_ids=[device])
    # clip_model_t.cuda(device)
    # clip_model_t = nn.parallel.DistributedDataParallel(clip_model_t, device_ids=[device])

    logging.info(model)
    logging.info("Building model done.")
    
    # Initialize NetVLAD centroids with K-means if using netvlad clustering
    if hasattr(args, 'is_parts') and args.is_parts == "netvlad":
        logging.info("Initializing NetVLAD centroids with K-means...")
        init_netvlad_with_kmeans(model, train_loader, args, device)
    
    # logging.info(clip_model_v)
    # logging.info("Building clip_model done.")
    # logging.info(clip_model_t)
    # logging.info("Building clip_model done.")

    # build optimizer
    args.lr = args.lr * args.batch_size * args.world_size / 256
    # args.lr = args.lr * args.batch_size * args.world_size / 128
    # args.lr = args.lr

    # Separate NetVLAD parameters from other parameters
    unwrapped_model = unwrap_model(model)
    if hasattr(args, 'is_parts') and args.is_parts == "netvlad" and hasattr(unwrapped_model, 'net_vlad'):
        freeze_netvlad = getattr(args, 'freeze_netvlad', False)
        
        netvlad_params = list(unwrapped_model.net_vlad.parameters())
        fc_q_params = list(unwrapped_model.fc_q.parameters())
        
        other_params = []
        for name, param in model.named_parameters():
            if not name.startswith('module.net_vlad') and not name.startswith('module.fc_q'):
                other_params.append(param)
        
        optimizer = build_optimizer(other_params, args)
        
        if freeze_netvlad:
            for param in netvlad_params:
                param.requires_grad = False
            netvlad_optimizer = None
            logging.info("NetVLAD parameters frozen")
        else:
            netvlad_lr = getattr(args, 'netvlad_lr', args.lr*2.0)
            netvlad_args = copy.deepcopy(args)
            netvlad_args.lr = netvlad_lr
            netvlad_optimizer = build_optimizer(netvlad_params, netvlad_args)
            logging.info(f"NetVLAD lr: {netvlad_lr:.6f}")
        
        fc_q_lr = getattr(args, 'fc_q_lr', args.lr*0.2)
        fc_q_args = copy.deepcopy(args)
        fc_q_args.lr = fc_q_lr
        fc_q_optimizer = build_optimizer(fc_q_params, fc_q_args)
        logging.info(f"fc_q lr: {fc_q_lr:.6f}")
    else:
        optimizer = build_optimizer(model.parameters(), args)
        netvlad_optimizer = None
        fc_q_optimizer = None

    # ============ init schedulers ... ============
    args.lr_schedule = cosine_scheduler(
        args.lr,
        args.final_lr,
        args.epochs, len(train_loader),
        warmup_epochs=args.warmup_epochs,
    )
    
    # NetVLAD scheduler (if separate optimizer exists)
    if netvlad_optimizer is not None:
        netvlad_final_lr = getattr(args, 'netvlad_final_lr', args.final_lr)
        args.netvlad_lr_schedule = cosine_scheduler(
            netvlad_lr,
            netvlad_final_lr,
            args.epochs, len(train_loader),
            warmup_epochs=args.warmup_epochs,
        )
    
    # fc_q scheduler (if separate optimizer exists)
    if fc_q_optimizer is not None:
        fc_q_final_lr = getattr(args, 'fc_q_final_lr', args.final_lr)
        args.fc_q_lr_schedule = cosine_scheduler(
            fc_q_lr,
            fc_q_final_lr,
            args.epochs, len(train_loader),
            warmup_epochs=args.warmup_epochs,
        )

    # momentum parameter is increased to 1. during training with a cosine schedule
    args.momentum_schedule = cosine_scheduler(
            args.momentum, 1,
            args.epochs, len(train_loader)
    )

    logging.info(f"Building {args.optimizer} optimizer done.")

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0, "best_rank1":0.0}
    if netvlad_optimizer is not None:
        restart_from_checkpoint(
            os.path.join(args.dump_path, "checkpoint.pth.tar"),
            run_variables=to_restore,
            state_dict=model,
            optimizer=optimizer,
            netvlad_optimizer=netvlad_optimizer,
            fc_q_optimizer=fc_q_optimizer,
        )
    else:
        restart_from_checkpoint(
            os.path.join(args.dump_path, "checkpoint.pth.tar"),
            run_variables=to_restore,
            state_dict=model,
            optimizer=optimizer,
        )
    start_epoch = to_restore["epoch"]
    best_rank1 = to_restore["best_rank1"]

    cudnn.benchmark = True

    scaler = torch.cuda.amp.GradScaler()
    # scaler = torch.amp.GradScaler('cuda')

    # ==============val process=============
    global rank1_temp, rank5_temp, rank10_temp
    rank1_temp = rank5_temp = rank10_temp = 0.0

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logging.info(f"============ Starting epoch {epoch} ... ============")
        # print("part_proto shape:", model.module.part_proto.shape)  # 应为 [n_parts, 2048]
        # print("part_proto:", model.module.part_proto)
        # time.sleep(1000)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        loss = train(train_loader, model, scaler, optimizer, epoch, args, netvlad_optimizer, fc_q_optimizer)
        
        # retrieval network
        if args.arch.startswith('vit'):
            eval_model, _ = vits.__dict__[args.arch](patch_size=args.patch_size)
        else:
            eval_model, _ = resnet.__dict__[args.arch]()
        if args.local_rank == 0:
            # remove the DDP wrapper
            eval_model = unwrap_model(eval_model)
            state_dict = model.state_dict()
            filtered_state_dict = {}
            for k, v in state_dict.items():
                if k.startswith('module.encoder'):
                    new_key = k[len("module.encoder."):]
                    filtered_state_dict[new_key] = v
                elif not (k.startswith('module.momentum_encoder') 
                        or k.startswith('module.momentum_projector')
                        or k.startswith('module.projector')
                        or k.startswith('module.predictor')
                        or k.startswith('module.part_proto')
                        or k.startswith('module.net_vlad')
                        or k.startswith('module.clip')
                        ):
                    filtered_state_dict[k] = v

            temp_path = os.path.join(args.dump_path, "temp_checkpoint.pth")
            torch.save(filtered_state_dict, temp_path)
            
        # 所有进程同步，等待 rank0 完成保存
        torch.distributed.barrier()

        # load_pretrained_resa(eval_model, os.path.join(args.dump_path, "temp_checkpoint.pth.tar"))
        temp_path = os.path.join(args.dump_path, "temp_checkpoint.pth")
        checkpoint = torch.load(temp_path, map_location='cpu')
        msg = eval_model.load_state_dict(checkpoint, strict=False)
        # logging.info(f"Load checkpoint message: {msg}")
        logging.info(f"Load checkpoint message: logger记得加回来")
        # time.sleep(1000)
        eval_model.cuda(device)
        eval_model = nn.parallel.DistributedDataParallel(eval_model, device_ids=[device], find_unused_parameters=True)
        cudnn.benchmark = True
        
        val_retrieval(args, eval_model, dataset_train, dataset_val, data_loader_train, data_loader_val)
        
        # save checkpoints
        if args.local_rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if netvlad_optimizer is not None:
                save_dict["netvlad_optimizer"] = netvlad_optimizer.state_dict()
            if fc_q_optimizer is not None:
                save_dict["fc_q_optimizer"] = fc_q_optimizer.state_dict()
            
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )

            logging.info(f"Rank1: {rank1_temp:.2f}, Rank5: {rank5_temp:.2f}, Rank10: {rank10_temp:.2f}")
            if rank1_temp > best_rank1:
                logging.info("Best rank1 found. Saving the model...")
                best_rank1 = rank1_temp
                best_rank5 = rank5_temp
                best_rank10 = rank10_temp
                best_save_dict = save_dict.copy()  # Use the same save_dict which includes netvlad_optimizer
                torch.save(
                    best_save_dict,
                    os.path.join(args.dump_path, "best_checkpoint.pth.tar"),
                )

        del eval_model
        torch.cuda.empty_cache()

    logging.info("Training done. Saving the final model ...")
        
def train(loader, model, scaler, optimizer, epoch, args, netvlad_optimizer=None, fc_q_optimizer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_global = AverageMeter()
    losses_part = AverageMeter()
    losses_text = AverageMeter()
    model.train()

    end = time.time()
    for it, samples in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # update parameters
        iters = len(loader) * epoch + it  # global training iteration
        adjust_parameters(model, optimizer, args, iters, netvlad_optimizer, fc_q_optimizer)

        # clip forward
        # clip_samples = [x.cuda(non_blocking=True) for x in samples]
        # resized_clip_samples = F.interpolate(clip_samples[0], size=(224, 224), mode="bilinear", align_corners=False)
        # features
        # clip_model_v.eval()
        # clip_model_t.eval()
        # with torch.no_grad():
        #     image_features = clip_model_v(resized_clip_samples)
        #     image_features = F.normalize(image_features, p=2, dim=1)
        # todo: text features

        # ============ backward and optim step ... ============
        optimizer.zero_grad()
        if netvlad_optimizer is not None:
            netvlad_optimizer.zero_grad()
        if fc_q_optimizer is not None:
            fc_q_optimizer.zero_grad()
        
        # print("samples:", samples[1])
        # time.sleep(200)
        with torch.cuda.amp.autocast():
        # with torch.amp.autocast('cuda'):
            loss_global, loss_part, loss_text = model(samples)
            # loss = loss_assign + loss_part
            # loss = loss_global + loss_part + loss_text
            loss = loss_global
            if args.is_parts!=None:
                loss = loss + loss_part
            if args.with_texts!=None:
                loss = loss + loss_text
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        if netvlad_optimizer is not None:
            scaler.step(netvlad_optimizer)
        if fc_q_optimizer is not None:
            scaler.step(fc_q_optimizer)
        scaler.update()

        # ============ misc ... ============
        # 为了方便，直接把text去掉了，之后都只是使用image部分的sampler size
        samples = samples[0]
        losses.update(loss.item(), samples[0].size(0))
        losses_global.update(loss_global.item(), samples[0].size(0))
        losses_part.update(loss_part.item(), samples[0].size(0))
        losses_text.update(loss_text.item(), samples[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if args.local_rank == 0 and it % 5 == 0:
            lr_info = f"Lr: {optimizer.param_groups[0]['lr']:.4f}"
            if netvlad_optimizer is not None:
                lr_info += f" NetVLAD_Lr: {netvlad_optimizer.param_groups[0]['lr']:.4f}"
            elif hasattr(args, 'is_parts') and args.is_parts == "netvlad" and getattr(args, 'freeze_netvlad', False):
                lr_info += " NetVLAD_Lr: FROZEN"
            
            logging.info(
                "Epoch: [{0}][{1}]\t"
                "Time {batch_time.val:.4f} ({batch_time.avg:.4f})\t"
                "Data {data_time.val:.4f} ({data_time.avg:.4f})\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "{lr_info}".format(
                    epoch,
                    it,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    lr_info=lr_info,
                )
            )
            logging.info(
                "Epoch: [{0}][{1}]\t"
                "Loss_assign {loss_assign.val:.4f} ({loss_assign.avg:.4f})\t"
                "Loss_part {loss_part.val:.4f} ({loss_part.avg:.4f})\t"
                "Loss_text {loss_text.val:.4f} ({loss_text.avg:.4f})".format(
                    epoch,
                    it,
                    loss_assign=losses_global,
                    loss_part=losses_part,
                    loss_text=losses_text
                )
            )
    return losses.avg


def val_retrieval(args, model, dataset_train, dataset_val, data_loader_train, data_loader_val):

    model.eval()

    global rank1_temp, rank5_temp, rank10_temp  # 添加全局声明
    with torch.no_grad():
        train_features = extract_features(model, data_loader_train, args)
        test_features = extract_features(model, data_loader_val, args)

    if args.rank == 0:
        train_features = nn.functional.normalize(train_features, dim=1, p=2)
        test_features = nn.functional.normalize(test_features, dim=1, p=2)

    train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
    test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()

    # extract_features时已经把张量搜集了
    if args.rank == 0:
        if args.use_cuda:
            train_features = train_features.cuda()
            test_features = test_features.cuda()
            train_labels = train_labels.cuda()
            test_labels = test_labels.cuda()

        # ============ retrieval ... ============
        rank1_temp, rank5_temp, rank10_temp = retrieval_rank(train_features, train_labels, 
                                      test_features, test_labels)
        
    # 将指标转换为 Tensor
    rank1_tensor = torch.tensor(rank1_temp).cuda(args.device)
    rank5_tensor = torch.tensor(rank5_temp).cuda(args.device)
    rank10_tensor = torch.tensor(rank10_temp).cuda(args.device)

    # 广播操作 (src=0 表示数据源是 rank0)
    torch.distributed.broadcast(rank1_tensor, src=0)
    torch.distributed.broadcast(rank5_tensor, src=0)
    torch.distributed.broadcast(rank10_tensor, src=0)

    # 更新所有进程的全局变量
    rank1_temp = rank1_tensor.item()
    rank5_temp = rank5_tensor.item()
    rank10_temp = rank10_tensor.item()


def adjust_parameters(model, optimizer, args, iters, netvlad_optimizer=None, fc_q_optimizer=None):
    # Update main optimizer learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = args.lr_schedule[iters]
    
    # Update NetVLAD optimizer learning rate if it exists
    if netvlad_optimizer is not None and hasattr(args, 'netvlad_lr_schedule'):
        for param_group in netvlad_optimizer.param_groups:
            param_group['lr'] = args.netvlad_lr_schedule[iters]
    
    # Update fc_q optimizer learning rate if it exists
    if fc_q_optimizer is not None and hasattr(args, 'fc_q_lr_schedule'):
        for param_group in fc_q_optimizer.param_groups:
            param_group['lr'] = args.fc_q_lr_schedule[iters]

    unwrap_model(model).momentum = args.momentum_schedule[iters]

class ReturnIndexDataset_withText(datasets.ImageFolder):
    def __init__(self, json_path, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 加载JSON文本数据
        with open(json_path, 'r') as f:
            self.text_data = json.load(f)
        # 获取数据集根目录的Path对象
        self.root_path = Path(self.root).resolve()  # 确保路径标准化
        print("===========!!!!!!!!!!!!!+++++++++++++++++")

    def __getitem__(self, idx):
        img, lab = super(ReturnIndexDataset_withText, self).__getitem__(idx)

        img_path = self.loader(idx)
        assert 1==0
        print(image_path)
        time.sleep(1000)
        # 获取当前图片的绝对路径
        img_abs_path = Path(self.samples[idx][0]).resolve()
        # 计算相对于根目录的相对路径
        relative_path = img_abs_path.relative_to(self.root_path)
        # 转换为字符串用作键（例如：'class1/image1.jpg'）
        key = str(relative_path)
        # 获取对应文本
        text = self.text_data[key]
        return img, idx, text

class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        return img, idx


def init_netvlad_with_kmeans(model, train_loader, args, device):
    """
    Initialize NetVLAD centroids using K-means clustering.
    This function freezes the encoder, extracts features, and initializes NetVLAD centroids.
    """
    import torch.distributed as dist
    
    logging.info("Starting NetVLAD K-means initialization...")
    
    # Set model to eval mode and freeze encoder
    model.eval()
    unwrapped_model = unwrap_model(model)
    
    # Freeze encoder parameters
    for param in unwrapped_model.encoder.parameters():
        param.requires_grad = False
    
    features_list = []
    max_samples = getattr(args, 'kmeans_max_samples', 50000)  # Increase default limit
    samples_collected = 0
    
    # Set sampler to epoch 0 for consistent sampling across runs
    if hasattr(train_loader.sampler, 'set_epoch'):
        train_loader.sampler.set_epoch(0)
    
    with torch.no_grad():
        for batch_idx, samples in enumerate(train_loader):
            if samples_collected >= max_samples:
                break
                
            # Handle text samples if present
            if len(samples) > 1:
                samples = samples[0]  # Use only image samples
            
            samples = [x.cuda(device, non_blocking=True) for x in samples]
            
            # Extract feature maps from encoder
            h, emb, fp = unwrapped_model.ForwardWrapper(samples, unwrapped_model.encoder, use_projector=False)
            
            # Use the feature map (fp) for clustering - typically the output before global pooling
            feature_map = fp[0]  # Use the first (and typically only) feature map
            
            # Keep features on GPU for distributed operations
            features_list.append(feature_map)
            samples_collected += feature_map.shape[0]
            
            if batch_idx % 20 == 0:
                logging.info(f"Collected {samples_collected} samples for K-means initialization (batch {batch_idx}/{len(train_loader)})...")
        
        logging.info(f"K-means initialization: processed {batch_idx + 1} batches, collected {samples_collected} samples total")
    
    # Concatenate all features and perform K-means
    if features_list:
        local_features = torch.cat(features_list, dim=0)
        logging.info(f"Local features shape for K-means: {local_features.shape}")
        
        if hasattr(unwrapped_model, 'net_vlad'):
            # Gather features from all processes for K-means (only on rank 0)
            if dist.is_initialized():
                if args.local_rank == 0:
                    # Collect features from all processes
                    gathered_features = [torch.zeros_like(local_features) for _ in range(dist.get_world_size())]
                    dist.gather(local_features, gathered_features, dst=0)
                    all_features = torch.cat(gathered_features, dim=0)
                    logging.info(f"Gathered features shape for K-means: {all_features.shape}")
                    
                    # Move to CPU for K-means computation
                    all_features_cpu = all_features.cpu()
                    # Perform K-means initialization on rank 0
                    unwrapped_model.net_vlad.init_centroids_with_kmeans(all_features_cpu)
                else:
                    # Other ranks just send their features
                    dummy_list = []
                    dist.gather(local_features, dummy_list, dst=0)
            else:
                # Single GPU case
                local_features_cpu = local_features.cpu()
                unwrapped_model.net_vlad.init_centroids_with_kmeans(local_features_cpu)
            
            # Synchronize and broadcast centroids to all processes
            if dist.is_initialized():
                dist.barrier()  # Wait for rank 0 to finish
                dist.broadcast(unwrapped_model.net_vlad.centroids.data, src=0)
                # Also broadcast conv layer parameters
                dist.broadcast(unwrapped_model.net_vlad.conv.weight.data, src=0)
                dist.broadcast(unwrapped_model.net_vlad.conv.bias.data, src=0)
                logging.info("NetVLAD centroids broadcasted to all processes")
        else:
            logging.warning("Model does not have net_vlad attribute!")
    else:
        logging.warning("No features collected for K-means initialization!")
    
    # Unfreeze encoder parameters
    for param in unwrapped_model.encoder.parameters():
        param.requires_grad = True
    
    # Set model back to train mode
    model.train()
    
    # Synchronize all processes to ensure initialization is complete
    if dist.is_initialized():
        dist.barrier()
    
    logging.info("NetVLAD K-means initialization completed.")


if __name__ == "__main__":
    main()
