import os
import time
import math
import shutil
import sys
import torch
import pickle
import copy
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dataclasses import dataclass
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader
from transformers import get_constant_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_cosine_schedule_with_warmup
import debugpy
from sample4geo.dataset.dataset import VigorDatasetEval, VigorDatasetTrain
from sample4geo.transforms import get_transforms_train, get_transforms_val
from sample4geo.utils import setup_system, Logger
from sample4geo.trainer_ddp import train
from sample4geo.evaluate.vigor_ddp import evaluate, calc_sim
from sample4geo.loss_ddp import InfoNCE
from sample4geo.model import TimmModel

def sync_samples_across_processes(dataset, device, rank):
    try:
        
        if rank == 0:
            
            samples_bytes = pickle.dumps(dataset.samples)
            samples_tensor = torch.ByteTensor(list(samples_bytes)).to(device)
            samples_size = torch.LongTensor([len(samples_bytes)]).to(device)
            
            
            pairs_bytes = pickle.dumps(dataset.pairs)
            pairs_tensor = torch.ByteTensor(list(pairs_bytes)).to(device)
            pairs_size = torch.LongTensor([len(pairs_bytes)]).to(device)
            
            
            idx2pairs_bytes = pickle.dumps(dataset.idx2pairs)
            idx2pairs_tensor = torch.ByteTensor(list(idx2pairs_bytes)).to(device)
            idx2pairs_size = torch.LongTensor([len(idx2pairs_bytes)]).to(device)
        else:
            samples_tensor = torch.ByteTensor().to(device)
            samples_size = torch.LongTensor([0]).to(device)
            pairs_tensor = torch.ByteTensor().to(device)
            pairs_size = torch.LongTensor([0]).to(device)
            idx2pairs_tensor = torch.ByteTensor().to(device)
            idx2pairs_size = torch.LongTensor([0]).to(device)
        
        
        dist.broadcast(samples_size, src=0)
        if rank != 0:
            samples_tensor = torch.ByteTensor(samples_size.item()).to(device)
        dist.broadcast(samples_tensor, src=0)
        
        
        dist.broadcast(pairs_size, src=0)
        if rank != 0:
            pairs_tensor = torch.ByteTensor(pairs_size.item()).to(device)
        dist.broadcast(pairs_tensor, src=0)
        
        
        dist.broadcast(idx2pairs_size, src=0)
        if rank != 0:
            idx2pairs_tensor = torch.ByteTensor(idx2pairs_size.item()).to(device)
        dist.broadcast(idx2pairs_tensor, src=0)
        
        
        if rank != 0:
            samples_bytes = bytes(samples_tensor.tolist())
            pairs_bytes = bytes(pairs_tensor.tolist())
            idx2pairs_bytes = bytes(idx2pairs_tensor.tolist())
            
            dataset.samples = pickle.loads(samples_bytes)
            dataset.pairs = pickle.loads(pairs_bytes)
            dataset.idx2pairs = pickle.loads(idx2pairs_bytes)
            
        if rank == 0:
            print(f"数据集同步完成，样本数量: {len(dataset.samples)}")
            
    except Exception as e:
        if rank == 0:
            print(f"数据集同步出错: {e}")
        
        dist.barrier()
        raise e

@dataclass
class Configuration:
    
    
    model: str = 'convnext_base.fb_in22k_ft_in1k_384'
    
    
    img_size: int = 384
    
    
    mixed_precision: bool = True
    seed = 1
    epochs: int = 40
    batch_size: int = 30        
    verbose: bool = True
    
    
    
    custom_sampling: bool = True   
    gps_sample: bool = False        
    sim_sample: bool = True        
    neighbour_select: int = 64     
    neighbour_range: int = 128     
    
 
    
    batch_size_eval: int = 30
    eval_every_n_epoch: int = 4      
    sim_sample_start_epoch: int = 4   
    sim_sample_start_epoch = eval_every_n_epoch
    normalize_features: bool = True

    
    clip_grad = 100.                 
    decay_exclue_bias: bool = False
    grad_checkpointing: bool = False 
    
    
    label_smoothing: float = 0.1
    
    
    lr: float = 0.001                  
    scheduler: str = "cosine"          
    warmup_epochs: int = 1
    lr_end: float = 0.0001             
    
    
    data_folder = "data/playground.json"
    same_area: bool = True             
    ground_cutting = 0                 
   
    
    prob_rotate: float = 0.75          
    prob_flip: float = 0.5             
    
    
    model_path: str = "./fuse/convnext"
    
    
    zero_shot: bool = False  
    
    
    checkpoint_start = None
  
    
    num_workers: int = 0 if os.name == 'nt' else 4 
    
    
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu' 
    
    
    cudnn_benchmark: bool = True
    
    
    cudnn_deterministic: bool = False






config = Configuration() 

def main():
    
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)
    
    config = Configuration()
    config.device = device
    config.rank = rank
    config.world_size = world_size
    config.gpu_ids = range(world_size)

    
    if rank == 0:
        model_path = "{}/{}/{}".format(config.model_path,
                                       config.model,
                                       time.strftime("%H%M%S"))
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        shutil.copyfile(os.path.basename(__file__), "{}/train.py".format(model_path))
        sys.stdout = Logger(os.path.join(model_path, 'log.txt'))
    else:
        model_path = None

    setup_system(seed=config.seed,
                 cudnn_benchmark=config.cudnn_benchmark,
                 cudnn_deterministic=config.cudnn_deterministic)

    
    
    
    if rank == 0:
        print("\nModel: {}".format(config.model))

    model = TimmModel(config.model,
                      pretrained=True,
                      img_size=config.img_size)
                          
    data_config = model.get_config()
    if rank == 0:
        print(data_config)
    mean = data_config["mean"]
    std = data_config["std"]
    img_size = config.img_size
    
    image_size_sat = (img_size, img_size)
    new_width = img_size*2    
    new_hight = int(((1024 - 2 * config.ground_cutting) / 2048) * new_width)
    img_size_ground = (new_hight, new_hight)
    
    
    if config.grad_checkpointing:
        model.set_grad_checkpointing(True)
     
    
    if config.checkpoint_start is not None:  
        if rank == 0:
            print("Start from:", config.checkpoint_start)
        model_state_dict = torch.load(config.checkpoint_start, map_location='cpu')  
        model.load_state_dict(model_state_dict, strict=False)     

    
    model = model.to(device)
    model = DDP(model, device_ids=[local_rank])

    if rank == 0:
        print("\nImage Size Sat:", image_size_sat)
        print("Image Size Ground:", img_size_ground)
        print("Mean: {}".format(mean))
        print("Std:  {}\n".format(std)) 

    
    
    

    
    sat_transforms_train, ground_transforms_train = get_transforms_train(image_size_sat,
                                                                         img_size_ground,
                                                                         mean=mean,
                                                                         std=std,
                                                                         ground_cutting=config.ground_cutting)
                                                                   
    
    train_dataset = VigorDatasetTrain(data_folder=config.data_folder ,
                                      same_area=config.same_area,
                                      transforms_query=ground_transforms_train,
                                      transforms_reference=sat_transforms_train,
                                      prob_flip=config.prob_flip,
                                      prob_rotate=config.prob_rotate,
                                      shuffle_batch_size=config.batch_size
                                      )
    
    train_sampler = DistributedSampler(
        train_dataset, 
        shuffle=not config.custom_sampling,
        seed=config.seed
    )
    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        pin_memory=True,
        sampler=train_sampler
    )
    
    
    
    sat_transforms_val, ground_transforms_val = get_transforms_val(image_size_sat,
                                                                   img_size_ground,
                                                                   mean=mean,
                                                                   std=std,
                                                                   ground_cutting=config.ground_cutting)

    
    reference_dataset_test = VigorDatasetEval(data_folder=config.data_folder ,
                                              split="test",
                                              img_type="reference",
                                              same_area=config.same_area,  
                                              transforms=sat_transforms_val,
                                              )
    
    reference_sampler_test = DistributedSampler(reference_dataset_test, shuffle=False)
    reference_dataloader_test = DataLoader(
        reference_dataset_test,
        batch_size=config.batch_size_eval,
        num_workers=config.num_workers,
        pin_memory=True,
        sampler=reference_sampler_test
    )
    
    
    query_dataset_test = VigorDatasetEval(data_folder=config.data_folder ,
                                          split="test",
                                          img_type="query",
                                          same_area=config.same_area,      
                                          transforms=ground_transforms_val,
                                          )
    
    query_sampler_test = DistributedSampler(query_dataset_test, shuffle=False)
    query_dataloader_test = DataLoader(
        query_dataset_test,
        batch_size=config.batch_size_eval,
        num_workers=config.num_workers,
        pin_memory=True,
        sampler=query_sampler_test
    )
    
    if rank == 0:
        print("Query Images Test:", len(query_dataset_test))
        print("Reference Images Test:", len(reference_dataset_test))
    

    
    
    
    if config.gps_sample:
    
    
        print("\nGPS Sample: True - GPS Sampling")
        exit(0)
    else:
        sim_dict = None
        if rank == 0:
            print("\nGPS Sample: False - No GPS Sampling")
    
    
    
    
    
    if config.sim_sample:
        
        query_dataset_train = VigorDatasetEval(data_folder=config.data_folder ,
                                               split="train",
                                               img_type="query",
                                               same_area=config.same_area,      
                                               transforms=ground_transforms_val,
                                               )
        
        query_sampler_train = DistributedSampler(query_dataset_train, shuffle=False)
        query_dataloader_train = DataLoader(
            query_dataset_train,
            batch_size=config.batch_size_eval,
            num_workers=config.num_workers,
            pin_memory=True,
            sampler=query_sampler_train
        )
        
        
        reference_dataset_train = VigorDatasetEval(data_folder=config.data_folder ,
                                                   split="train",
                                                   img_type="reference",
                                                   same_area=config.same_area,  
                                                   transforms=sat_transforms_val,
                                                   )
        
        reference_sampler_train = DistributedSampler(reference_dataset_train, shuffle=False)
        reference_dataloader_train = DataLoader(
            reference_dataset_train,
            batch_size=config.batch_size_eval,
            num_workers=config.num_workers,
            pin_memory=True,
            sampler=reference_sampler_train
        )
            
        if rank == 0:
            print("\nQuery Images Train:", len(query_dataset_train))
            print("Reference Images Train (unique):", len(reference_dataset_train))
        
    
    
    
    

    loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
    loss_function = InfoNCE(
        loss_function=loss_fn,
        device=device,
        world_size=world_size
    )

    if config.mixed_precision:
        scaler = GradScaler(init_scale=2.**10)
    else:
        scaler = None
        
    
    
    

    
    if config.decay_exclue_bias:
        
        model_without_ddp = model.module if hasattr(model, 'module') else model
        param_optimizer = list(model_without_ddp.named_parameters())
        no_decay = ["bias", "LayerNorm.bias"]
        optimizer_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.01,
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = torch.optim.AdamW(optimizer_parameters, lr=config.lr)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)


    
    
    

    
    train_steps = len(train_dataloader) * config.epochs
    warmup_steps = len(train_dataloader) * config.warmup_epochs
       
    if config.scheduler == "polynomial":
        if rank == 0:
            print("\nScheduler: polynomial - max LR: {} - end LR: {}".format(config.lr, config.lr_end))  
        scheduler = get_polynomial_decay_schedule_with_warmup(
            optimizer,
            num_training_steps=train_steps,
            lr_end=config.lr_end,
            power=1.5,
            num_warmup_steps=warmup_steps
        )
        
    elif config.scheduler == "cosine":
        if rank == 0:
            print("\nScheduler: cosine - max LR: {}".format(config.lr))   
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_training_steps=train_steps,
            num_warmup_steps=warmup_steps
        )
        
    elif config.scheduler == "constant":
        if rank == 0:
            print("\nScheduler: constant - max LR: {}".format(config.lr))   
        scheduler = get_constant_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps
        )
           
    else:
        scheduler = None
        
    if rank == 0:
        print("Warmup Epochs: {} - Warmup Steps: {}".format(str(config.warmup_epochs).ljust(2), warmup_steps))
        print("Train Epochs:  {} - Train Steps:  {}".format(config.epochs, train_steps))
        
        
    
    
    
    if config.zero_shot:
        
        if rank == 0:
            print("\n{}[{}]{}".format(30*"-", "Zero Shot", 30*"-"))  

        
        r1_test = evaluate(config=config,
                           model=model,
                           reference_dataloader=reference_dataloader_test,
                           query_dataloader=query_dataloader_test, 
                           ranks=[1, 5, 10],
                           step_size=1000,
                           cleanup=True)
        
        if config.sim_sample:
            r1_train, sim_dict = calc_sim(config=config,
                                          model=model,
                                          reference_dataloader=reference_dataloader_train,
                                          query_dataloader=query_dataloader_train, 
                                          ranks=[1, 5, 10],
                                          step_size=1000,
                                          cleanup=True)
        else:
            sim_dict = None
    else:
        sim_dict = None
        
    
    if config.sim_sample and config.zero_shot:
        
        
        if rank == 0:
            sim_dict_bytes = pickle.dumps(sim_dict)
            sim_dict_tensor = torch.ByteTensor(list(sim_dict_bytes)).to(device)
            sim_dict_size = torch.LongTensor([len(sim_dict_bytes)]).to(device)
        else:
            sim_dict_tensor = torch.ByteTensor().to(device)
            sim_dict_size = torch.LongTensor([0]).to(device)
        
        
        dist.broadcast(sim_dict_size, src=0)
        
        
        if rank != 0:
            sim_dict_tensor = torch.ByteTensor(sim_dict_size.item()).to(device)
        dist.broadcast(sim_dict_tensor, src=0)
        
        
        if rank != 0:
            sim_dict_bytes = bytes(sim_dict_tensor.tolist())
            sim_dict = pickle.loads(sim_dict_bytes)
    
    
    
    
    if config.custom_sampling: 
        train_dataloader.dataset.shuffle(sim_dict,
                                        neighbour_select=config.neighbour_select,
                                        neighbour_range=config.neighbour_range)
        
        sync_samples_across_processes(train_dataloader.dataset, config.device, config.rank)
                
    
    
    
    start_epoch = 0   
    best_score = 0
    
    
    if rank == 0:
        print(f"训练数据加载器长度: {len(train_dataloader)}")
        print(f"世界大小: {world_size}")
        print(f"批次大小: {config.batch_size}")
    
    for epoch in range(1, config.epochs+1):
        if rank == 0:
            print("\n{}[Epoch: {}]{}".format(30*"-", epoch, 30*"-"))
        
        
        train_sampler.set_epoch(epoch)
        
        if config.custom_sampling:
            if rank == 0:
                print(f"第{epoch}轮使用custom sampling")

        train_loss = train(config,
                           model,
                           dataloader=train_dataloader,
                           loss_function=loss_function,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           scaler=scaler)
        
        if rank == 0:
            print("Epoch: {}, Train Loss = {:.3f}, Lr = {:.6f}".format(epoch,
                                                                   train_loss,
                                                                   optimizer.param_groups[0]['lr']))
        
        
        if (epoch % config.eval_every_n_epoch == 0 and epoch != 0) or epoch == config.epochs:
            if rank == 0:
                print(f"开始第{epoch}轮评估...")
            
            
            dist.barrier()
            
            
            if config.sim_sample and epoch >= config.sim_sample_start_epoch:
                try:
                    
                    r1_train, sim_dict = calc_sim(config=config,
                                                  model=model,
                                                  reference_dataloader=reference_dataloader_train,
                                                  query_dataloader=query_dataloader_train, 
                                                  ranks=[1, 5, 10],
                                                  step_size=1000,
                                                  cleanup=True)
                    
                    
                    if rank == 0:
                        sim_dict_bytes = pickle.dumps(sim_dict)
                        sim_dict_tensor = torch.ByteTensor(list(sim_dict_bytes)).to(device)
                        sim_dict_size = torch.LongTensor([len(sim_dict_bytes)]).to(device)
                    else:
                        sim_dict_tensor = torch.ByteTensor().to(device)
                        sim_dict_size = torch.LongTensor([0]).to(device)
                    
                    
                    dist.broadcast(sim_dict_size, src=0)
                    
                    
                    if rank != 0:
                        sim_dict_tensor = torch.ByteTensor(sim_dict_size.item()).to(device)
                    dist.broadcast(sim_dict_tensor, src=0)
                    
                    
                    if rank != 0:
                        sim_dict_bytes = bytes(sim_dict_tensor.tolist())
                        sim_dict = pickle.loads(sim_dict_bytes)
                    
                    if rank == 0:
                        print(f"第{epoch}轮评估完成，R1: {r1_train:.4f}")
                except Exception as e:
                    if rank == 0:
                        print(f"第{epoch}轮评估出错: {e}")
                    sim_dict = None
            else:
                sim_dict = None
                if rank == 0 and epoch < config.sim_sample_start_epoch:
                    print(f"第{epoch}轮使用random sampling")
            
            
            dist.barrier()
            
            
            if rank == 0:
                model_to_save = model.module if hasattr(model, 'module') else model
                torch.save(model_to_save.state_dict(), '{}/weights_e{}.pth'.format(model_path, epoch))
                print(f"第{epoch}轮模型已保存")

        
        if config.custom_sampling:
            if epoch >= config.sim_sample_start_epoch:
                
                sim_dict_to_use = sim_dict
            else:
                sim_dict_to_use = None
            if rank == 0:
                print(f"开始第{epoch}轮数据重排...")
            try:
                train_dataloader.dataset.shuffle(sim_dict_to_use,
                                                neighbour_select=config.neighbour_select,
                                                neighbour_range=config.neighbour_range)
                
                sync_samples_across_processes(train_dataloader.dataset, config.device, config.rank)
                if rank == 0:
                    print(f"第{epoch}轮数据重排完成")
            except Exception as e:
                if rank == 0:
                    print(f"第{epoch}轮数据重排出错: {e}")
            
            
            dist.barrier()
        
        
        if rank == 0:
            import gc
            gc.collect()
            torch.cuda.empty_cache()
                
    
    if rank == 0:
        model_to_save = model.module if hasattr(model, 'module') else model
        torch.save(model_to_save.state_dict(), '{}/weights_end.pth'.format(model_path))
    
    
    dist.destroy_process_group()

if __name__ == '__main__':
    
    
    
    
    main()