# RLDF4CO_v4/train_diffusion_new_2gpu_new.py

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf
import os
import time

from data_loader_new import TSPConditionalSuffixDataset, custom_collate_fn
from diffusion_model_new import ConditionalTSPSuffixDiffusionModel
from discrete_diffusion_new_new_new import AdjacencyMatrixDiffusion
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import GradScaler, autocast

@torch.no_grad()
def validate_model(model, diffusion_handler, valid_dataloader, device):
    model.eval()
    total_valid_loss = 0
    num_batches = 0

    for batch_data in valid_dataloader:
        instance_locs = batch_data["instance_locs"].to(device)
        prefix_nodes = batch_data["prefix_nodes"].to(device)
        prefix_lengths = batch_data["prefix_lengths"].to(device)
        x_0_adj_matrix = batch_data["target_adj_matrix"].to(device)
        node_prefix_state = batch_data["node_prefix_state"].to(device) # <<< GET NEW STATE

        t = torch.randint(1, diffusion_handler.num_timesteps + 1, (instance_locs.size(0),), device=device).long()

        loss = diffusion_handler.training_loss(
            model, x_0_adj_matrix, t, instance_locs,
            prefix_nodes, prefix_lengths, node_prefix_state # <<< PASS NEW STATE
        )
        total_valid_loss += loss.item()
        num_batches += 1
    
    if num_batches == 0:
        return float('inf')


    total_loss_tensor = torch.tensor([total_valid_loss, num_batches], dtype=torch.float64, device=device)
    dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
    
    global_total_loss, global_num_batches = total_loss_tensor[0].item(), total_loss_tensor[1].item()
    
    avg_valid_loss = global_total_loss / global_num_batches if global_num_batches > 0 else float('inf')
    return avg_valid_loss


def ddp_setup():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return torch.device("cuda", local_rank), local_rank


def load_weights_for_transfer_learning(new_model, pretrained_ckpt_path, device):
    """
    Intelligently loads weights from a pretrained model to a new model with a different size (e.g., N).
    It only loads weights for layers with matching names and shapes.

    Args:
        new_model: The newly initialized model (e.g., for TSP200).
        pretrained_ckpt_path: Path to the pretrained model checkpoint (e.g., from TSP100).
        device: The device to load the checkpoint onto.
    """
    # 1. Load the state_dict from the pretrained model checkpoint
    if not os.path.exists(pretrained_ckpt_path):
        print(f"Pretrained checkpoint not found at {pretrained_ckpt_path}. Starting from scratch.")
        return new_model

    print(f"Loading pretrained weights from {pretrained_ckpt_path} for transfer learning...")
    pretrained_dict = torch.load(pretrained_ckpt_path, map_location=device)
    
    # 2. Get the state_dict of the new model architecture
    new_model_dict = new_model.state_dict()
    
    # 3. Filter the pretrained_dict to include only layers that match in name and shape
    weights_to_load = {k: v for k, v in pretrained_dict.items() if k in new_model_dict and v.shape == new_model_dict[k].shape}
    
    # 4. Update the new model's state_dict with the transferable weights
    new_model_dict.update(weights_to_load)
    
    # 5. Load the updated state_dict into the new model
    # Use strict=False because some layers might be intentionally left out if their shapes differ.
    new_model.load_state_dict(new_model_dict, strict=False)
    
    # Optional: Print a report on what was transferred
    print("-" * 60)
    print("Transfer Learning Report:")
    print(f"Successfully transferred {len(weights_to_load)} tensors from the pretrained model.")
    
    loaded_keys = weights_to_load.keys()
    # Find keys in the new model that were not loaded from the pretrained one
    uninitialized_keys = [k for k in new_model_dict.keys() if k not in loaded_keys]
    
    if uninitialized_keys:
        print(f"Skipped or left uninitialized {len(uninitialized_keys)} tensors (this is expected for size-dependent layers).")
        # For debugging, you can print the skipped keys:
        # print("Uninitialized keys:", uninitialized_keys)
    else:
        print("All layers from the new model were successfully initialized from the pretrained model!")
    print("-" * 60)
    
    return new_model


def run_training_stage(cfg: DictConfig, stage_name: str, prefix_k_options: list, epochs_for_stage: int, device, local_rank, checkpoint_to_load: str = None):
    """
    Executes a single stage of the training curriculum.
    """
    
    if dist.get_rank() == 0:
        print(f"\n===== Starting Curriculum Stage: {stage_name} =====")
        print(f"===== Epochs: {epochs_for_stage}, Prefix K Range: {prefix_k_options[0]}-{prefix_k_options[-1]} =====")
        print(f"prefix in this stage is {prefix_k_options}")
        if checkpoint_to_load:
            print(f"===== Loading checkpoint from: {checkpoint_to_load} =====")


    time.sleep(2) # Pause for readability

    ckpt_dir = cfg.train.get("ckpt_dir", "./ckpt_difusco_style")
    os.makedirs(ckpt_dir, exist_ok=True)
    
    prefix_sampling_strategy = cfg.data.get('prefix_sampling_strategy', 'continuous_from_start')

    #global_batch_size = cfg.train.batch_size
    #per_gpu_batch_size = global_batch_size // dist.get_world_size()  # 分给每个进程

    # Setup Datasets for the current stage
    train_dataset = TSPConditionalSuffixDataset(
        npz_file_path=cfg.data.train_path,
        prefix_k_options=prefix_k_options,
        prefix_sampling_strategy=prefix_sampling_strategy
    )
    train_sampler = DistributedSampler(train_dataset)

    train_dataloader = DataLoader(
        train_dataset, batch_size=cfg.train.batch_size, shuffle=False, 
        sampler=train_sampler,
        num_workers=cfg.train.get("num_workers", 4), collate_fn=custom_collate_fn
    )

    valid_dataloader = None
    if cfg.data.get("valid_path"):
        
        valid_dataset = TSPConditionalSuffixDataset(
            npz_file_path=cfg.data.valid_path,
            prefix_k_options=prefix_k_options,
            prefix_sampling_strategy=prefix_sampling_strategy
        )


        valid_sampler = DistributedSampler(valid_dataset, shuffle=False) # shuffle=False 
        valid_dataloader = DataLoader(
            valid_dataset, 
            batch_size=cfg.train.batch_size,
            sampler=valid_sampler,shuffle=False, #  Sampler
            num_workers=cfg.train.get("num_workers", 4), 
            collate_fn=custom_collate_fn
        )
    # Initialize Model
    model = ConditionalTSPSuffixDiffusionModel(
        num_nodes=cfg.model.num_nodes, node_coord_dim=cfg.model.node_coord_dim,
        pos_embed_num_feats=cfg.model.pos_embed_num_feats, node_embed_dim=cfg.model.node_embed_dim,
        prefix_node_embed_dim=cfg.model.node_embed_dim,
        prefix_enc_hidden_dim=cfg.model.prefix_enc_hidden_dim, prefix_cond_dim=cfg.model.prefix_cond_dim,
        gnn_n_layers=cfg.model.gnn_n_layers, gnn_hidden_dim=cfg.model.gnn_hidden_dim,
        gnn_aggregation=cfg.model.gnn_aggregation, gnn_norm=cfg.model.gnn_norm,
        gnn_learn_norm=cfg.model.gnn_learn_norm, gnn_gated=cfg.model.gnn_gated,
        time_embed_dim=cfg.model.time_embed_dim
    ).to(device)


    if checkpoint_to_load:
        model_checkpoint_path = checkpoint_to_load
        if os.path.exists(model_checkpoint_path):
            try:
            # The function will print a detailed report, so we don't need extra prints here.
                model = load_weights_for_transfer_learning(model, checkpoint_to_load, device)
            except Exception as e:
                if dist.get_rank() == 0:
                    print(f"Successfully loaded single-GPU checkpoint into base model from {model_checkpoint_path}")
            except Exception as e:
                if dist.get_rank() == 0:
                    print(f"Could not load checkpoint: {e}. Starting from scratch.")
        else:
            if dist.get_rank() == 0:
                 print(f"Checkpoint file not found at {model_checkpoint_path}. Starting from scratch.")
    

    dist.barrier()  # 确保所有进程都完成了加载

    

    model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)  
    

    diffusion_handler = AdjacencyMatrixDiffusion(
        num_nodes=cfg.model.num_nodes, num_timesteps=cfg.diffusion.num_timesteps,
        schedule_type=cfg.diffusion.schedule_type, device=device
    )

    optimizer = optim.Adam(model.parameters(), lr=cfg.train.learning_rate)
    scaler = GradScaler()# 20250626
    
    best_valid_loss = float('inf')
    epochs_no_improve = 0
    early_stopping_patience = cfg.train.get("early_stopping_patience", 10)
    min_delta = cfg.train.get("early_stopping_min_delta", 0.00001)
    
    # Main training loop for the stage
    for epoch in range(epochs_for_stage):
        train_sampler.set_epoch(epoch)
        model.train()
        total_train_loss = 0
        num_train_batches = 0
        is_main_process = (dist.get_rank() == 0)
        
        for batch_idx, batch_data in enumerate(train_dataloader):
            optimizer.zero_grad()
            instance_locs = batch_data["instance_locs"].to(device)
            prefix_nodes = batch_data["prefix_nodes"].to(device)
            prefix_lengths = batch_data["prefix_lengths"].to(device)
            x_0_adj_matrix = batch_data["target_adj_matrix"].to(device)
            node_prefix_state = batch_data["node_prefix_state"].to(device) # <<< GET NEW STATE

            t = torch.randint(1, diffusion_handler.num_timesteps + 1, (instance_locs.size(0),), device=device).long()

            # loss = diffusion_handler.training_loss(
            #     model, x_0_adj_matrix, t, instance_locs,
            #     prefix_nodes, prefix_lengths, node_prefix_state # <<< PASS NEW STATE
            # )
            # loss.backward()
            # optimizer.step()

            with autocast():
                loss = diffusion_handler.training_loss(
                    model, x_0_adj_matrix, t, instance_locs,
                    prefix_nodes, prefix_lengths, node_prefix_state
                )
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            
            total_train_loss += loss.item()
            num_train_batches += 1

            if is_main_process and batch_idx % cfg.train.log_interval == 0 and batch_idx > 0:
                print(f"Stage '{stage_name}', Epoch {epoch+1}/{epochs_for_stage}, Batch {batch_idx}/{len(train_dataloader)}, Avg Train Loss: {(total_train_loss/num_train_batches):.5f}")
        
        print(f"Stage '{stage_name}', Epoch {epoch+1} completed. Average Training Loss: {(total_train_loss/num_train_batches):.5f}")

        if valid_dataloader:
            current_valid_loss = validate_model(model, diffusion_handler, valid_dataloader, device)
            print(f"Stage '{stage_name}', Epoch {epoch+1}: Validation Loss: {current_valid_loss:.5f}")
            if is_main_process:

                if current_valid_loss < best_valid_loss - min_delta:
                    best_valid_loss = current_valid_loss
                    epochs_no_improve = 0
                    best_model_path_stage = os.path.join(ckpt_dir, f"best_model_{stage_name}.pth")
                    torch.save(model.module.state_dict(), best_model_path_stage)
                    print(f"Validation loss improved. Saved best model for this stage to {best_model_path_stage}")
                else:
                    epochs_no_improve += 1
    
                if epochs_no_improve >= early_stopping_patience:
                    print(f"Early stopping triggered for stage '{stage_name}'.")
                    break
        if is_main_process:
            if (epoch + 1) % cfg.train.save_interval == 0:
                periodic_save_path = os.path.join(ckpt_dir, f"{stage_name}_epoch_{epoch+1}.pth")
                torch.save(model.module.state_dict(), periodic_save_path)
                
                print(f"Saved model checkpoint (periodic) at epoch {epoch+1} to {periodic_save_path}")
    print(f"Finished stage '{stage_name}'. Best validation loss for this stage: {best_valid_loss:.5f}")
    return os.path.join(ckpt_dir, f"best_model_{stage_name}.pth")

def ddp_cleanup():
    dist.destroy_process_group()

    
def train_with_curriculum(cfg: DictConfig):
    """
    Main function to orchestrate the curriculum learning process.
    """    
    device, local_rank = ddp_setup()
    print(f"[Rank {dist.get_rank()}] DDP setup complete. Using device: {device}")

    try:
        # Stage 1: Easy task - long prefixes

        stage1_k_options = list(range(50, 100))
        stage1_epochs = 10
        tsp100_best_ckpt = "./ckpt_tsp_difusco_style_tsp100/stage1_k0_20_epoch_50.pth"

        stage1_best_ckpt = run_training_stage(
            cfg=cfg,
            stage_name="stage1_k0_100",
            prefix_k_options=stage1_k_options,
            epochs_for_stage=stage1_epochs,
            device=device,            
            local_rank=local_rank,   
            checkpoint_to_load=tsp100_best_ckpt
        )
        #tage 2: Medium task - short prefixes
        stage2_k_options = list(range(20, 50))
        stage2_epochs = 10
        stage2_best_ckpt = run_training_stage(
            cfg=cfg,
            stage_name="stage2_k1_50",
            prefix_k_options=stage2_k_options,
            epochs_for_stage=stage2_epochs,
            device=device,           
            local_rank=local_rank,     
            checkpoint_to_load=stage1_best_ckpt
        )

        # Stage 3: Full task - all prefixes
        stage3_k_options = list(range(1, cfg.model.num_nodes))
        stage3_epochs = 10
        final_best_ckpt = run_training_stage(
            cfg=cfg,
            stage_name="stage3_k1_500_final",
            prefix_k_options=stage3_k_options,
            epochs_for_stage=stage3_epochs,
            device=device,            
            local_rank=local_rank,   
            checkpoint_to_load=stage2_best_ckpt
        )

        # Stage 4: Front task - [1-30]
        stage4_k_options = list(range(1, 30))
        stage4_epochs = 10
        final_last_best_ckpt = run_training_stage(
            cfg=cfg,
            stage_name="stage4_k1_30_last",
            prefix_k_options=stage4_k_options,
            epochs_for_stage=stage4_epochs,
            device=device,             
            local_rank=local_rank,     
            checkpoint_to_load=final_best_ckpt
        )

        stage5_k_options = list(range(1, 20))
        stage5_epochs = 10
        final_last_best_ckpt = run_training_stage(
            cfg=cfg,
            stage_name="stage5_k1_20_last",
            prefix_k_options=stage5_k_options,
            epochs_for_stage=stage5_epochs,
            device=device,            
            local_rank=local_rank,     
            checkpoint_to_load=final_last_best_ckpt
        )

        if dist.get_rank() == 0:
            print("\nCurriculum training finished!")
            final_generic_path = os.path.join(os.path.dirname(stage1_best_ckpt), "Final_0_20_best_model_checkpoint.pth")
            if os.path.exists(stage1_best_ckpt):
                os.rename(stage1_best_ckpt, final_generic_path)
                print(f"Renamed final model to: {final_generic_path}")


    finally:

        if dist.is_initialized():
            rank = dist.get_rank()
            ddp_cleanup()
            print(f"[Rank {rank}] DDP resources cleaned up.")
        else:
            # This branch would execute if setup failed in the first place
            print("DDP was not initialized, no cleanup needed.")

if __name__ == "__main__":
    config_path = "tsp500_config.yaml" 
    try:
        config = OmegaConf.load(config_path)
        print("Loaded configuration from:", config_path)
    except FileNotFoundError:
        print(f"ERROR: Configuration file '{config_path}' not found.")
        exit()
    except Exception as e:
        print(f"Error loading configuration: {e}")
        exit()
        
    train_with_curriculum(config)
