import torch
from torch_geometric.loader import DataLoader
import pickle
import os
import shutil
import yaml
from easydict import EasyDict
from torch.nn.parallel import DistributedDataParallel
import sys
from tqdm import tqdm
sys.path.append('./')

from datasets.md17 import MD17Traj

from models.EA_zero import EA
from models.EAencoder import EAencoder
from diffusion.EATDM import EATDM, ModelMeanType, ModelVarType, LossType

from utils.misc import set_seed

device = 0

# Load args
eval_yaml_file = "/nlp/scr/jiangm/wproject/CGeoDM/cond_configs_MD/md17_sampling.yaml"
with open(eval_yaml_file, 'r') as f:
    params = yaml.safe_load(f)
config = EasyDict(params)
cond = config.eval.cond
train_output_path_mo = os.path.join(config.eval.output_base_path, config.data.molecule_name)

train_output_path = os.path.join(train_output_path_mo, config.eval.train_exp_name)
print(f'Train output path: {train_output_path}')

set_seed(config.eval.seed)

if config.eval.model == 'GeoTDM':
    try:
        train_yaml_file = os.path.join(train_output_path, 'md17_train_cond_ours_EA.yaml')
    except:
        print('Notice: This exp was launched by submit_job.sh')
        train_yaml_file = os.path.join(train_output_path, f'{config.eval.train_exp_name}.yaml')  # Use this if batch

# train_yaml_file = os.path.join(train_output_path, 'md17_train.yaml')


with open(train_yaml_file, 'r') as f:
    train_params = yaml.safe_load(f)
train_config = EasyDict(train_params)
if config.eval.eval_exp_name is None:
    config.eval.eval_exp_name = config.eval.train_exp_name + '_eval'
eval_output_path = os.path.join(train_output_path_mo, config.eval.eval_exp_name)
print(f'Eval output path: {eval_output_path}')
if not os.path.exists(eval_output_path):
    os.mkdir(eval_output_path)
shutil.copy(eval_yaml_file, eval_output_path)

# Overwrite model configs from training config
if config.eval.model == 'GeoTDM':
    config.denoise_model = train_config.denoise_model
    config.diffusion = train_config.diffusion
    config.ea_encoder = train_config.ea_encoder
    # Overwrite diffusion timesteps for sampling
    if config.eval.sampling_timesteps is not None:
        config.diffusion.num_timesteps = config.eval.sampling_timesteps
# Overwrite cond_mask
if cond:
    if config.eval.cond_mask is None:
        config.eval.cond_mask = train_config.train.cond_mask

# Load dataset
dataset = MD17Traj(**config.data)
dataloader = DataLoader(dataset, batch_size=config.eval.batch_size, shuffle=False)

# Init model and optimizer
if config.eval.model == 'GeoTDM':
    # Init model and optimizer
    denoise_network = EA(**config.denoise_model).to(device)
    ea_encoder=EAencoder(**config.ea_encoder).to(device)


for init_param_name, init_param in denoise_network.named_parameters():
    # print(init_param_name)
    if "s_modules.1.edge_mlp.actions.0.weight" == init_param_name:
        init_param_1=init_param.detach().clone()
        print(init_param_name)
            
# Load checkpoint
model_ckpt_path = os.path.join(train_output_path, f'ckpt_{config.eval.model_ckpt}.pt')
state_dict = torch.load(model_ckpt_path)
# try:
#     denoise_network.load_state_dict(state_dict['model'])
# except:
# print( state_dict["model"].items())
print('Notice: This exp was launched by submit_job.sh')
state_dict = {k[7:]: v for k, v in state_dict["model"].items()}
# print(state_dict.keys())
# denoise_network = DistributedDataParallel(denoise_network, device_ids=[0])
denoise_network.load_state_dict(state_dict)
    
for test_param_name, test_param in denoise_network.named_parameters():
    if "s_modules.0.edge_mlp.actions.0.weight" == test_param_name:
        test_param_1=test_param.detach().clone()
        
assert not torch.equal(init_param_1, test_param_1), "test model load failed"
print(f'Model loaded from {model_ckpt_path}')

if config.eval.model == 'GeoTDM':
    diffusion = EATDM(denoise_network=denoise_network,ea_encoder=ea_encoder,
                       model_mean_type=ModelMeanType.EPSILON,
                       model_var_type=ModelVarType.FIXED_LARGE,
                       loss_type=LossType.MSE,
                       device=device,
                       # rescale_timesteps=True,
                       rescale_timesteps=False,
                       **config.diffusion)

denoise_network.eval()

all_data = []
print('Start sampling...')
for step, data in tqdm(enumerate(dataloader), total=len(dataloader)):
    data = data.to(device)
    model_kwargs = {
                    'edge_index': data.edge_index,
                    'edge_attr': data.edge_attr,
                    'batch': data.batch
                    }

    x_given = data.x


    # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
    cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
    for interval in config.eval.cond_mask:
        cond_mask[..., interval[0]: interval[1]] = 1
    x_target = x_given[..., ~cond_mask.view(-1).bool()]

    x_cond = x_given[..., cond_mask.view(-1).bool()]
    shape_to_pred = x_target.shape
    


    if config.eval.model == 'GeoTDM':
        x_out = diffusion.p_sample_loop(shape=shape_to_pred, x_cond=x_cond, progress=False,
                    h=data.h, model_kwargs=model_kwargs)  # [BN, 3, T_p]
    print(x_out.shape)

    data['x_pred'] = x_out.detach()

    all_data.append(data.cpu())

without_pyg = True  # Use true for visualization and false for evaluation
samples_save_path = os.path.join(eval_output_path, 'samples_vis.pkl' if without_pyg else 'samples.pkl')
with open(samples_save_path, 'wb') as f:
    # data['x_pred'] = x_out
    # Parse data in case the drawer cannot process PyG data
    if without_pyg:
        data_save = []
        for i in range(len(all_data)):
            data_save.append((all_data[i].x.numpy(), all_data[i].x_pred.numpy(), all_data[i].batch.numpy()))
        # Currently no need to save node feature, since we read it from original data
        pickle.dump(data_save, f)
    else:
        pickle.dump(all_data, f)

print(f'Samples saved to {samples_save_path}')
# print(x_out.shape)
# exit(0)


