import torch
import torch.nn as nn
from tqdm import tqdm
import yaml
import argparse
from torch_geometric.data import DataLoader
from torch_geometric.nn.pool import global_mean_pool, global_add_pool
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
from torchmetrics.aggregation import CatMetric
from easydict import EasyDict
import wandb
import os
import copy
import shutil
import pickle
import sys
sys.path.append('./')
import numpy
from experiments.fixddp import DistributedSamplerNoDuplicate
from datasets.cmu import CMU
from models.EGCTN import EGCTN
from models.Encoder_cond import Cencoder
from diffusion.GeoTDM import GeoTDM, ModelMeanType, ModelVarType, LossType
from utils.misc import set_seed, gather_across_gpus

os.environ["WANDB__SERVICE_WAIT"] = "300"
torch.multiprocessing.set_sharing_strategy('file_system')

opt_mode = 'node'

def run(rank, world_size, args):
    # torch.autograd.set_detect_anomaly(True)
    # Load args
    yaml_file = args.train_yaml_file
    with open(yaml_file, 'r') as f:
        params = yaml.safe_load(f)
    config = EasyDict(params)
    
    # Init model and optimizer
    denoise_network = EGCTN(**config.denoise_model).to(rank)
    con_encoder=Cencoder(**config.con_encoder).to(rank)
    print("init denoise_network")
    diffusion = GeoTDM(denoise_network=denoise_network,con_encoder=con_encoder,
                       model_mean_type=ModelMeanType.EPSILON,
                       model_var_type=ModelVarType.FIXED_LARGE,
                       loss_type=LossType.MSE,
                       device=rank,
                       rescale_timesteps=False,
                       **config.diffusion)
    print("init diffusion")
    optimizer = torch.optim.Adam(list(denoise_network.zero_x_params.parameters())+list(denoise_network.zero_h_params.parameters())+list(denoise_network.copy_s_modules.parameters())+list(denoise_network.copy_t_modules.parameters()),lr=config.train.lr)
    print("init optimizer")
    if world_size > 1:
        denoise_network = DistributedDataParallel(denoise_network, device_ids=[rank] )
        # con_encoder = DistributedDataParallel(con_encoder, device_ids=[rank])
        print("init Distributed")
   
    # Ensure only ref frame parameters are trainable
    for param in denoise_network.module.s_modules.parameters():
        param.requires_grad = False
    for param in denoise_network.module.t_modules.parameters():
        param.requires_grad = False
    print("freeze denoise_network")   
    for param in denoise_network.module.zero_x_params.parameters():
        param.requires_grad = True
    for param in denoise_network.module.zero_h_params.parameters():
        param.requires_grad = True
    for param in denoise_network.module.copy_s_modules.parameters():
        param.requires_grad = True
    for param in denoise_network.module.copy_t_modules.parameters():
        param.requires_grad = True
    # for param in con_encoder.module.parameters():
    #     param.requires_grad = True
    print("Trainable zero modules")

            
    # for encoder_init_param_name, encoder_init_param in diffusion.con_encoder.named_parameters():
    #     if "s_modules.0.edge_mlp.actions.0.weight" == encoder_init_param_name:
    #         encoder_init_param_1=encoder_init_param.detach().clone()

    # Load dataset
    dataset_train = CMU(**config.data.train)
    dataset_val = CMU(**config.data.val)
    task_name = config.data.train.act

      # Save args yaml file
    name_path = os.path.join(config.train.output_base_path, task_name)
    output_path = os.path.join(name_path, config.train.exp_name)
    print(output_path)
    if rank == 0:
        # 假设你的模型是 model
        total_params = sum(p.numel() for p in denoise_network.module.parameters())
        trainable_params = sum(p.numel() for p in denoise_network.module.parameters() if p.requires_grad)
        for name, param in denoise_network.module.named_parameters():
            if param.requires_grad:
                print(f"{name}: {param.numel()}")


        print(f"Total parameters: {total_params}")
        print(f"Trainable parameters: {trainable_params}")

        param_size_bytes = total_params * 4
        param_size_mb = param_size_bytes / (1024 * 1024)
        print(f"Total parameters: {total_params} ({param_size_mb:.2f} MB)")

        trainable_size_bytes = trainable_params * 4
        trainable_size_mb = trainable_size_bytes / (1024 * 1024)
        print(f"Trainable parameters: {trainable_params} ({trainable_size_mb:.2f} MB)")

    if config.train.train_mode == True:
        if rank == 0:
            if not os.path.exists(output_path):
                os.makedirs(output_path, exist_ok=True)
            shutil.copy(yaml_file, output_path)

        set_seed(config.train.seed)
        # Load dataset
        if world_size > 1:
            sampler_train = DistributedSampler(dataset_train)
            sampler_val = DistributedSamplerNoDuplicate(dataset_val, shuffle=False, drop_last=False)
        else:
            sampler_train = None
            sampler_val = None

        dataloader_train = DataLoader(dataset_train, batch_size=config.train.batch_size // world_size,
                                    shuffle=(sampler_train is None), sampler=sampler_train, pin_memory=True)
        dataloader_val = DataLoader(dataset_val, batch_size=config.train.eval_batch_size // world_size, shuffle=False,
                                    sampler=sampler_val)

        
        # Load base checkpoint
        base_model_ckpt_path = "cond_outputs_CMU/ckpt_base.pt"
        if os.path.exists(base_model_ckpt_path):
            state_dict = torch.load(base_model_ckpt_path)
            model_dict=denoise_network.state_dict()
            filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
            model_dict.update(filtered_state_dict)
            denoise_network.load_state_dict(model_dict)
            tot_step = state_dict.get('tot_step', 0)  
            
            for i in config.denoise_model.n_copy_layer_list:
                layer_idx = str(i)
                denoise_network.module.copy_s_modules[layer_idx].load_state_dict(copy.deepcopy(denoise_network.module.s_modules[i].state_dict()))
                denoise_network.module.copy_t_modules[layer_idx].load_state_dict(copy.deepcopy(denoise_network.module.t_modules[i].state_dict()))

        
            
        
        for base_train_param_name, base_train_param in diffusion.denoise_network.named_parameters():
            if "s_modules.0.edge_mlp.actions.0.weight" == base_train_param_name:
                base_train_param_1=base_train_param.detach().clone()

        assert not torch.equal(init_param_1, base_train_param_1), "base train model load failed"
        print(f'Base Train Model loaded from {base_model_ckpt_path} success')
        param1 = denoise_network.module.copy_s_modules["0"].edge_mlp.actions[0].weight
        param2 = denoise_network.module.s_modules[0].edge_mlp.actions[0].weight
        assert torch.equal(param1.cpu(), param2.cpu()),"copy layer failed"
        print("Copy layer load success")
        for key, param in denoise_network.module.zero_x_params.items():
            assert torch.equal(param, torch.zeros_like(param)), f"zero_x_params[{key}] is not all zeros"        
        print("zero params load success")


        for para_name, param in denoise_network.module.s_modules.named_parameters():
            if param.grad is not None:
                print(f"WARNING: have gradient for s_modules: {para_name}")
        
        if rank == 0:
            # Wandb_train config
            if config.wandb.no_wandb:
                mode = 'disabled'
            else:
                mode = 'online'
            if config.train.train_mode == True:
                kwargs_train = {'entity': config.wandb.wandb_usr, 'name': task_name + "_" + config.train.exp_name, 'project': config.wandb.project,
                        'config': params, 'settings': wandb.Settings(_disable_stats=True), 'mode': mode}
            train_log = wandb.init(**kwargs_train)
            train_log.save('*.txt')

        # Start training
        num_epochs = config.train.num_epochs
        tot_step = 0

        best_val_final_error = 1e10
        reduce_placeholder = CatMetric()

        if rank == 0:
            progress_bar = tqdm(total=num_epochs)

        calc_nll = False

        for epoch in range(1, num_epochs + 1):
            if rank == 0:
                # print(f'Start epoch {epoch}')
                progress_bar.set_description(f"Epoch {epoch}")
            denoise_network.train()
            sampler_train.set_epoch(epoch)

            # Training

            train_loss_epoch, counter = torch.zeros(1).to(rank), torch.zeros(1).to(rank)

            for step, data in enumerate(dataloader_train):
                tot_step += 1

                data = data.to(rank)
                model_kwargs = {'h': data.h,
                                'edge_index': data.edge_index,
                                'edge_attr': data.edge_attr,
                                'batch': data.batch}
                x_given = data.x
                # Construct cond mask
                cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
                for interval in config.train.cond_mask:
                    cond_mask[..., interval[0]: interval[1]] = 1
                # model_kwargs['x_given'] = x_start
                x_target = x_given[..., ~cond_mask.view(-1).bool()]
                x_cond = x_given[..., cond_mask.view(-1).bool()]

                em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.con_encoder.em_T_out)).to(x_target)
                em_cond_mask[..., 0: config.con_encoder.em_T_out] = 1
                em_cond_mask = em_cond_mask.view(-1).bool()
                
                
                
                training_losses = diffusion.training_losses(x_target=x_target, x_cond=x_cond,h=data.h,em_cond_mask=em_cond_mask, t=None, model_kwargs=model_kwargs)
                loss = training_losses['loss']  # [BN]
                if opt_mode == 'graph':  # graph-wise loss
                    loss = global_mean_pool(loss, data.batch)  # [B]

                step_loss_synced = gather_across_gpus(loss, reduce_placeholder).mean().item()
                if rank == 0 and tot_step % config.train.log_every_step == 0:
                    train_log.log({"Step train loss": step_loss_synced}, commit=True, step=tot_step)
                    logs = {"loss": step_loss_synced, "step": tot_step}
                    progress_bar.set_postfix(**logs)

                train_loss_epoch = train_loss_epoch + loss.sum()
                counter = counter + loss.size(0)

                loss = loss.mean()
                loss.backward()
                for para_name, param in denoise_network.module.s_modules.named_parameters():
                    if param.grad is not None:
                        print(f"WARNING: have gradient for s_modules: {para_name}")
                for para_name, param in denoise_network.module.t_modules.named_parameters():
                    if param.grad is not None:
                        print(f"WARNING: have gradient for t_modules: {para_name}")

                for para_name, param in denoise_network.module.zero_x_params.named_parameters():
                    if param.grad is None:
                        print(f"WARNING: No gradient for zero_x_params {para_name}")
                for para_name, param in denoise_network.module.zero_h_params.named_parameters():
                    if param.grad is None:
                        print(f"WARNING: No gradient for zero_h_params {para_name}")

                for para_name, param in denoise_network.module.copy_s_modules.named_parameters():
                    if param.grad is None:
                        print(f"WARNING: No gradient for copy_s_modules {para_name}")
                for para_name, param in denoise_network.module.copy_t_modules.named_parameters():
                    if param.grad is None:
                        print(f"WARNING: No gradient for copy_t_modules {para_name}")

                # for para_name, param in con_encoder.module.named_parameters():
                #     if param.grad is None:
                #         print(f"WARNING: No gradient for con_encoder: {para_name}")
                nn.utils.clip_grad_norm_(
                    list(denoise_network.module.zero_x_params.parameters()) +
                    list(denoise_network.module.zero_h_params.parameters()) +
                    list(denoise_network.module.copy_s_modules.parameters()) +
                    list(denoise_network.module.copy_t_modules.parameters()), 
                    # list(con_encoder.module.parameters()), 
                    max_norm=1.0  
                )
                optimizer.step()
                optimizer.zero_grad()

            train_loss_epoch = gather_across_gpus(train_loss_epoch, reduce_placeholder).sum().item()
            counter = gather_across_gpus(counter, reduce_placeholder).sum().item()

            if rank == 0:
                train_log.log({"Epoch train loss": train_loss_epoch / counter}, commit=True)

            # Eval on validation set
            if epoch % config.train.eval_every_epoch == 0 and epoch != 0:
                # Start testing
                denoise_network.eval()

                if calc_nll:
                    test_nll_epoch_all, test_mse_epoch_all = [], []  # [B_tot]
                Error_K_all = []  # [B_tot, K, T]
                system_id_all = []  # the index in the test dataset
                reduce_placeholder = CatMetric()

                for step, data in tqdm(enumerate(dataloader_val), disable=rank != 0):
                    data = data.to(rank)
                    model_kwargs = {'h': data.h,
                                    'edge_index': data.edge_index,
                                    'edge_attr': data.edge_attr,
                                    'batch': data.batch}

                    x_given = data.x

             
                    # Construct cond mask
                    cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
                    for interval in config.train.cond_mask:
                        cond_mask[..., interval[0]: interval[1]] = 1
                    # model_kwargs['cond_mask'] = cond_mask
                    x_target  = x_given[..., ~cond_mask.view(-1).bool()]
                    x_cond = x_given[..., cond_mask.view(-1).bool()]
                    em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.con_encoder.em_T_out)).to(x_target)
                    em_cond_mask[..., 0: config.con_encoder.em_T_out] = 1
                    em_cond_mask = em_cond_mask.view(-1).bool()

                    if calc_nll:
                        val_results = diffusion.calc_bpd_loop(x_start=x_target,x_cond=x_cond,h=data.h,em_cond_mask=em_cond_mask,model_kwargs=model_kwargs)
                        total_bpd = val_results['total_bpd']  # [BN]
                        mse = val_results['mse'].mean(dim=1)  # [BN, T] -> [BN]

                        total_bpd = global_add_pool(total_bpd, data.batch)  # [B]
                        mse = global_mean_pool(mse, data.batch)  # [B]
                        test_nll_epoch_all.append(total_bpd)
                        test_mse_epoch_all.append(mse)

                    x_target  = x_given[..., ~cond_mask.view(-1).bool()]
                    shape_to_pred = x_target.shape  # [BN, 3, T_p]

                    Error_K = []

                    # Compute traj distance
                    for k in range(config.train.K):
                        x_out = diffusion.p_sample_loop(shape=shape_to_pred, x_cond=x_cond, progress=False,
                                                h=data.h, em_cond_mask=em_cond_mask,  model_kwargs=model_kwargs)  # [BN, 3, T_p]
         # x_out = torch.cat((x_start[..., cond_mask.view(-1).bool()], x_out), dim=-1)
                        distance = (x_out - x_target).square().sum(dim=1).sqrt()  # [BN, T_p]
                        distance = global_mean_pool(distance, data.batch)  # [B, T_p]
                        Error_K.append(distance)

                    # Compute minADE, minFDE
                    Error_K = torch.stack(Error_K, dim=-1)  # [B, T_p, K]
                    system_id_all.append(data.system_id)  # [B]
                    Error_K_all.append(Error_K)

                # Analyze
                Error_K_all = torch.cat(Error_K_all, dim=0)  # [B_tot, T_p, K]
                Error_min_all = Error_K_all.min(dim=2).values  # [B_tot, T_p]
                Error_ave_all = Error_K_all.mean(dim=2)  # [B_tot, T_p]
                if calc_nll:
                    nll_all = torch.cat(test_nll_epoch_all, dim=0)  # [B_tot]
                    eps_mse_all = torch.cat(test_mse_epoch_all, dim=0)  # [B_tot]
                system_id_all = torch.cat(system_id_all, dim=0)  # [B_tot]

                results = {}

                if world_size > 1:
                    if calc_nll:
                        nll_all = gather_across_gpus(nll_all, reduce_placeholder)
                        eps_mse_all = gather_across_gpus(eps_mse_all, reduce_placeholder)
                    system_id_all = gather_across_gpus(system_id_all, reduce_placeholder)

                if calc_nll:
                    results['nll'] = nll_all.mean().item()
                results['system_id_range'] = [system_id_all.min().item(), system_id_all.max().item()]

                eval_index = [2, 4, 8, 10, 14, 25]

                eval_steps = {f'{_*40}ms': _ - 1 for _ in eval_index}

                for key in eval_steps:
                    cur_Error_min = Error_min_all[..., eval_steps[key]]  # [B_tot]

                    # Reduce from all gpus and compute metrics
                    if world_size > 1:
                        reduce_placeholder = reduce_placeholder.to(rank)
                        cur_Error_min = gather_across_gpus(cur_Error_min, reduce_placeholder)  # [B_tot * num_gpus]

                    results[f'min_{config.train.K}_' + key] = cur_Error_min.mean().item() / config.data.test.scale

                    cur_Error_ave = Error_ave_all[..., eval_steps[key]]  # [B_tot]

                    # Reduce from all gpus and compute metrics
                    if world_size > 1:
                        cur_Error_ave = gather_across_gpus(cur_Error_ave, reduce_placeholder)  # [B_tot * num_gpus]

                    results[f'ave_{config.train.K}_' + key] = cur_Error_ave.mean().item() / config.data.test.scale

                if rank == 0:
                    name = f'ave_{config.train.K}_' + '80ms'
                    train_log.log({name: results[name]}, commit=False)
                    name = f'min_{config.train.K}_' + '80ms'
                    train_log.log({name: results[name]}, commit=False)
                    name = f'ave_{config.train.K}_' + '1000ms'
                    train_log.log({name: results[name]}, commit=False)
                    name = f'min_{config.train.K}_' + '1000ms'
                    train_log.log({name: results[name]}, commit=True)

                if rank == 0:
                    print(results)
                    print(f'counter: {system_id_all.size()}')

                    better = False
                    name = f'ave_{config.train.K}_' + '1000ms'
                    final_error = results[name]

                    if final_error < best_val_final_error:
                        best_val_final_error = final_error
                        better = True
                    if better:
                        save_dict = {
                        # 'Cencoder':con_encoder.state_dict(),
                        'model': denoise_network.state_dict(),  
                        'optimizer': optimizer.state_dict(),  
                        'tot_step': tot_step } 
                        torch.save(save_dict,
                                os.path.join(output_path, f'ckpt_best.pt'))

            # Save model
            if rank == 0 and config.train.save_model:
                save_dict = {
                        # 'Cencoder':con_encoder.state_dict(),
                        'model': denoise_network.state_dict(),  
                        'optimizer': optimizer.state_dict(),  
                        'tot_step': tot_step } 
                if epoch % config.train.save_every_epoch == 0:
                    torch.save(save_dict,
                            os.path.join(output_path, f'ckpt_{epoch}.pt'))
                torch.save(save_dict,
                        os.path.join(output_path, f'ckpt_last.pt'))

            if world_size > 1:
                dist.barrier()
            if rank == 0:
                progress_bar.update(1)
                
        if rank == 0:
            progress_bar.close()
            train_log.finish()

    # Final testing
    if config.test.final_test:
        if rank == 0:
            print('Start final testing')
        dataset_test = CMU(**config.data.test)
        # Load checkpoint
        test_output_path= os.path.join(output_path, config.test.exp_name)
        
        if not os.path.exists(test_output_path):
            os.makedirs(test_output_path, exist_ok=True)
        if rank == 0:
            shutil.copy(yaml_file, test_output_path)
            
        if world_size > 1:
            sampler = DistributedSamplerNoDuplicate(dataset_test, shuffle=False, drop_last=False)
        else:
            sampler = None
        test_dataloader = DataLoader(dataset_test, batch_size=config.train.eval_batch_size // world_size, shuffle=False,
                                    sampler=sampler)
        
         # Load checkpoint
        test_model_ckpt_path = os.path.join(output_path, f'ckpt_{config.test.final_test_ckpt}.pt')
        device = torch.device(rank)
        state_dict = torch.load( test_model_ckpt_path ,map_location=device)
        denoise_network.load_state_dict(state_dict['model'])
        # con_encoder.load_state_dict(state_dict["Cencoder"])
    
        for test_param_name, test_param in diffusion.denoise_network.named_parameters():
            if "s_modules.0.edge_mlp.actions.0.weight" == test_param_name:
                test_param_1=test_param.detach().clone()
        assert not torch.equal(init_param_1, test_param_1), "test model load failed"
        print(f"Test Model loaded from { test_model_ckpt_path } success")

        # for encoder_test_param_name, encoder_test_param in diffusion.con_encoder.named_parameters():
        #     if "s_modules.0.edge_mlp.actions.0.weight" == encoder_test_param_name:
        #         encoder_test_param_1=encoder_test_param.detach().clone()
        # assert not torch.equal(encoder_init_param_1, encoder_test_param_1), "test model load failed"
        print(f"Test Model loaded from { test_model_ckpt_path } success")

        
        if rank == 0:
            print(f'Model loaded from {test_model_ckpt_path}')
        denoise_network.eval()

        
        if rank == 0:
            # Wandb test config
            if config.wandb.no_wandb:
                mode = 'disabled'
            else:
                mode = 'online'
            
            kwargs_test = {'entity': config.wandb.wandb_usr, 'name': task_name + "_" + config.test.exp_name, 'project': config.wandb.project,
                    'config': params, 'settings': wandb.Settings(_disable_stats=True), 'mode': mode}
            test_log = wandb.init(**kwargs_test,allow_val_change=True)
            test_log.save('*.txt')
        calc_nll = False
        if calc_nll:
            test_nll_epoch_all, test_mse_epoch_all = [], []  # [B_tot]
        Error_K_all = []  # [B_tot, K, T]
        system_id_all = []  # the index in the test dataset
        reduce_placeholder = CatMetric()

        for step, data in tqdm(enumerate(test_dataloader), disable=rank != 0, total=len(test_dataloader)):
            data = data.to(rank)
            model_kwargs = {'h': data.h,
                            'edge_index': data.edge_index,
                            'edge_attr': data.edge_attr,
                            'batch': data.batch}

            x_given = data.x
            h_given=data.h

            # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
            cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
            for interval in config.test.cond_mask:
                cond_mask[..., interval[0]: interval[1]] = 1
            # model_kwargs['cond_mask'] = cond_mask
            # model_kwargs['x_given'] = x_start
            x_target = x_given[..., ~cond_mask.view(-1).bool()]
            x_cond = x_given[..., cond_mask.view(-1).bool()]
            em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.con_encoder.em_T_out)).to(x_target)
            em_cond_mask[..., 0: config.con_encoder.em_T_out] = 1
            em_cond_mask = em_cond_mask.view(-1).bool()

            if calc_nll:
                val_results = diffusion.calc_bpd_loop(x_start=x_target,x_cond=x_cond,h=data.h,em_cond_mask=em_cond_mask,model_kwargs=model_kwargs)
                total_bpd = val_results['total_bpd']  # [BN]
                mse = val_results['mse'].mean(dim=1)  # [BN, T] -> [BN]

                total_bpd = global_add_pool(total_bpd, data.batch)  # [B]
                mse = global_mean_pool(mse, data.batch)  # [B]
                test_nll_epoch_all.append(total_bpd)
                test_mse_epoch_all.append(mse)

            shape_to_pred = x_target.shape  # [BN, 3, T_p]

            Error_K = []

            # Compute traj distance
            for k in range(config.train.K):
                print("k:",k)
                x_out = diffusion.p_sample_loop(shape=shape_to_pred, x_cond=x_cond, progress=False,
                                                h=data.h, em_cond_mask=em_cond_mask,  model_kwargs=model_kwargs)  # [BN, 3, T_p]
               # x_out = torch.cat((x_start[..., cond_mask.view(-1).bool()], x_out), dim=-1)
                distance = (x_out - x_target).square().sum(dim=1).sqrt()  # [BN, T_p]
                distance = global_mean_pool(distance, data.batch)  # [B, T_p]
                Error_K.append(distance)

            # Compute minADE, minFDE
            Error_K = torch.stack(Error_K, dim=-1)  # [B, T_p, K]
            system_id_all.append(data.system_id)  # [B]
            Error_K_all.append(Error_K)

        # Analyze
        Error_K_all = torch.cat(Error_K_all, dim=0)  # [B_tot, T_p, K]
        Error_min_all = Error_K_all.min(dim=2).values  # [B_tot, T_p]
        Error_ave_all = Error_K_all.mean(dim=2)  # [B_tot, T_p]
        if calc_nll:
            nll_all = torch.cat(test_nll_epoch_all, dim=0)  # [B_tot]
            eps_mse_all = torch.cat(test_mse_epoch_all, dim=0)  # [B_tot]
        system_id_all = torch.cat(system_id_all, dim=0)  # [B_tot]

        results = {}

        if world_size > 1:
            if calc_nll:
                nll_all = gather_across_gpus(nll_all, reduce_placeholder)
                eps_mse_all = gather_across_gpus(eps_mse_all, reduce_placeholder)
            system_id_all = gather_across_gpus(system_id_all, reduce_placeholder)
            
        if calc_nll:
            results['nll'] = nll_all.mean().item()
        results['system_id_range'] = [system_id_all.min().item(), system_id_all.max().item()]

        eval_index = [2, 4, 8, 10, 14, 25]

        eval_steps = {f'{_ * 40}ms': _ - 1 for _ in eval_index}
        cur_Error_k_all=[[] for _ in range(5)]
        for key in eval_steps:
            print("key",key)
            cur_Error_min = Error_min_all[..., eval_steps[key]]  # [B_tot]
            # Reduce from all gpus and compute metrics
            if world_size > 1:
                reduce_placeholder = reduce_placeholder.to(rank)
                cur_Error_min = gather_across_gpus(cur_Error_min, reduce_placeholder)  # [B_tot * num_gpus]
            cur_Error_ave = Error_ave_all[..., eval_steps[key]]  # [B_tot]
            # Reduce from all gpus and compute metrics
            if world_size > 1:
                reduce_placeholder = reduce_placeholder.to(rank)
                cur_Error_ave = gather_across_gpus(cur_Error_ave, reduce_placeholder)  # [B_tot * num_gpus]
            for i in range(config.train.K):
                cur_Error_k_all[i] = Error_K_all[:, :, i][..., eval_steps[key]]
                if world_size > 1:
                    reduce_placeholder = reduce_placeholder.to(rank)
                    cur_Error_k_all[i] = gather_across_gpus(cur_Error_k_all[i], reduce_placeholder)  # [B_tot * num_gpus]
            act_index_map = dataset_test.act_index_map
            case_index = system_id_all
            sort_idx = torch.argsort(case_index)
            cur_Error_ave = cur_Error_ave[sort_idx]  # change into the correct sorting
            cur_Error_min = cur_Error_min[sort_idx]
            for i in range(config.train.K):
                cur_Error_k_all[i] = cur_Error_k_all[i][sort_idx]
            assert (case_index[sort_idx] - torch.arange(len(dataset_test)).to(case_index)).square().sum().item() == 0
            cur_min_, cur_ave_ = [], []
            cur_k_all_=[[] for _ in range(5)]
            cur_k_all=[0 for _ in range(5)]
            for scenario in act_index_map:
                cur_index = torch.from_numpy(act_index_map[scenario]).long().to(system_id_all.device)
                cur_ave = cur_Error_ave[cur_index].mean().item() / config.data.test.scale
                cur_min = cur_Error_min[cur_index].mean().item() / config.data.test.scale
                for i in range(config.train.K):
                    # print(cur_Error_k_all[i][cur_index])
                    cur_k_all[i]=cur_Error_k_all[i][cur_index].mean().item() / config.data.test.scale
                    cur_k_all_[i].append(cur_k_all[i])
                    results[f'{scenario}_{i}_' + key] = cur_k_all[i]
                    # print(f'{scenario}_{i}_{config.train.K}_' + key, cur_k_all[i])
                results[f'{scenario}_min_{config.train.K}_' + key] = cur_min
                results[f'{scenario}_ave_{config.train.K}_' + key] = cur_ave
                print("cur_k_all",cur_k_all)
                try:
                    results[f'{scenario}_std_{config.train.K}_' + key] = float(numpy.std(cur_k_all, ddof=1))
                    results[f'{scenario}_aveK_{config.train.K}_' + key] = float(numpy.mean(cur_k_all))
                except:
                    print("std error")
                cur_min_.append(cur_min)
                cur_ave_.append(cur_ave)
            results[f'AVE_min_{config.train.K}_' + key] = sum(cur_min_) / len(cur_min_)
            results[f'AVE_ave_{config.train.K}_' + key] = sum(cur_ave_) / len(cur_ave_)

        if rank == 0:

            print(results)

            name = f'AVE_ave_{config.train.K}_' + '80ms'
            test_log.log({'Test_' + name: results[name]}, commit=False)
            name = f'AVE_min_{config.train.K}_' + '80ms'
            test_log.log({'Test_' + name: results[name]}, commit=False)
            name = f'AVE_ave_{config.train.K}_' + '1000ms'
            test_log.log({'Test_' + name: results[name]}, commit=False)
            name = f'AVE_min_{config.train.K}_' + '1000ms'
            test_log.log({'Test_' + name: results[name]}, commit=True)

            # Save
            save_path = os.path.join(output_path, 'results.pkl')
            save_results = results
            with open(save_path, 'wb') as f:
                pickle.dump(save_results, f)
            print(f'Results saved to {save_path}')


    if world_size > 1:
        dist.barrier()
        dist.destroy_process_group()



def main():

    parser = argparse.ArgumentParser(description='GeoTDM')
    parser.add_argument('--train_yaml_file', type=str, help='path of the train yaml file',
                        default='configs/cmu_train.yaml')
    parser.add_argument('--local-rank', dest='local_rank', type=int, default=0)

    args = parser.parse_args()
    print(args)

    world_size = torch.cuda.device_count()
    print('Let\'s use', world_size, 'GPUs!')

    if world_size > 1:
        dist.init_process_group('nccl', rank=args.local_rank, world_size=world_size)
    run(args.local_rank, world_size, args)


if __name__ == '__main__':
    
    main()

