# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import subprocess
import sys
import socket

import torch
import torch.distributed as dist
from ipdb import set_trace

def get_model(model):
    if isinstance(model, torch.nn.DataParallel) \
      or isinstance(model, torch.nn.parallel.DistributedDataParallel):
        return model.module
    else:
        return model


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    else:
        return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def find_free_port():
    """Find a free port on localhost."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        s.listen(1)
        port = s.getsockname()[1]
    return port


def save_on_master(state, is_best, output_dir, is_epoch=True, save_latest_checkpoint=True):
    if is_main_process():
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        ckpt_path = f'{output_dir}/checkpoint.pt'
        best_path = f'{output_dir}/checkpoint_best.pt'

        # Use temporary file and atomic rename to avoid corruption
        if is_best:
            temp_best_path = f'{best_path}.tmp'
            torch.save(state, temp_best_path)
            os.replace(temp_best_path, best_path)

        if is_epoch:
            if isinstance(state['epoch'], int):
                ckpt2_path = '{}/checkpoint_{:04d}.pt'.format(output_dir, state['epoch'])
            else:
                ckpt2_path = '{}/checkpoint_{:.4f}.pt'.format(output_dir, state['epoch'])

            # Optionally save a rolling "latest" checkpoint (checkpoint.pt).
            # Continual learning runs in this repo typically do NOT need this and it wastes disk space.
            if save_latest_checkpoint:
                temp_ckpt_path = f'{ckpt_path}.tmp'
                torch.save(state, temp_ckpt_path)
                os.replace(temp_ckpt_path, ckpt_path)

            # Also save epoch-specific checkpoint
            temp_ckpt2_path = f'{ckpt2_path}.tmp'
            torch.save(state, temp_ckpt2_path)
            os.replace(temp_ckpt2_path, ckpt2_path)

    # Synchronize all processes after checkpoint save
    if is_dist_avail_and_initialized():
        dist.barrier()


def init_distributed_mode(args):
    # launched with torch.distributed.launch
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    # launched with submitit on a slurm cluster
    elif 'SLURM_PROCID' in os.environ:
        #args.rank = int(os.environ['SLURM_PROCID'])
        #args.gpu = args.rank % torch.cuda.device_count()
        proc_id = int(os.environ['SLURM_PROCID'])
        ntasks = os.environ['SLURM_NTASKS']
        node_list = os.environ['SLURM_NODELIST']
        num_gpus = torch.cuda.device_count()
        addr = subprocess.getoutput(
            'scontrol show hostname {} | head -n1'.format(node_list)
        )
        master_port = os.environ.get('MASTER_PORT', '29486')
        os.environ['MASTER_PORT'] = master_port
        os.environ['MASTER_ADDR'] = addr
        os.environ['WORLD_SIZE'] = str(ntasks)
        os.environ['RANK'] = str(proc_id)
        os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
        os.environ['LOCAL_SIZE'] = str(num_gpus)
        args.dist_url = 'env://'
        args.world_size = int(ntasks)
        args.rank = int(proc_id)
        args.gpu = int(proc_id % num_gpus)
        print(f'SLURM MODE: proc_id: {proc_id}, ntasks: {ntasks}, node_list: {node_list}, num_gpus:{num_gpus}, addr:{addr}, master port:{master_port}' )
        
    # launched naively with `python main_dino.py`
    # we manually add MASTER_ADDR and MASTER_PORT to env variables
    elif torch.cuda.is_available():
        print('Will run the code on one GPU.')
        args.rank, args.gpu, args.world_size = 0, 0, 1
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        # Use free port to avoid conflicts when running multiple jobs on same server
        free_port = find_free_port()
        os.environ['MASTER_PORT'] = str(free_port)
        print(f'Using free port: {free_port}')
    else:
        # print('Does not support training without GPU.')
        # sys.exit(1)
        print('Training without GPU')
        return 

    dist.init_process_group(
        backend="nccl",
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
    )
    
    args.distributed = True

    torch.cuda.set_device(args.gpu)
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    dist.barrier()
    setup_for_distributed(args.rank == 0)