import torch
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.autograd import Function
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Cartesian, Distance
from models import GenGeomAutoencoder, MMDLoss
from datatools import WindTerrainDataset, compute_dataset_stats, norm_data
from box import Box
import yaml
import os
from tqdm import tqdm
import string
import random
torch.multiprocessing.set_sharing_strategy('file_system')

def setup(rank, world_size):
    """Initialize the distributed environment."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """Clean up the distributed environment."""
    dist.destroy_process_group()
    
class GatherLayer(Function):
    """
    Gather tensors from all processes, supporting backward propagation.
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        input, = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

def compute_loss(data, rank=None, world_size=None):
    # Compute reconstruction loss (this is fine per-GPU)
    recon_loss = torch.mean((data.x - data.y)**2.)
    
    # Get latent representations
    z_local = data.z
    
    # If not distributed or single GPU, compute normally
    if world_size is None or world_size == 1:
        mmd = MMDLoss(device=data.x.device)
        Xd = torch.randn_like(z_local).to(data.x.device)
        mmd_loss = mmd(z_local, Xd)
    else:
        # Compute MMD on gathered latents from all GPUs
        mmd_loss = compute_distributed_mmd(z_local, rank, world_size)
    
    loss = recon_loss + mmd_loss
    return loss, recon_loss, mmd_loss

def compute_distributed_mmd(z_local, rank, world_size):
    device = z_local.device
    
    # This maintains gradients across all GPUs
    z_list = GatherLayer.apply(z_local)
    z_all = torch.cat(z_list, dim=0)

    # Generate prior samples for the full batch
    # Important: Use the same random seed across all ranks for consistency
    # Save current RNG state and set a fixed seed
    rng_state = torch.get_rng_state()
    if device.type == 'cuda':
        cuda_rng_state = torch.cuda.get_rng_state(device)
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)
    
    Xd = torch.randn_like(z_all)
    
    # Restore RNG state
    torch.set_rng_state(rng_state)
    if device.type == 'cuda':
        torch.cuda.set_rng_state(cuda_rng_state, device)
    
    # Initialize MMD loss
    mmd = MMDLoss(device=device)
    mmd_loss = mmd(z_all, Xd)
    
    # Divide by world_size to avoid gradient accumulation
    mmd_loss = mmd_loss / world_size
    
    return mmd_loss

def train_distributed(rank, world_size, config_path):
    """Main training function."""
    # Setup distributed training
    setup(rank, world_size)
    
    # Load the config file
    config = Box.from_yaml(filename=config_path, Loader=yaml.FullLoader)
    
    # Set up transforms
    if config.data_settings.transform == 'Cartesian':
        transform = Cartesian(norm=False)
    elif config.data_settings.transform == 'Distance':
        transform = Distance(norm=False)
    else:
        transform = None

    # Initialize the datasets
    train_dataset = WindTerrainDataset(
        filename=config.io_settings.train_dataset_path, 
        transform=transform, 
        channels=config.data_settings.channels,
        max_cells_above_terrain=config.data_settings.max_cells_above_terrain,
        mode='train'  # Ensure training mode for data augmentation and cropping
    )
    
    # Create distributed sampler for data parallelism
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        drop_last=True
    )
    
    # Adjust batch size per GPU
    batch_size_per_gpu = config.hyperparameters.batch_size // world_size
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size_per_gpu, 
        sampler=train_sampler,
        exclude_keys=['terrain_mask', 'fluid_indices'],
        num_workers=config.run_settings.num_t_workers, 
        pin_memory=True,
        persistent_workers=False if config.run_settings.num_t_workers == 0 else True
    )
    
    if config.run_settings.validate:
        validate_dataset = WindTerrainDataset(
            filename=config.io_settings.valid_dataset_path, 
            transform=transform, 
            channels=config.data_settings.channels,
            max_cells_above_terrain=config.data_settings.max_cells_above_terrain,
            mode='eval'  # Validation mode, no augmentation or cropping (samples already resampled)
        )
        
        # Split validation indices across GPUs
        val_indices = list(range(len(validate_dataset)))
        val_indices_per_rank = val_indices[rank::world_size]
        val_subset = torch.utils.data.Subset(validate_dataset, val_indices_per_rank)
        
        validate_loader = DataLoader(
            val_subset, 
            batch_size=batch_size_per_gpu, 
            shuffle=False,
            num_workers=config.run_settings.num_v_workers, 
            pin_memory=True, 
            exclude_keys=['terrain_mask', 'fluid_indices'],
            persistent_workers=False if config.run_settings.num_v_workers == 0 else True
        )
    
    # Get the dimensions of the data
    config.data_dims = train_dataset.get_data_dims_dict()

    # Only initialize experiment tracking on rank 0
    if rank == 0:
        uid = ''.join(random.choices(string.ascii_letters + string.digits, k=4))
        run_name = '{}_dim_{}_uid_{}_{}gpu'.format(
            config.model_settings.model_type, 
            config.model_settings.latent_dim, 
            uid,
            world_size
        )
        
        # Create model saving dir and save config
        current_run_dir = os.path.join(config.io_settings.run_dir, run_name)
        os.makedirs(os.path.join(current_run_dir, 'trained_models'), exist_ok=True)
        config.to_yaml(filename=os.path.join(current_run_dir, 'config.yml'))
    else:
        run_name = None
        current_run_dir = None

    # Set device for this rank
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(rank)
    
    # Initialize the model
    model = GenGeomAutoencoder(**config.data_dims, **config.hyperparameters, **config.model_settings)

    # Load pretrained model if specified
    if config.io_settings.pretrained_model:
        checkpoint = torch.load(config.io_settings.pretrained_model, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
    
    # Compute dataset stats (only on rank 0 and broadcast)
    if rank == 0:
        model.trainset_stats = compute_dataset_stats(train_loader, device)
    else:
        model.trainset_stats = None
    
    # Broadcast stats from rank 0 to all other ranks
    trainset_stats_list = [model.trainset_stats]
    dist.broadcast_object_list(trainset_stats_list, src=0)
    model.trainset_stats = trainset_stats_list[0]
    
    # Move model to device and wrap with DDP
    model = model.to(device)

    model = DistributedDataParallel(
        model, 
        device_ids=[rank], 
        output_device=rank,
        static_graph=True  # Set static graph for DDP
    )
    
    # Define optimizer
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=float(config.hyperparameters.start_lr), 
        weight_decay=float(config.hyperparameters.weight_decay)
    )
    
    # Define scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, 
        gamma=float(config.hyperparameters.lr_decay)
    )
    
    # Training loop
    if rank == 0:
        print(f'Starting run {run_name} on {world_size} GPUs')
        print(f'Effective batch size for MMD: {config.hyperparameters.batch_size}')
        pbar = tqdm(total=config.hyperparameters.epochs)
        pbar.set_description('Training')
    
    best_validation_loss = float('inf')
    
    for epoch in range(config.hyperparameters.epochs):
        # Set epoch for distributed sampler (important for shuffling)
        train_sampler.set_epoch(epoch)
        
        train_loss = 0
        train_recon_loss = 0
        train_mmd_loss = 0
        model.train()
        
        # Mini-batch loop
        for i_batch, data in enumerate(train_loader):
            # Norm the data
            data = norm_data(data, model.module.trainset_stats)
            
            # Move data to device
            data = data.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            data = model(data)
            
            # Compute loss with distributed MMD
            batch_loss, batch_recon_loss, batch_mmd_loss = compute_loss(data, rank, world_size)
            
            train_loss += batch_loss.item()
            train_recon_loss += batch_recon_loss.item()
            train_mmd_loss += batch_mmd_loss.item()
            
            # Backward pass
            batch_loss.backward()
            optimizer.step()
        
        # Average training loss across mini-batches
        train_loss = train_loss / len(train_loader)
        train_recon_loss = train_recon_loss / len(train_loader)
        train_mmd_loss = train_mmd_loss / len(train_loader)
        
        # Reduce training loss across all ranks (for logging consistency)
        train_loss_tensor = torch.tensor([train_loss, train_recon_loss, train_mmd_loss], device=device)
        dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.AVG)
        train_loss, train_recon_loss, train_mmd_loss = train_loss_tensor.tolist()
        
        # Step scheduler
        scheduler.step()
        

        # Save model checkpoints (only on rank 0)
        if rank == 0 and (epoch + 1) % config.io_settings.save_epochs == 0:
            torch.save({
                'model_state_dict': model.module.state_dict(),
                'trainset_stats': model.module.trainset_stats
            }, os.path.join(current_run_dir, 'trained_models', f'e{epoch + 1}.pt'))
        
        # Validation
        if config.run_settings.validate:
            validation_loss = 0
            validation_recon_loss = 0
            validation_mmd_loss = 0
            
            model.eval()
            with torch.no_grad():
                for i_batch, data in enumerate(validate_loader):
                    # Norm the data
                    data = norm_data(data, model.module.trainset_stats)
                    
                    # Move data to device
                    data = data.to(device)
                    
                    # Forward pass
                    data = model(data)
                    
                    # Compute validation loss with distributed MMD
                    batch_loss, batch_recon_loss, batch_mmd_loss = compute_loss(data, rank, world_size)
                    validation_loss += batch_loss.item()
                    validation_recon_loss += batch_recon_loss.item()
                    validation_mmd_loss += batch_mmd_loss.item()
            
            # Average validation loss
            validation_loss = validation_loss / len(validate_loader)
            validation_recon_loss = validation_recon_loss / len(validate_loader)
            validation_mmd_loss = validation_mmd_loss / len(validate_loader)
            
            # Reduce validation loss across all ranks
            val_loss_tensor = torch.tensor([validation_loss, validation_recon_loss, validation_mmd_loss], device=device)
            dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
            validation_loss, validation_recon_loss, validation_mmd_loss = val_loss_tensor.tolist()
            
            # # Update progress bar and save best validation loss model
            if rank == 0:
                
                pbar.set_postfix({'Train Loss': f'{train_loss:.8f}', 'Val Loss': f'{validation_loss:.8f}'})
                pbar.update(1)
                
                # Save best model
                if validation_loss < best_validation_loss:
                    best_validation_loss = validation_loss
                    torch.save({
                        'model_state_dict': model.module.state_dict(),
                        'trainset_stats': model.module.trainset_stats
                    }, os.path.join(current_run_dir, 'trained_models', 'best.pt'))
        else:
            if rank == 0:
                pbar.set_postfix({'Train Loss': f'{train_loss:.8f}'})
                pbar.update(1)
    
    if rank == 0:
        pbar.close()
    
    cleanup()


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', help="path to the yaml config file", type=str, required=True)
    args = parser.parse_args()
    
    world_size = torch.cuda.device_count()
    print(f"Starting training on {world_size} GPUs")
    mp.spawn(train_distributed, args=(world_size, args.config), nprocs=world_size, join=True) 

if __name__ == "__main__":
    main()