import torch
from torch_geometric.loader import DataLoader
import pickle
import os
import shutil
import yaml
from easydict import EasyDict
from torch.nn.parallel import DistributedDataParallel
import sys
sys.path.append('./')

from datasets.md17 import MD17Traj
from models.ETGNN import ETGNN
from diffusion.GeoTDM import GeoTDM, ModelMeanType, ModelVarType, LossType
from utils.misc import set_seed

device = 0

# Load args
eval_yaml_file = 'configs/md17_sampling.yaml'
with open(eval_yaml_file, 'r') as f:
    params = yaml.safe_load(f)
config = EasyDict(params)
cond = config.eval.cond
train_output_path = os.path.join(config.eval.output_base_path, config.eval.train_exp_name)
print(f'Train output path: {train_output_path}')

set_seed(config.eval.seed)

if config.eval.model == 'GeoTDM':
    try:
        train_yaml_file = os.path.join(train_output_path, 'md17_train.yaml')
        with open(train_yaml_file, 'r') as f:
            train_params = yaml.safe_load(f)
    except:
        print('Notice: This exp was launched by submit_job.sh')
        train_yaml_file = os.path.join(train_output_path, f'{config.eval.train_exp_name}.yaml')  # Use this if batch
        with open(train_yaml_file, 'r') as f:
            train_params = yaml.safe_load(f)
else:
    raise NotImplementedError()

# train_yaml_file = os.path.join(train_output_path, 'md17_train.yaml')



train_config = EasyDict(train_params)
if config.eval.eval_exp_name is None:
    config.eval.eval_exp_name = config.eval.train_exp_name + '_eval'
eval_output_path = os.path.join(config.eval.output_base_path, config.eval.eval_exp_name)
print(f'Eval output path: {eval_output_path}')
if not os.path.exists(eval_output_path):
    os.mkdir(eval_output_path)
shutil.copy(eval_yaml_file, eval_output_path)

# Overwrite model configs from training config
if config.eval.model == 'GeoTDM':
    config.denoise_model = train_config.denoise_model
    config.diffusion = train_config.diffusion
    # Overwrite diffusion timesteps for sampling
    if config.eval.sampling_timesteps is not None:
        config.diffusion.num_timesteps = config.eval.sampling_timesteps
# Overwrite cond_mask
if cond:
    if config.eval.cond_mask is None:
        config.eval.cond_mask = train_config.train.cond_mask

# Load dataset
dataset = MD17Traj(**config.data)
dataloader = DataLoader(dataset, batch_size=config.eval.batch_size, shuffle=False)

# Init model and optimizer
if config.eval.model == 'GeoTDM':
    # Init model and optimizer
    denoise_network = ETGNN(**config.denoise_model).to(device)
else:
    raise NotImplementedError()

# Load checkpoint
model_ckpt_path = os.path.join(train_output_path, f'ckpt_{config.eval.model_ckpt}.pt')
state_dict = torch.load(model_ckpt_path)
try:
    denoise_network.load_state_dict(state_dict)
except:
    state_dict = {k[7:]: v for k, v in state_dict.items()}
    # denoise_network = DistributedDataParallel(denoise_network, device_ids=[0])
    denoise_network.load_state_dict(state_dict)
print(f'Model loaded from {model_ckpt_path}')

if config.eval.model == 'GeoTDM':
    diffusion = GeoTDM(denoise_network=denoise_network,
                       model_mean_type=ModelMeanType.EPSILON,
                       model_var_type=ModelVarType.FIXED_LARGE,
                       loss_type=LossType.MSE,
                       device=device,
                       # rescale_timesteps=True,
                       rescale_timesteps=False,
                       **config.diffusion)

denoise_network.eval()

all_data = []

for step, data in enumerate(dataloader):
    data = data.to(device)
    model_kwargs = {'h': data.h,
                    'edge_index': data.edge_index,
                    'edge_attr': data.edge_attr,
                    'batch': data.batch
                    }

    x_start = data.x

    n_repeat = 6

    x_out_collected = [x_start[..., :10]]

    for rep in range(n_repeat):

        if cond:
            # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
            cond_mask = torch.zeros(1, 1, x_start.size(-1)).to(x_start)
            for interval in config.eval.cond_mask:
                cond_mask[..., interval[0]: interval[1]] = 1
            model_kwargs['cond_mask'] = cond_mask
            shape_to_pred = x_start[..., ~cond_mask.view(-1).bool()].shape
        else:
            x_start_ = x_start[..., :train_config.train.tot_len]
            shape_to_pred = x_start_.shape
            data.x = x_start_

        model_kwargs['x_given'] = x_start

        if config.eval.model == 'GeoTDM':
            x_out = diffusion.p_sample_loop(shape=shape_to_pred, progress=True, model_kwargs=model_kwargs)

        else:
            raise NotImplementedError()

        x_out_collected.append(x_out)

        x_start = x_out[..., 10:]
        x_start = torch.cat((x_start, torch.randn_like(x_out)), dim=-1)

    if cond:
        x_out = torch.cat((x_start[..., cond_mask.view(-1).bool()], x_out), dim=-1)

    # data['x_pred'] = x_out.detach()

    data['x_pred'] = torch.cat(x_out_collected, dim=-1).detach()

    all_data.append(data.cpu())

    break

without_pyg = True  # Use true for visualization and false for evaluation
samples_save_path = os.path.join(eval_output_path, 'samples_vis.pkl' if without_pyg else 'samples.pkl')
with open(samples_save_path, 'wb') as f:
    # data['x_pred'] = x_out
    # Parse data in case the drawer cannot process PyG data
    if without_pyg:
        data_save = []
        for i in range(len(all_data)):
            data_save.append((all_data[i].x.numpy(), all_data[i].x_pred.numpy(), all_data[i].batch.numpy()))
        # Currently no need to save node feature, since we read it from original data
        pickle.dump(data_save, f)
    else:
        pickle.dump(all_data, f)

print(f'Samples saved to {samples_save_path}')
# print(x_out.shape)
# exit(0)


