# train_unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
import time
import numpy as np
import math
import os


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler



from unet_model_new import UNetPotentialField


# --- Configuration ---
PT_DATA_FILE = Path("new_randonLOC_mixed_dataset.pt")
MODEL_SAVE_PATH = Path("new_randonLOC_mixed_dataset/unet_potential_field_256ch_best_new_ddp.pth")
LATEST_CHECKPOINT_PATH = Path("new_randonLOC_mixed_dataset/unet_potential_field_256ch_latest_ddp.pth")

NUM_EPOCHS = 100
BATCH_SIZE = 4096 
LEARNING_RATE = 5e-4
VALIDATION_SPLIT = 0.2
RANDOM_SEED_DATASET = 42
USE_BILINEAR_UPSAMPLE = False

SCHEDULER_PATIENCE = 5
SCHEDULER_FACTOR = 0.1
SCHEDULER_MIN_LR = 1e-7
SCHEDULER_THRESHOLD = 1e-5

EARLY_STOPPING_PATIENCE = 10
EARLY_STOPPING_MIN_DELTA = 1e-4
INIT_CHN = 256 # 

LAMBDA_GRAD = 0.1


def ddp_setup():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def ddp_cleanup():
    dist.destroy_process_group()


def count_parameters(model: nn.Module):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  total_params: {total_params:,}")
    print(f"  trainable_params: {trainable_params:,}")
    print(f"  (total_params: {total_params/1e6:.2f}M, trainable_params: {trainable_params/1e6:.2f}M)\n")
# --- PARAM COUNT END ---

class PFGDataset(Dataset):
    def __init__(self, pt_file_path, transform=None):
        self.pt_file_path = Path(pt_file_path)
        if not self.pt_file_path.exists():
            raise FileNotFoundError(f"data file not found: {self.pt_file_path}")

        # 只在主进程打印日志
        if int(os.environ.get("RANK", 0)) == 0:
            print(f"load file: {self.pt_file_path}...")
        
        # 所有进程都需要加载数据到内存
        self.X_spatial, self.X_nonspatial, self.Y_target = torch.load(self.pt_file_path, map_location='cpu')
        
        if int(os.environ.get("RANK", 0)) == 0:
            print("load ok.")

        # Sanity checks
        if not (self.X_spatial.ndim == 4 and self.X_nonspatial.ndim == 2 and self.Y_target.ndim == 4):
            raise ValueError("loaded tensors have incorrect dimensions.")
        if not (self.X_spatial.shape[0] == self.X_nonspatial.shape[0] == self.Y_target.shape[0]):
            raise ValueError("the number of samples in the loaded tensors does not match.")
        if self.Y_target.shape[1] != 1:
            raise ValueError(f"target tensor should have 1 channel, but got {self.Y_target.shape[1]}")
        if self.X_nonspatial.shape[1] != 2:
            raise ValueError(f"expected 2 non-spatial features, but got {self.X_nonspatial.shape[1]} from {self.pt_file_path}")

        self.n_samples = self.X_spatial.shape[0]
        self.spatial_channels = self.X_spatial.shape[1]
        self.non_spatial_features = self.X_nonspatial.shape[1]
        self.height = self.X_spatial.shape[2]
        self.width = self.X_spatial.shape[3]

        if int(os.environ.get("RANK", 0)) == 0:
            print(f"  sample count: {self.n_samples}")
            print(f"  spatial feature channels: {self.spatial_channels}")
            print(f"  non-spatial features: {self.non_spatial_features}")
            print(f"  image height x width: {self.height}x{self.width}")

        self.transform = transform

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        spatial_sample = self.X_spatial[idx]
        non_spatial_sample = self.X_nonspatial[idx]
        target_sample = self.Y_target[idx]
        return spatial_sample, non_spatial_sample, target_sample

def gradient_loss(prediction, target):
    """
     (L1 Loss)。
    """
    sobel_y = torch.tensor([[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]]).view(1, 1, 3, 3).to(prediction.device)
    sobel_x = sobel_y.transpose(2, 3)

    pred_grad_x = F.conv2d(prediction, sobel_x, padding=1)
    pred_grad_y = F.conv2d(prediction, sobel_y, padding=1)
    target_grad_x = F.conv2d(target, sobel_x, padding=1)
    target_grad_y = F.conv2d(target, sobel_y, padding=1)
    
    return F.l1_loss(pred_grad_x, target_grad_x) + F.l1_loss(pred_grad_y, target_grad_y)


def train_model(rank, world_size, model, train_loader, val_loader, optimizer, scheduler, num_epochs, start_epoch,
                initial_best_val_loss_saving, initial_best_val_loss_early_stop, initial_epochs_no_improve):
    
    best_val_loss_for_saving = initial_best_val_loss_saving
    best_val_loss_for_early_stop = initial_best_val_loss_early_stop
    epochs_no_improve = initial_epochs_no_improve
    
    mse_criterion = nn.MSELoss() # 

    if rank == 0:
        count_parameters(model)
        if start_epoch > 0:
            print(f"  recover from checkpoint:  = {best_val_loss_for_saving:.4f}")
            print(f"best  {best_val_loss_for_early_stop:.4f},  = {epochs_no_improve}")
        print(f"learning rate scheduler: ReduceLROnPlateau (patience={SCHEDULER_PATIENCE}, factor={SCHEDULER_FACTOR}, threshold={SCHEDULER_THRESHOLD})")
        print(f"early stopping mechanism: patience={EARLY_STOPPING_PATIENCE}, min_delta={EARLY_STOPPING_MIN_DELTA}")
        print(f"gradient loss weight (LAMBDA_GRAD): {LAMBDA_GRAD}")


    for epoch in range(start_epoch, num_epochs):
        current_epoch_display = epoch + 1
        epoch_start_time = time.time()
        
        
        train_loader.sampler.set_epoch(epoch)
        
        
        model.train()
        running_loss = 0.0
        
        for i, (spatial_inputs, non_spatial_inputs, targets) in enumerate(train_loader):
            spatial_inputs = spatial_inputs.to(rank)
            non_spatial_inputs = non_spatial_inputs.to(rank)
            targets = targets.to(rank)

            optimizer.zero_grad()
            outputs = model(spatial_inputs, non_spatial_inputs)
            
            main_loss = mse_criterion(outputs, targets)
            grad_loss = gradient_loss(outputs, targets)
            loss = main_loss + LAMBDA_GRAD * grad_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * spatial_inputs.size(0)

            if rank == 0 and (i + 1) % (len(train_loader) // 4 + 1) == 0:
                print(f"  round [{current_epoch_display}/{num_epochs}], step [{i+1}/{len(train_loader)}], batch loss: {loss.item():.4f}")

        epoch_train_loss = running_loss / len(train_loader.sampler)
        
        model.eval()
        val_running_loss = 0.0
        current_val_loss = float('nan')

        if val_loader.dataset and len(val_loader.dataset) > 0:
            with torch.no_grad():
                for spatial_inputs_val, non_spatial_inputs_val, targets_val in val_loader:
                    spatial_inputs_val = spatial_inputs_val.to(rank)
                    non_spatial_inputs_val = non_spatial_inputs_val.to(rank)
                    targets_val = targets_val.to(rank)

                    outputs_val = model(spatial_inputs_val, non_spatial_inputs_val)
                    
                    main_loss_val = mse_criterion(outputs_val, targets_val)
                    grad_loss_val = gradient_loss(outputs_val, targets_val)
                    loss_val = main_loss_val + LAMBDA_GRAD * grad_loss_val
                    
                    val_running_loss += loss_val.item() * spatial_inputs_val.size(0)
            
            
            total_val_loss_tensor = torch.tensor(val_running_loss).to(rank)
            dist.all_reduce(total_val_loss_tensor, op=dist.ReduceOp.SUM)
            
            if rank == 0:
                current_val_loss = total_val_loss_tensor.item() / len(val_loader.dataset)
            
        
        if rank == 0:
            epoch_duration = time.time() - epoch_start_time
            current_lr = optimizer.param_groups[0]['lr']

            print(f"round [{current_epoch_display}/{num_epochs}] in {epoch_duration:.2f}s: "
                  f"training loss: {epoch_train_loss:.4f}, validation loss: {current_val_loss:.4f}, current learning rate: {current_lr:.1e}")

            if not math.isnan(current_val_loss):
                scheduler.step(current_val_loss)

                if current_val_loss < best_val_loss_for_saving:
                    best_val_loss_for_saving = current_val_loss
                    
                    # 保存时要获取DDP包装下的原始模型
                    torch.save(model.module.state_dict(), MODEL_SAVE_PATH)
                    
                    print(f"  new best model weights saved to {MODEL_SAVE_PATH} (validation loss: {best_val_loss_for_saving:.4f})")

                if current_val_loss < best_val_loss_for_early_stop - EARLY_STOPPING_MIN_DELTA:
                    best_val_loss_for_early_stop = current_val_loss
                    epochs_no_improve = 0
                    print(f"  validation loss improved, resetting early stopping counter.")
                else:
                    epochs_no_improve += 1
                    print(f"  validation loss not significantly improved, early stopping counter: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")

                if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
                    print(f"\nearly stopping triggered: validation loss not improved in {EARLY_STOPPING_PATIENCE} rounds.")
                    stop_signal = torch.tensor(1.0).to(rank)
                else:
                    stop_signal = torch.tensor(0.0).to(rank)
            else:
                 print("  validation set is empty or validation loss is NaN, skipping learning rate scheduling, model saving and early stopping check.")
                 stop_signal = torch.tensor(0.0).to(rank)


            checkpoint_state = {
                'epoch': epoch,
                'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss_for_saving': best_val_loss_for_saving,
                'best_val_loss_for_early_stop': best_val_loss_for_early_stop,
                'epochs_no_improve': epochs_no_improve,
            }
            torch.save(checkpoint_state, LATEST_CHECKPOINT_PATH)
            print(f"  latest checkpoint saved to {LATEST_CHECKPOINT_PATH} (epoch {current_epoch_display} completed)")
        else:
            stop_signal = torch.tensor(0.0).to(rank)

        dist.broadcast(stop_signal, src=0)
        if stop_signal.item() == 1:
            break # 

    if rank == 0:
        print("\nTraining completed.")
        if not math.isnan(best_val_loss_for_saving):
            print(f"Best validation loss (for model weight saving): {best_val_loss_for_saving:.4f}")
        print(f"Best model weights saved to {MODEL_SAVE_PATH}")
        print(f"Latest training state checkpoint saved to {LATEST_CHECKPOINT_PATH}")


def main():
    ddp_setup()
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = rank # 设备ID就是进程的rank

    if rank == 0:
        print(f"use {world_size}  gpus for training.")
        MODEL_SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)


    try:
        full_dataset = PFGDataset(pt_file_path=PT_DATA_FILE)
    except Exception as e:
        if rank == 0:
            print(f"error: {e}")
        ddp_cleanup()
        return

    num_samples = len(full_dataset)
    val_size = int(VALIDATION_SPLIT * num_samples)
    train_size = num_samples - val_size
    
    # ... (dataset splitting logic remains the same) ...
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(RANDOM_SEED_DATASET)
    )
    
    if rank == 0:
        print(f"Training set size: {len(train_dataset)}")
        print(f"Validation set size: {len(val_dataset)}")

    
    # 
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)

    num_workers_dl = 6 # 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_dl, pin_memory=True, sampler=train_sampler)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_dl, pin_memory=True, sampler=val_sampler)
    
    
    n_spatial_channels = full_dataset.spatial_channels
    n_non_spatial_features = full_dataset.non_spatial_features
    
    model = UNetPotentialField(
        n_spatial_channels=n_spatial_channels,
        n_non_spatial_features=n_non_spatial_features,
        n_classes_out=1,
        bilinear_upsample=USE_BILINEAR_UPSAMPLE, init_channel=INIT_CHN
    ).to(device)

    
    model = DDP(model, device_ids=[device])
    

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=SCHEDULER_FACTOR, patience=SCHEDULER_PATIENCE, verbose=(rank==0), min_lr=SCHEDULER_MIN_LR, threshold=SCHEDULER_THRESHOLD)

    start_epoch = 0
    initial_best_val_loss_saving = float('inf')
    initial_best_val_loss_early_stop = float('inf')
    initial_epochs_no_improve = 0

    if os.path.exists(LATEST_CHECKPOINT_PATH):
        if rank == 0:
            print(f"load : {LATEST_CHECKPOINT_PATH}")
        
        checkpoint = torch.load(LATEST_CHECKPOINT_PATH, map_location=f'cuda:{device}')
        try:
            model.module.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            if rank == 0:
                initial_best_val_loss_saving = checkpoint['best_val_loss_for_saving']
                initial_best_val_loss_early_stop = checkpoint.get('best_val_loss_for_early_stop', float('inf'))
                initial_epochs_no_improve = checkpoint.get('epochs_no_improve', 0)
            
            if rank == 0:
                 print(f"Checkpoint loaded successfully. Continuing training from epoch {start_epoch + 1}.")
        except Exception as e:
            if rank == 0:
                print(f"Failed to load checkpoint: {e}. Starting training from scratch.")
            start_epoch = 0
    elif rank == 0:
        print(f"Checkpoint file not found: {LATEST_CHECKPOINT_PATH}. Starting training from scratch.")

    train_model(rank, world_size, model, train_loader, val_loader, optimizer, scheduler, NUM_EPOCHS,
                start_epoch, initial_best_val_loss_saving, initial_best_val_loss_early_stop, initial_epochs_no_improve)

    ddp_cleanup()


if __name__ == "__main__":
    main()