import torch
from torch_geometric.loader import DataLoader
import pickle
import os
import shutil
import yaml
from easydict import EasyDict
from torch.nn.parallel import DistributedDataParallel
import sys
sys.path.append('./')

from datasets.nbody import NBody
from models.ETGNN import ETGNN
from diffusion.GeoTDM import GeoTDM, ModelMeanType, ModelVarType, LossType
from utils.misc import set_seed

from torch_geometric.nn import global_mean_pool

device = 0

# Load args
eval_yaml_file = 'configs/nbody_optimize.yaml'
with open(eval_yaml_file, 'r') as f:
    params = yaml.safe_load(f)
config = EasyDict(params)
cond = config.eval.cond
train_output_path = os.path.join(config.eval.output_base_path, config.eval.train_exp_name)
print(f'Train output path: {train_output_path}')

set_seed(config.eval.seed)

if config.eval.model == 'GeoTDM':
    if cond:
        train_yaml_file = os.path.join(train_output_path, 'nbody_train.yaml')
    else:
        # train_yaml_file = os.path.join(train_output_path, 'nbody_train_uncond.yaml')
        train_yaml_file = os.path.join(train_output_path, 'nbody_train.yaml')
else:
    raise NotImplementedError()

with open(train_yaml_file, 'r') as f:
    train_params = yaml.safe_load(f)
train_config = EasyDict(train_params)
if config.eval.eval_exp_name is None:
    config.eval.eval_exp_name = config.eval.train_exp_name + '_eval'
eval_output_path = os.path.join(config.eval.output_base_path, config.eval.eval_exp_name)
print(f'Eval output path: {eval_output_path}')
if not os.path.exists(eval_output_path):
    os.mkdir(eval_output_path)
shutil.copy(eval_yaml_file, eval_output_path)

# Overwrite model configs from training config
if config.eval.model == 'GeoTDM':
    config.denoise_model = train_config.denoise_model
    config.diffusion = train_config.diffusion
    # Overwrite diffusion timesteps for sampling
    if config.eval.sampling_timesteps is not None:
        config.diffusion.num_timesteps = config.eval.sampling_timesteps
# Overwrite cond_mask
if cond:
    if config.eval.cond_mask is None:
        config.eval.cond_mask = train_config.train.cond_mask

# Load dataset
dataset = NBody(**config.data)
dataloader = DataLoader(dataset, batch_size=config.eval.batch_size, shuffle=False)

if config.eval.model == 'GeoTDM':
    # Init model and optimizer
    denoise_network = ETGNN(**config.denoise_model).to(device)
else:
    raise NotImplementedError()

# Load checkpoint
model_ckpt_path = os.path.join(train_output_path, f'ckpt_{config.eval.model_ckpt}.pt')
state_dict = torch.load(model_ckpt_path)
try:
    denoise_network.load_state_dict(state_dict)
except:
    state_dict = {k[7:]: v for k, v in state_dict.items()}
    # denoise_network = DistributedDataParallel(denoise_network, device_ids=[0])
    denoise_network.load_state_dict(state_dict)
print(f'Model loaded from {model_ckpt_path}')

if config.eval.model == 'GeoTDM':
    diffusion = GeoTDM(denoise_network=denoise_network,
                       model_mean_type=ModelMeanType.EPSILON,
                       model_var_type=ModelVarType.FIXED_LARGE,
                       loss_type=LossType.MSE,
                       device=device,
                       rescale_timesteps=False,
                       # rescale_timesteps=True,
                       **config.diffusion)

denoise_network.eval()

# Load base trajectories
samples_path = os.path.join(config.eval.base_samples_path, 'samples.pkl')
with open(samples_path, 'rb') as f:
    base_data = pickle.load(f)
# Note: sampling and optimization must use the same batch size

all_data = []

all_ADE, all_FDE = [], []
all_ADE_base, all_FDE_base = [], []
all_ADE_original, all_FDE_original = [], []

for step, data in enumerate(dataloader):
    data = data.to(device)
    model_kwargs = {'h': data.h,
                    'edge_index': data.edge_index,
                    'edge_attr': data.edge_attr,
                    'batch': data.batch
                    }

    # x_start = data.x
    x_start = base_data[step].x_pred.to(data.x)

    if cond:
        # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
        cond_mask = torch.zeros(1, 1, x_start.size(-1)).to(x_start)
        for interval in config.eval.cond_mask:
            cond_mask[..., interval[0]: interval[1]] = 1
        model_kwargs['cond_mask'] = cond_mask
        shape_to_pred = x_start[..., ~cond_mask.view(-1).bool()].shape
        x_start_ = x_start[..., ~cond_mask.view(-1).bool()]
    else:
        x_start_ = x_start[..., :train_config.train.tot_len]
        shape_to_pred = x_start_.shape
        data.x = x_start_

    model_kwargs['x_given'] = x_start

    if config.eval.model == 'GeoTDM':
        x_out = diffusion.optimize(x_start_, optimize_step=config.eval.optimize_step, model_kwargs=model_kwargs)
        # [BN, 3, T_f]
        # x_out = diffusion.p_sample_loop(shape=shape_to_pred, progress=True, model_kwargs=model_kwargs)
    else:
        raise NotImplementedError()

    # Compute error
    x_target = data.x[..., ~cond_mask.view(-1).bool()]  # [BN, 3, T_f]
    x_base = x_start[..., ~cond_mask.view(-1).bool()]
    x_opt = x_out
    distance = (x_opt - x_target).square().sum(dim=1).sqrt()  # [BN, T_f]
    distance = global_mean_pool(distance, data.batch)  # [B, T_f]
    ADE = distance.mean(dim=-1)
    FDE = distance[..., -1]  # [B]
    distance1 = (x_opt - x_base).square().sum(dim=1).sqrt()  # [BN, T_f]
    distance1 = global_mean_pool(distance1, data.batch)
    ADE_base = distance1.mean(dim=-1)
    FDE_base = distance1[..., -1]
    all_ADE.append(ADE)
    all_FDE.append(FDE)
    all_ADE_base.append(ADE_base)
    all_FDE_base.append(FDE_base)
    distance2 = (x_base - x_target).square().sum(dim=1).sqrt()
    distance2 = global_mean_pool(distance2, data.batch)
    ADE_original = distance2.mean(dim=-1)
    FDE_original = distance2[..., -1]
    all_ADE_original.append(ADE_original)
    all_FDE_original.append(FDE_original)

    if cond:
        x_out = torch.cat((x_start[..., cond_mask.view(-1).bool()], x_out), dim=-1)

    data['x_pred'] = x_out.detach()

    all_data.append(data.cpu())

print(f'opt step: {config.eval.optimize_step}')
print(f'seed: {config.eval.seed}')
print('opt-gt ADE', torch.cat(all_ADE).mean().item())
print('opt-gt FDE', torch.cat(all_FDE).mean().item())
print('opt-base ADE', torch.cat(all_ADE_base).mean().item())
print('opt-base FDE', torch.cat(all_FDE_base).mean().item())
print('base-gt ADE', torch.cat(all_ADE_original).mean().item())
print('base-gt FDE', torch.cat(all_FDE_original).mean().item())



