import torch
import torch.distributed as dist
import os
import json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    device = torch.device('cuda:0') 
    os.environ['MASTER_ADDR'] = 'localhost' # '133.11.138.15'
    os.environ['MASTER_PORT'] = '8888'
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(rank)  # 0 for master process, increment for additional processes
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def main():
    # メインのGPUデバイスを指定
    setup(0, 2)
    # load config
    config_path = "/home/***/work/doob_apps/hug/configs/configs.json"
    with open(config_path, "r") as f:
        config = json.load(f)
    image_path = config["image_ref_path"]
    images = torch.load(image_path)
    images_batch = images[:2].to(device)
    model_path = config["model_path"]

    # Load the butterfly pipeline
    pipeline = DDPMPipeline.from_pretrained(
        model_path
    ).to(device)

    unet = pipeline.unet.to(device)
    unet = DDP(unet, device_ids=[0,1,2,3], output_device=0)

if __name__ == "__main__":
    main()
