import torch
import torch.nn as nn
from tqdm import tqdm
import yaml
import argparse
from torch_geometric.data import DataLoader
from torch_geometric.nn.pool import global_mean_pool, global_add_pool
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
from torchmetrics.aggregation import CatMetric
from easydict import EasyDict
import wandb
import os
import shutil
import pickle
import copy
import sys
sys.path.append('./')

from experiments.fixddp import DistributedSamplerNoDuplicate
from datasets.nbody import NBody
from datasets.md17 import MD17Traj

from models.EA import EA
from models.EAencoder import EAencoder
from diffusion.EATDM import EATDM, ModelMeanType, ModelVarType, LossType
from utils.misc import set_seed, gather_across_gpus

torch.multiprocessing.set_sharing_strategy('file_system')

def run(rank, world_size, args):
    # Load args
    yaml_file = args.train_yaml_file
    with open(yaml_file, 'r') as f:
        params = yaml.safe_load(f)
    config = EasyDict(params)

    # Init model and optimizer
    denoise_network = EA(**config.denoise_model).to(rank)
    ea_encoder=EAencoder(**config.ea_encoder).to(rank)
    # print("ea_encoder.named_parameters",list(ea_encoder.named_parameters()))
    diffusion = EATDM(denoise_network=denoise_network,ea_encoder=ea_encoder,
                model_mean_type=ModelMeanType.EPSILON,
                model_var_type=ModelVarType.FIXED_LARGE,
                loss_type=LossType.MSE,
                device=rank,
                rescale_timesteps=False,
                **config.diffusion)
    optimizer = torch.optim.Adam(list(denoise_network.parameters()), lr=config.train.lr)

    print("init optimizer")
    if world_size > 1:
        denoise_network = DistributedDataParallel(denoise_network, device_ids=[rank])
        # ea_encoder = DistributedDataParallel(ea_encoder, device_ids=[rank])
    # optimizer = torch.optim.Adam(list(denoise_network.module.zero_x_params.parameters())+list(denoise_network.module.zero_h_params.parameters())+list(denoise_network.module.copy_s_modules.parameters())+list(denoise_network.module.copy_t_modules.parameters()),lr=config.train.lr)

    # Ensure only ref frame parameters are trainable
    for param in denoise_network.module.s_modules.parameters():
        param.requires_grad = False
    for param in denoise_network.module.t_modules.parameters():
        param.requires_grad = False
    print("freeze denoise_network")   
    for param in denoise_network.module.zero_x_params.parameters():
        param.requires_grad = True
    for param in denoise_network.module.zero_h_params.parameters():
        param.requires_grad = True
    for param in denoise_network.module.copy_s_modules.parameters():
        param.requires_grad = True
    for param in denoise_network.module.copy_t_modules.parameters():
        param.requires_grad = True
    # for param in ea_encoder.parameters():
    #     param.requires_grad = True
    print("Trainable zero modules")

    for init_param_name, init_param in diffusion.denoise_network.named_parameters():
        if  "s_modules.1.edge_mlp.actions.0.weight" == init_param_name:
            init_param_1=init_param.detach().clone()
            init_param_1=init_param_1.cpu()
            print("init_param_1:",init_param_1)
    if rank == 0:
     
        total_params = sum(p.numel() for p in denoise_network.module.parameters())
        trainable_params = sum(p.numel() for p in denoise_network.module.parameters() if p.requires_grad)
        for name, param in denoise_network.module.named_parameters():
            if param.requires_grad:
                print(f"{name}: {param.numel()}")


        print(f"Total parameters: {total_params}")
        print(f"Trainable parameters: {trainable_params}")

        param_size_bytes = total_params * 4
        param_size_mb = param_size_bytes / (1024 * 1024)
        print(f"Total parameters: {total_params} {param_size_mb:.2f} MB")

        trainable_size_bytes = trainable_params * 4
        trainable_size_mb = trainable_size_bytes / (1024 * 1024)
        print(f"Trainable parameters: {trainable_params}  {trainable_size_mb:.2f} MB")

    # Load train/val dataset
    if config.data.train.root == 'data/md17':
        dataset_train = MD17Traj(**config.data.train)
        dataset_val = MD17Traj(**config.data.val)
        task_name = config.data.train.molecule_name
        name_path = os.path.join(config.train.output_base_path, task_name)
    elif config.data.train.root == 'datasets/datagen':
        dataset_train = NBody(**config.data.train)
        dataset_val = NBody(**config.data.val)
        name_path = config.train.output_base_path
    output_path = os.path.join(name_path, config.train.exp_name)
     
# #start training
#     if config.train.train_mode == True:
#         # Save args yaml file
#         if rank == 0:
#             if not os.path.exists(output_path):
#                 os.makedirs(output_path, exist_ok=True)
#             shutil.copy(yaml_file, output_path)
#         print(output_path)

#         set_seed(config.train.seed)
#         if world_size > 1:
#             sampler_train = DistributedSampler(dataset_train)
#             sampler_val = DistributedSamplerNoDuplicate(dataset_val, shuffle=False, drop_last=False)
#         else:
#             sampler_train = None
#             sampler_val = None

#         dataloader_train = DataLoader(dataset_train, batch_size=config.train.batch_size // world_size,
#                                     shuffle=(sampler_train is None), sampler=sampler_train, pin_memory=True)
#         dataloader_val = DataLoader(dataset_val, batch_size=config.test.eval_batch_size // world_size,
#                                     shuffle=False, sampler=sampler_val, pin_memory=True)

        
#         # Load base checkpoint
#         base_model_ckpt_path = os.path.join(name_path, 'ckpt_base.pt')
#         print("base_model_ckpt_path:",base_model_ckpt_path)
#         print(os.path.exists(base_model_ckpt_path))
#         if os.path.exists(base_model_ckpt_path):
#             device = torch.device(f'cuda:{args.local_rank}')
#             state_dict = torch.load(base_model_ckpt_path, map_location=device)
#             model_dict=denoise_network.state_dict()
#             filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
#             model_dict.update(filtered_state_dict)
#             # print("model_dict:",model_dict)
#             denoise_network.load_state_dict(model_dict)
                    

#             # denoise_network.to(rank)
#             tot_step = state_dict.get('tot_step', 0)  
            
#             for i in config.denoise_model.n_copy_layer_list:
#                 layer_idx = str(i)
#                 print("copy layer_idx:",layer_idx)
#                 denoise_network.module.copy_s_modules[layer_idx].load_state_dict(copy.deepcopy(denoise_network.module.s_modules[i].state_dict()))
#                 denoise_network.module.copy_t_modules[layer_idx].load_state_dict(copy.deepcopy(denoise_network.module.t_modules[i].state_dict()))


            
        
#         for base_train_param_name, base_train_param in diffusion.denoise_network.named_parameters():
#             if  "s_modules.1.edge_mlp.actions.0.weight" == base_train_param_name:
#                 base_train_param_1=base_train_param.detach().clone()
#                 base_train_param_1=base_train_param_1.cpu()
                
#         # print("base_train_param_1:",base_train_param_1) 
#         assert not torch.equal(init_param_1, base_train_param_1), "base train model load failed"
#         print(f'Base Train Model loaded from {base_model_ckpt_path} success')
        
#         param1 = denoise_network.module.copy_s_modules["0"].edge_mlp.actions[0].weight
#         param2 = denoise_network.module.s_modules[0].edge_mlp.actions[0].weight
#         assert torch.equal(param1.cpu(), param2.cpu()),"copy layer failed"
#         print("Copy layer load success")
        
#         # for key, param in denoise_network.module.zero_x_params.items():
#         #     assert torch.equal(param, torch.zeros_like(param)), f"zero_x_params[{key}] is not all zeros"      
#         # for key, param in denoise_network.module.zero_h_params.items():
#         #     assert torch.equal(param, torch.zeros_like(param)), f"zero_h_params[{key}] is not all zeros"  
#         # print("zero params load success")

#         for param in denoise_network.module.s_modules.parameters():
#             param.requires_grad = False
#         for param in denoise_network.module.t_modules.parameters():
#             param.requires_grad = False
#         print("freeze denoise_network")   
#         for param in denoise_network.module.zero_x_params.parameters():
#             param.requires_grad = True
#         for param in denoise_network.module.zero_h_params.parameters():
#             param.requires_grad = True
#         print(param.requires_grad)
#         for param in denoise_network.module.copy_s_modules.parameters():
#             param.requires_grad = True
#         for param in denoise_network.module.copy_t_modules.parameters():
#             param.requires_grad = True
            
#         optimizer = torch.optim.Adam(list(denoise_network.module.parameters()), lr=config.train.lr)

        
        
#         if rank == 0:
#             # Wandb_train config
#             if config.wandb.no_wandb:
#                 mode = 'disabled'
#             else:
#                 mode = 'online'
#             if config.train.train_mode == True:
#                 kwargs_train = {'entity': config.wandb.wandb_usr, 'name':  task_name+"_"+ config.train.exp_name, 'project': config.wandb.project,
#                         'config': params, 'settings': wandb.Settings(_disable_stats=True), 'mode': mode}
#             train_log = wandb.init(**kwargs_train)
#             train_log.save('*.txt')

#         # Start training
#         num_epochs = config.train.num_epochs
#         tot_step = 0

#         best_val_nll, best_val_mse = 1e10, 1e10
#         reduce_placeholder = CatMetric()

#         if rank == 0:
#             progress_bar = tqdm(total=num_epochs)

#         for epoch in range(1, num_epochs + 1):
#             if rank == 0:
#                 # print(f'Start epoch {epoch}')
#                 progress_bar.set_description(f"Epoch {epoch}")
#             denoise_network.train()
#             if world_size > 1:
#                 sampler_train.set_epoch(epoch)


#             # Training

#             train_loss_epoch, counter = torch.zeros(1).to(rank), torch.zeros(1).to(rank)

#             for step, data in enumerate(dataloader_train):
#                 tot_step += 1

#                 data = data.to(rank)
#                 model_kwargs = {'edge_index': data.edge_index,
#                                 'edge_attr': data.edge_attr,
#                                 'batch': data.batch}
#                 x_given = data.x
#                 if diffusion.mode == 'cond':
#                     # Construct cond mask
#                     cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
#                     for interval in config.train.cond_mask:
#                         cond_mask[..., interval[0]: interval[1]] = 1
#                     # model_kwargs['x_given'] = x_start
#                     x_target = x_given[..., ~cond_mask.view(-1).bool()]
#                     x_cond = x_given[..., cond_mask.view(-1).bool()]

#                     # em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.ea_encoder.em_T_out)).to(x_target)
#                     # em_cond_mask[..., 0: config.ea_encoder.em_T_out] = 1
#                     # em_cond_mask = em_cond_mask.view(-1).bool()


#                 else:
#                     x_target = x_given[..., :config.train.tot_len]
#                     x_cond=None

#                 training_losses = diffusion.training_losses(x_target=x_target, x_cond=x_cond, h=data.h, t=None, model_kwargs=model_kwargs)
#                 loss = training_losses['loss']  # [BN]
#                 loss = global_mean_pool(loss, data.batch)  # [B]
                
#                 if world_size > 1:
                    
#                     step_loss_synced = gather_across_gpus(loss, reduce_placeholder).mean().item()
#                 else:
#                     step_loss_synced = loss.mean().item()
#                 if rank == 0 and tot_step % config.train.log_every_step == 0:
#                     train_log.log({"Step train loss": step_loss_synced}, commit=True, step=tot_step)
#                     logs = {"loss": step_loss_synced, "step": tot_step}
#                     progress_bar.set_postfix(**logs)

#                 train_loss_epoch = train_loss_epoch + loss.sum()
#                 counter = counter + loss.size(0)

#                 loss = loss.mean()
#                 loss.backward()
#                 for para_name, param in denoise_network.module.s_modules.named_parameters():
#                     if param.grad is not None:
#                         print(f"WARNING: have gradient for s_modules: {para_name}")
#                 for para_name, param in denoise_network.module.t_modules.named_parameters():
#                     if param.grad is not None:
#                         print(f"WARNING: have gradient for t_modules: {para_name}")

#                 for para_name, param in denoise_network.module.zero_x_params.named_parameters():
#                     if param.grad is None:
#                         print(f"WARNING: No gradient for zero_x_params {para_name}")
                        
#                 for para_name, param in denoise_network.module.zero_h_params.named_parameters():
#                     if param.grad is None:
#                         print(f"WARNING: No gradient for zero_h_params {para_name}")

#                 for para_name, param in denoise_network.module.copy_s_modules.named_parameters():
#                     if param.grad is None:
#                         print(f"WARNING: No gradient for copy_s_modules {para_name}")
#                 for para_name, param in denoise_network.module.copy_t_modules.named_parameters():
#                     if param.grad is None:
#                         print(f"WARNING: No gradient for copy_t_modules {para_name}")

#                 # for para_name, param in ea_encoder.named_parameters():
#                 #     if param.grad is None:
#                 #         print(f"WARNING: No gradient for ea_encoder: {para_name}")
#                 nn.utils.clip_grad_norm_(
#                     list(denoise_network.module.zero_x_params.parameters()) +
#                     list(denoise_network.module.zero_h_params.parameters()) +
#                     list(denoise_network.module.copy_s_modules.parameters()) +
#                     list(denoise_network.module.copy_t_modules.parameters()),
#                     max_norm=1.0  
#                 )
#                 optimizer.step()
#                 optimizer.zero_grad()

#             train_loss_epoch = gather_across_gpus(train_loss_epoch, reduce_placeholder).sum().item()
#             counter = gather_across_gpus(counter, reduce_placeholder).sum().item()

#             if rank == 0:
#                 train_log.log({"Epoch train loss": train_loss_epoch / counter}, commit=True)

#             # Eval on validation set
#             if epoch % config.train.eval_every_epoch == 0 and epoch != 0:
#                 if rank == 0:
#                     print(f'Validating at epoch {epoch}')
#                 denoise_network.eval()
#                 val_nll_epoch, val_mse_epoch = torch.zeros(1).to(rank), torch.zeros(1).to(rank)
#                 counter = torch.zeros(1).to(rank)

#                 for step, data in enumerate(dataloader_val):
#                     data = data.to(rank)
#                     model_kwargs = {'edge_index': data.edge_index,
#                                     'edge_attr': data.edge_attr,
#                                     'batch': data.batch}
#                     x_given = data.x
#                     if diffusion.mode == 'cond':
#                         # Construct cond mask
#                         cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
#                         for interval in config.train.cond_mask:
#                             cond_mask[..., interval[0]: interval[1]] = 1
#                         # model_kwargs['cond_mask'] = cond_mask
#                         x_target  = x_given[..., ~cond_mask.view(-1).bool()]
#                         x_cond = x_given[..., cond_mask.view(-1).bool()]
#                         # em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.ea_encoder.em_T_out)).to(x_target)
#                         # em_cond_mask[..., 0: config.ea_encoder.em_T_out] = 1
#                         # em_cond_mask = em_cond_mask.view(-1).bool()
#                     else:
#                         # model_kwargs['x_given'] = x_start
#                         x_target = x_given[..., :config.train.tot_len]

                        

#                     val_results = diffusion.calc_bpd_loop(x_start=x_target,x_cond=x_cond,h=data.h,model_kwargs=model_kwargs)
#                     total_bpd = val_results['total_bpd']  # [BN]
#                     mse = val_results['mse'].mean(dim=1)  # [BN, T] -> [BN]
#                     total_bpd = global_add_pool(total_bpd, data.batch)  # [B]
#                     mse = global_mean_pool(mse, data.batch)  # [B]

#                     val_nll_epoch += total_bpd.sum()
#                     val_mse_epoch += mse.sum()
#                     counter += total_bpd.size(0)

#                 val_nll_epoch = gather_across_gpus(val_nll_epoch, reduce_placeholder).sum().item()
#                 val_mse_epoch = gather_across_gpus(val_mse_epoch, reduce_placeholder).sum().item()
#                 counter = gather_across_gpus(counter, reduce_placeholder).sum().item()

#                 if rank == 0:
#                     print(f'Val counter: {counter}')
#                     val_nll_epoch = val_nll_epoch / counter
#                     val_mse_epoch = val_mse_epoch / counter
#                     print(f'Val nll: {val_nll_epoch}')
#                     train_log.log({"Val nll": val_nll_epoch}, commit=False)
#                     train_log.log({"Val mse": val_mse_epoch}, commit=True)

#                     better = False

#                     if val_nll_epoch < best_val_nll:
#                         best_val_nll = val_nll_epoch
#                         better = True
#                     if val_mse_epoch < best_val_mse:
#                         best_val_mse = val_mse_epoch
#                     if better:
#                         save_dict = {
#                         # 'EAencoder':ea_encoder.state_dict(),
#                         'model': denoise_network.state_dict(),  
#                         'optimizer': optimizer.state_dict(),  
#                         'tot_step': tot_step } 
#                     torch.save(save_dict,
#                                os.path.join(output_path, f'ckpt_best.pt'))

#             # Save model
#             if rank == 0 and config.train.save_model:
#                 save_dict = {
#                     # 'EAencoder':ea_encoder.state_dict(),
#                     'model': denoise_network.state_dict(),  
#                     'optimizer': optimizer.state_dict(),  
#                     'tot_step': tot_step  
#         }
#                 if epoch % config.train.save_every_epoch == 0:
#                     torch.save(save_dict,
#                             os.path.join(output_path, f'ckpt_{epoch}.pt'))
#                 torch.save(save_dict,
#                         os.path.join(output_path, f'ckpt_last.pt'))

#             if world_size > 1:
#                 dist.barrier()
#             if rank == 0:
#                 progress_bar.update(1)
                
#         if rank == 0:
#             progress_bar.close()
#             train_log.finish()


#     # Start testing
#     if config.test.final_test and diffusion.mode == 'cond':

#         # Load dataset
#         if config.data.test.root == 'data/md17':
#             test_dataset = MD17Traj(**config.data.test)

#         elif config.data.test.root == 'datasets/datagen':
#             test_dataset =  NBody(**config.data.test)
            
#         test_output_path= os.path.join(output_path, config.test.exp_name)

#         if not os.path.exists(test_output_path):
#             os.makedirs(test_output_path, exist_ok=True)
#         if rank == 0:
#             shutil.copy(yaml_file, test_output_path)

#         if world_size > 1:
#             sampler = DistributedSamplerNoDuplicate(test_dataset, shuffle=False, drop_last=False)
#         else:
#             sampler = None
#         test_dataloader = DataLoader(test_dataset, batch_size=config.test.eval_batch_size // world_size,
#                                      shuffle=False, sampler=sampler)

#         # Load checkpoint
#         # test_model_ckpt_path = os.path.join(output_path, f'ckpt_{config.test.final_test_ckpt}.pt')
#         # if os.path.exists(test_model_ckpt_path):
#         #     state_dict = torch.load(test_model_ckpt_path)
#         #     new_state_dict = {}
#         #     for k, v in state_dict.items():
#         #         # new_key = k.replace("module.", "", 1) if k.startswith("module.") else k
#         #         new_key = k
#         #         new_state_dict[new_key] = v
#         #     model_dict=denoise_network.state_dict()
#         #     filtered_state_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
#         #     model_dict.update(filtered_state_dict)
#         #     denoise_network.load_state_dict(model_dict['model'])
            
           
#         test_model_ckpt_path = os.path.join(output_path, f'ckpt_{config.test.final_test_ckpt}.pt')
#         print(   "test_model_ckpt_path:",test_model_ckpt_path)
#         device = torch.device(f'cuda:{args.local_rank}')
#         state_dict = torch.load( test_model_ckpt_path, map_location=device)       

#         denoise_network.load_state_dict(state_dict['model'])
#         # # ea_encoder.load_state_dict(state_dict["EAencoder"])


#         # for test_param_name, test_param in diffusion.denoise_network.named_parameters():
#         #     if "s_modules.0.edge_mlp.actions.0.weight" == test_param_name:
#         #         test_param_1=test_param.detach().clone()
#         #         print(rank)
#         #         print("test_param_1:",test_param_1)
                
              
#         # assert not torch.equal(init_param_1, test_param_1), "test model load failed"
#         # print(f"Test Model loaded from { test_model_ckpt_path } success")
#         # for key, param in denoise_network.module.zero_x_params.items():
#         #     # param.data.zero_()
#         #     assert not torch.equal(param, torch.zeros_like(param)), f"zero_x_params[{key}] is zeros"      
#         # for key, param in denoise_network.module.zero_h_params.items():
#         #     assert not torch.equal(param, torch.zeros_like(param)), f"zero_h_params[{key}] is zeros"  
#         #     # param.data.zero_()
#         # print("zero params load success")
#         # print("zero params set to zeros")

#         # for encoder_test_param_name, encoder_test_param in diffusion.ea_encoder.named_parameters():
#         #     if "s_modules.0.edge_mlp.actions.0.weight" == encoder_test_param_name:
#         #         encoder_test_param_1=encoder_test_param.detach().clone()
#         # assert not torch.equal(encoder_init_param_1, encoder_test_param_1), "test model load failed"
#         # print(f"Test Model loaded from { test_model_ckpt_path } success")

#         denoise_network.eval()
#         if rank == 0:
#             # Wandb test config
#             if config.wandb.no_wandb:
#                 mode = 'disabled'
#             else:
#                 mode = 'online'
            
#             kwargs_test = {'entity': config.wandb.wandb_usr, 'name':  task_name+"_"+config.test.exp_name, 'project': config.wandb.project,
#                     'config': params, 'settings': wandb.Settings(_disable_stats=True), 'mode': mode}
#             test_log = wandb.init(**kwargs_test,allow_val_change=True)
#             test_log.save('*.txt')

               

#         test_nll_epoch_all, test_mse_epoch_all = [], []
#         minADE_K_all, minFDE_K_all = [], []  # distance is L2-norm
#         aveADE_K_all, aveFDE_K_all = [], []  # distance is L2-norm
#         system_id_all = []  # the index in the test dataset
#         reduce_placeholder = CatMetric()

#         for step, data in tqdm(enumerate(test_dataloader), disable=rank != 0):
#             print("step:",step)
#             # if step >= 1:
#             #     break
#             data = data.to(rank)
#             model_kwargs = {'edge_index': data.edge_index,
#                             'edge_attr': data.edge_attr,
#                             'batch': data.batch}

#             x_given = data.x
#             h_given=data.h

#             # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
#             cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
#             for interval in config.test.cond_mask:
#                 cond_mask[..., interval[0]: interval[1]] = 1
#             # model_kwargs['cond_mask'] = cond_mask
#             # model_kwargs['x_given'] = x_start
#             x_target = x_given[..., ~cond_mask.view(-1).bool()]
            
#             x_cond = x_given[..., cond_mask.view(-1).bool()]
#             # em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.ea_encoder.em_T_out)).to(x_target)
#             # em_cond_mask[..., 0: config.ea_encoder.em_T_out] = 1
#             # em_cond_mask = em_cond_mask.view(-1).bool()
        
#             val_results = diffusion.calc_bpd_loop(x_start=x_target, x_cond=x_cond, h=data.h, model_kwargs=model_kwargs)
#             # for key, param in denoise_network.module.zero_x_params.items():
#             #     assert torch.equal(param, torch.zeros_like(param)), f"zero_x_params[{key}] is not all zeros"      
#             # for key, param in denoise_network.module.zero_h_params.items():
#             #     assert torch.equal(param, torch.zeros_like(param)), f"zero_h_params[{key}] is not all zeros"  
#             total_bpd = val_results['total_bpd']  # [BN]
#             mse = val_results['mse'].mean(dim=1)  # [BN, T] -> [BN]

#             total_bpd = global_add_pool(total_bpd, data.batch)  # [B]
#             mse = global_mean_pool(mse, data.batch)  # [B]
#             test_nll_epoch_all.append(total_bpd)
#             test_mse_epoch_all.append(mse)


#             shape_to_pred = x_target.shape  # [BN, 3, T_p]
            
#             ADE_K, FDE_K = [], []
          
#             # Compute traj distance
#             for k in tqdm(range(config.test.K), disable=rank != 0):
#                 if rank == 0:
#                     print(f'Predicting {k}')
                
                
#                 x_out = diffusion.p_sample_loop(shape=shape_to_pred, x_cond=x_cond, progress=False,
#                                                 h=data.h, model_kwargs=model_kwargs)  # [BN, 3, T_p]
#                 distance = (x_out - x_target).square().sum(dim=1).sqrt()  # [BN, T_p]
#                 distance = global_mean_pool(distance, data.batch)  # [B, T_p]
#                 ADE_K.append(distance.mean(dim=1))  # [B]
#                 FDE_K.append(distance[..., -1])  # [B]

#             # Compute minADE, minFDE
#             ADE_K_all = [[] for _ in range(5)]  # 分别存储 ADE_1~ADE_5
#             FDE_K_all = [[] for _ in range(5)]  # 分别存储 FDE_1~FDE_5
#             ADE_K = torch.stack(ADE_K, dim=-1)  # [B, K]
#             FDE_K = torch.stack(FDE_K, dim=-1)  # [B, K]
#             for i in range(config.test.K):
#                 ADE_K_all[i].append(ADE_K[:, i])  # [B]
#                 FDE_K_all[i].append(FDE_K[:, i])  # [B]
#             aveADE_K_all.append(ADE_K.mean(dim=-1))  # [B]
#             aveFDE_K_all.append(FDE_K.mean(dim=-1))  # [B]
#             system_id_all.append(data.system_id)  # [B]

#         # Analyze
#         # minADE_K_all = torch.cat(minADE_K_all, dim=0)  # [B_tot]
#         # minFDE_K_all = torch.cat(minFDE_K_all, dim=0)  # [B_tot]
#         for i in range(config.test.K):
#             ADE_K_all[i] = torch.cat(ADE_K_all[i], dim=0)
#             FDE_K_all[i] = torch.cat(FDE_K_all[i], dim=0)
#             # print(ADE_K_all[i])
#             # print(FDE_K_all[i])
#         aveADE_K_all = torch.cat(aveADE_K_all, dim=0)  # [B_tot]
#         aveFDE_K_all = torch.cat(aveFDE_K_all, dim=0)  # [B_tot]
#         nll_all = torch.cat(test_nll_epoch_all, dim=0)  # [B_tot]
#         system_id_all = torch.cat(system_id_all, dim=0)  # [B_tot]
       
#         # print(aveADE_K_all)
#         # print(aveFDE_K_all)
#         device = torch.device(f"cuda:{rank}")
#         for i in range(config.test.K):
#             ADE_K_all[i] =  ADE_K_all[i].to(device)
#             FDE_K_all[i] =  FDE_K_all[i].to(device)
#         aveADE_K_all = aveADE_K_all.to(device)
#         aveFDE_K_all = aveFDE_K_all.to(device)
#         nll_all = nll_all.to(device)
#         system_id_all = system_id_all.float().to(device)  # if int64, convert to float first
#         reduce_placeholder = reduce_placeholder.to(device)

#         # Reduce from all gpus and compute metrics
#         if world_size > 1:
#             # minADE_K_all = gather_across_gpus(minADE_K_all, reduce_placeholder)  # [B_tot * num_gpus]
#             # minFDE_K_all = gather_across_gpus(minFDE_K_all, reduce_placeholder)
#             for i in range(config.test.K):
#                 ADE_K_all[i] = gather_across_gpus(ADE_K_all[i], reduce_placeholder)
#                 FDE_K_all[i] = gather_across_gpus(FDE_K_all[i], reduce_placeholder)
#             aveADE_K_all = gather_across_gpus(aveADE_K_all, reduce_placeholder)
#             aveFDE_K_all = gather_across_gpus(aveFDE_K_all, reduce_placeholder)
#             nll_all = gather_across_gpus(nll_all, reduce_placeholder)
#             system_id_all = gather_across_gpus(system_id_all, reduce_placeholder)

#         results = {
#             f'aveADE_{config.test.K}': aveADE_K_all.mean().item(),
#             f'aveFDE_{config.test.K}': aveFDE_K_all.mean().item(),
#             'nll': nll_all.mean().item(),
#             'system_id_range': [system_id_all.min().item(), system_id_all.max().item()]
#         }
#         for i in range(config.test.K):
#             results[f'ADE_{i}'] = ADE_K_all[i].mean().item()
#             results[f'FDE_{i}'] = FDE_K_all[i].mean().item()
#         ade_means = [ADE_K_all[i].mean().item() for i in range(config.test.K)]
#         fde_means = [FDE_K_all[i].mean().item() for i in range(config.test.K)]

#         results['ADE_std_across_K'] = torch.tensor(ade_means).std().item()
#         results['FDE_std_across_K'] = torch.tensor(fde_means).std().item()

           

#         if rank == 0:
#             print(results)
#             for i in range(config.test.K):
#                test_log.log({f'Test ADE_{i}': ADE_K_all[i].mean().item()}, commit=False)
#                test_log.log({f'Test FDE_{i}': FDE_K_all[i].mean().item()}, commit=False)
#             test_log.log({f'Test ADE_std': torch.tensor(ade_means).std().item()}, commit=False)
#             test_log.log({f'Test FDE_std': torch.tensor(fde_means).std().item()}, commit=False)
                             
#             # test_log.log({f'Test minADE_{config.test.K}': minADE_K_all.mean().item()}, commit=False)
#             # test_log.log({f'Test minFDE_{config.test.K}': minFDE_K_all.mean().item()}, commit=False)
            
#             test_log.log({f'Test aveADE_{config.test.K}': aveADE_K_all.mean().item()}, commit=False)
#             test_log.log({f'Test aveFDE_{config.test.K}': aveFDE_K_all.mean().item()}, commit=False)
#             test_log.log({f'Test nll': nll_all.mean().item()}, commit=True)

#             # Save
#             save_path = os.path.join(test_output_path, 'results.pkl')
            
#             save_results = {
#                 f'aveADE_{config.test.K}': aveADE_K_all.detach().cpu().numpy(),
#                 f'aveFDE_{config.test.K}': aveFDE_K_all.detach().cpu().numpy(),
                
#                 'nll': nll_all.detach().cpu().numpy(),
#                 'system_id': system_id_all.detach().cpu().numpy()
#             }
#             for i in range(config.test.K):
#                 save_results[f'ADE_{i}'] = ADE_K_all[i].detach().cpu().numpy()
#                 save_results[f'FDE_{i}'] = FDE_K_all[i].detach().cpu().numpy()
#             with open(save_path, 'wb') as f:
#                 pickle.dump(save_results, f)
#             print(f'Results saved to {save_path}')

#         if world_size > 1:
#             dist.barrier()
#             dist.destroy_process_group()
        


def main():

    parser = argparse.ArgumentParser(description='EATDM')
    parser.add_argument('--train_yaml_file', type=str, help='path of the train yaml file',
                        default='configs/nbody_train_cond.yaml')
    parser.add_argument('--local-rank', dest='local_rank', type=int, default=0)

    args = parser.parse_args()
    print(args)

    world_size = torch.cuda.device_count()
    # world_size = 1

    print('Let\'s use', world_size, 'GPUs!')

    if world_size > 1:
        if torch.cuda.is_available():
            dist.init_process_group('nccl', rank=args.local_rank, world_size=world_size)
        else:
            dist.init_process_group('gloo', rank=args.local_rank, world_size=world_size)

    run(args.local_rank, world_size, args)


if __name__ == '__main__':
    main()

