import torch
from torch_geometric.loader import DataLoader
import pickle
import os
import shutil
import yaml
from easydict import EasyDict
from torch.nn.parallel import DistributedDataParallel
import sys
sys.path.append('./')

from datasets.md17 import MD17Traj
from models.ETGNN import ETGNN
from diffusion.GeoTDM import GeoTDM, ModelMeanType, ModelVarType, LossType
from utils.misc import set_seed

device = 0

# Load args
eval_yaml_file = 'configs/md17_sampling.yaml'
with open(eval_yaml_file, 'r') as f:
    params = yaml.safe_load(f)
config = EasyDict(params)
cond = config.eval.cond
train_output_path = os.path.join(config.eval.output_base_path, config.eval.train_exp_name)
print(f'Train output path: {train_output_path}')

set_seed(config.eval.seed)

if config.eval.model == 'GeoTDM':
    try:
        train_yaml_file = os.path.join(train_output_path, 'md17_train.yaml')
    except:
        print('Notice: This exp was launched by submit_job.sh')
        train_yaml_file = os.path.join(train_output_path, f'{config.eval.train_exp_name}.yaml')  # Use this if batch
else:
    raise NotImplementedError()

# train_yaml_file = os.path.join(train_output_path, 'md17_train.yaml')


with open(train_yaml_file, 'r') as f:
    train_params = yaml.safe_load(f)
train_config = EasyDict(train_params)
if config.eval.eval_exp_name is None:
    config.eval.eval_exp_name = config.eval.train_exp_name + '_eval'
eval_output_path = os.path.join(config.eval.output_base_path, config.eval.eval_exp_name)
print(f'Eval output path: {eval_output_path}')
if not os.path.exists(eval_output_path):
    os.mkdir(eval_output_path)
shutil.copy(eval_yaml_file, eval_output_path)

# Overwrite model configs from training config
if config.eval.model == 'GeoTDM':
    config.denoise_model = train_config.denoise_model
    config.diffusion = train_config.diffusion
    # Overwrite diffusion timesteps for sampling
    if config.eval.sampling_timesteps is not None:
        config.diffusion.num_timesteps = config.eval.sampling_timesteps
# Overwrite cond_mask
if cond:
    if config.eval.cond_mask is None:
        config.eval.cond_mask = train_config.train.cond_mask

# Load dataset
dataset = MD17Traj(**config.data)
dataloader = DataLoader(dataset, batch_size=config.eval.batch_size, shuffle=False)

# Init model and optimizer
if config.eval.model == 'GeoTDM':
    # Init model and optimizer
    denoise_network = ETGNN(**config.denoise_model).to(device)
else:
    raise NotImplementedError()

# Load checkpoint
model_ckpt_path = os.path.join(train_output_path, f'ckpt_{config.eval.model_ckpt}.pt')
state_dict = torch.load(model_ckpt_path)
try:
    denoise_network.load_state_dict(state_dict)
except:
    state_dict = {k[7:]: v for k, v in state_dict.items()}
    # denoise_network = DistributedDataParallel(denoise_network, device_ids=[0])
    denoise_network.load_state_dict(state_dict)
print(f'Model loaded from {model_ckpt_path}')

if config.eval.model == 'GeoTDM':
    diffusion = GeoTDM(denoise_network=denoise_network,
                       model_mean_type=ModelMeanType.EPSILON,
                       model_var_type=ModelVarType.FIXED_LARGE,
                       loss_type=LossType.MSE,
                       device=device,
                       # rescale_timesteps=True,
                       rescale_timesteps=False,
                       **config.diffusion)

denoise_network.eval()

all_data = []

for step, data in enumerate(dataloader):
    data = data.to(device)
    model_kwargs = {'h': data.h,
                    'edge_index': data.edge_index,
                    'edge_attr': data.edge_attr,
                    'batch': data.batch
                    }

    x_start = data.x

    if cond:
        # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
        cond_mask = torch.zeros(1, 1, x_start.size(-1)).to(x_start)
        for interval in config.eval.cond_mask:
            cond_mask[..., interval[0]: interval[1]] = 1
        model_kwargs['cond_mask'] = cond_mask
        shape_to_pred = x_start[..., ~cond_mask.view(-1).bool()].shape
    else:
        x_start_ = x_start[..., :train_config.train.tot_len]
        shape_to_pred = x_start_.shape
        data.x = x_start_

    model_kwargs['x_given'] = x_start

    if config.eval.model == 'GeoTDM':
        x_out = diffusion.p_sample_loop(shape=shape_to_pred, progress=True, model_kwargs=model_kwargs)
    else:
        raise NotImplementedError()

    if cond:
        x_out = torch.cat((x_start[..., cond_mask.view(-1).bool()], x_out), dim=-1)

    data['x_pred'] = x_out.detach()

    all_data.append(data.cpu())

without_pyg = False  # Use true for visualization and false for evaluation
samples_save_path = os.path.join(eval_output_path, 'samples_vis.pkl' if without_pyg else 'samples.pkl')
with open(samples_save_path, 'wb') as f:
    # data['x_pred'] = x_out
    # Parse data in case the drawer cannot process PyG data
    if without_pyg:
        data_save = []
        for i in range(len(all_data)):
            data_save.append((all_data[i].x.numpy(), all_data[i].x_pred.numpy(), all_data[i].batch.numpy()))
        # Currently no need to save node feature, since we read it from original data
        pickle.dump(data_save, f)
    else:
        pickle.dump(all_data, f)

print(f'Samples saved to {samples_save_path}')
# print(x_out.shape)
# exit(0)


