import torch
from torch_geometric.loader import DataLoader
import pickle
import os
import shutil
import yaml
from tqdm import tqdm
from easydict import EasyDict
from torch.nn.parallel import DistributedDataParallel
import sys
sys.path.append('./')
from torch_geometric.nn import global_mean_pool
from datasets.cmu import CMU
from models.EGCTN import EGCTN
from models.Encoder_cond import Cencoder
from diffusion.GeoTDM import GeoTDM, ModelMeanType, ModelVarType, LossType

device = 0

# Load args
eval_yaml_file = "cond_configs_CMU/cmu_sampling.yaml"
with open(eval_yaml_file, 'r') as f:
    params = yaml.safe_load(f)
config = EasyDict(params)
train_output_path_act = os.path.join(config.eval.output_base_path, f"{config.data.act}")
train_output_path = os.path.join(train_output_path_act, config.eval.train_exp_name)

print(f'Train output path: {train_output_path}')
train_yaml_file = os.path.join(train_output_path, f'cmu_train_cond_ours.yaml')
with open(train_yaml_file, 'r') as f:
    train_params = yaml.safe_load(f)
train_config = EasyDict(train_params)
if config.eval.eval_exp_name is None:
    config.eval.eval_exp_name = config.eval.train_exp_name + '_eval'
eval_output_path = os.path.join(train_output_path_act, config.eval.eval_exp_name)
print(f'Eval output path: {eval_output_path}')
if not os.path.exists(eval_output_path):
    os.mkdir(eval_output_path)
shutil.copy(eval_yaml_file, eval_output_path)

# Overwrite model configs from training config
config.denoise_model = train_config.denoise_model
config.diffusion = train_config.diffusion
config.con_encoder = train_config.con_encoder
# Overwrite diffusion timesteps for sampling
if config.eval.sampling_timesteps is not None:
    config.diffusion.num_timesteps = config.eval.sampling_timesteps
# Overwrite cond_mask
if config.eval.cond_mask is None:
    config.eval.cond_mask = train_config.train.cond_mask

# Load dataset
dataset = CMU(**config.data)
dataloader = DataLoader(dataset, batch_size=config.eval.batch_size, shuffle=False)

# Init model and optimizer
denoise_network = EGCTN(**config.denoise_model).to(device)
con_encoder=Cencoder(**config.con_encoder).to(device)


for init_param_name, init_param in denoise_network.named_parameters():
    # print(init_param_name)
    if "s_modules.1.edge_mlp.actions.0.weight" in init_param_name:
        init_param_1=init_param.detach().clone()
        print(init_param_name)
            
# Load checkpoint
model_ckpt_path = os.path.join(train_output_path, f'ckpt_{config.eval.model_ckpt}.pt')
state_dict = torch.load(model_ckpt_path)
# try:
#     denoise_network.load_state_dict(state_dict['model'])
# except:
# print( state_dict["model"].items())
print('Notice: This exp was launched by submit_job.sh')
state_dict = {k[7:]: v for k, v in state_dict["model"].items()}
# print(state_dict.keys())
# denoise_network = DistributedDataParallel(denoise_network, device_ids=[0])
denoise_network.load_state_dict(state_dict)
    
for test_param_name, test_param in denoise_network.named_parameters():
    if "s_modules.1.edge_mlp.actions.0.weight" == test_param_name:
        test_param_1=test_param.detach().clone()
        print(test_param_name)
        
assert not torch.equal(init_param_1, test_param_1), "test model load failed"
print(f'Model loaded from {model_ckpt_path}')

diffusion = GeoTDM(denoise_network=denoise_network,con_encoder=con_encoder,
                   model_mean_type=ModelMeanType.EPSILON,
                   model_var_type=ModelVarType.FIXED_LARGE,
                   loss_type=LossType.MSE,
                   device=device,
                   rescale_timesteps=True,
                   **config.diffusion)

denoise_network.eval()

x_out_ = []
x_in_ = []
x_gt_ = []
batch_ = []
Error_K_all = []  # [B_tot, K, T]
system_id_all = []  # the index in the test dataset
for step, data in tqdm(enumerate(dataloader), total=len(dataloader)):
    data = data.to(device)
    model_kwargs = {'h': data.h,
                    'edge_index': data.edge_index,
                    'edge_attr': data.edge_attr,
                    'batch': data.batch
                    }

    x_given = data.x

    # Create temporal inpainting mask, 1 to keep the entries unchanged, 0 to modify it by diffusion
    cond_mask = torch.zeros(1, 1, x_given.size(-1)).to(x_given)
    for interval in config.eval.cond_mask:
        cond_mask[..., interval[0]: interval[1]] = 1
    x_target = x_given[..., ~cond_mask.view(-1).bool()]
    x_cond = x_given[..., cond_mask.view(-1).bool()]
    em_cond_mask = torch.zeros(1, 1, (x_target.size(-1)+config.con_encoder.em_T_out)).to(x_target)
    em_cond_mask[..., 0: config.con_encoder.em_T_out] = 1
    em_cond_mask = em_cond_mask.view(-1).bool()
    # model_kwargs['cond_mask'] = cond_mask
    # model_kwargs['x_given'] = x_start
    shape_to_pred = x_target.shape  # [BN, 3, T_p]


    all_x_out = []
    Error_K = []
    for k in tqdm(range(config.eval.K)):

        x_out = diffusion.p_sample_loop(shape=shape_to_pred, x_cond=x_cond, progress=True,h=data.h, em_cond_mask=em_cond_mask,  model_kwargs=model_kwargs)  # [BN, 3, T]
        all_x_out.append(x_out)
        distance = (x_out - x_target).square().sum(dim=1).sqrt()  # [BN, T_p]
        distance = global_mean_pool(distance, data.batch)  # [B, T_p]
        Error_K.append(distance)
        

    all_x_out = torch.stack(all_x_out, dim=0)  # [K, BN, 3, T_f]
    x_in = x_given[..., cond_mask.view(-1).bool()]  # [BN, 3, T_p]
    x_gt = x_given[..., ~cond_mask.view(-1).bool()]

    # x_out = torch.cat((x_start[..., cond_mask.view(-1).bool()], x_out), dim=-1)
    x_out_.append(all_x_out)
    x_in_.append(x_in)
    x_gt_.append(x_gt)
    batch_.append(data.system_id[data.batch])
    Error_K = torch.stack(Error_K, dim=-1)  # [B, T_p, K]
    system_id_all.append(data.system_id)  # [B]
    Error_K_all.append(Error_K)

x_out_ = torch.cat(x_out_, dim=1)
x_in_ = torch.cat(x_in_, dim=0)
x_gt_ = torch.cat(x_gt_, dim=0)
batch_ = torch.cat(batch_, dim=0)
results = {}
# Analyze
Error_K_all = torch.cat(Error_K_all, dim=0)  # [B_tot, T_p, K]
Error_min_all = Error_K_all.min(dim=2).values  # [B_tot, T_p]
Error_ave_all = Error_K_all.mean(dim=2)  # [B_tot, T_p]
system_id_all = torch.cat(system_id_all, dim=0)  # [B_tot]
results['system_id_range'] = [system_id_all.min().item(), system_id_all.max().item()]
eval_index = [2, 4, 8, 10, 14, 25]



samples_save_path = os.path.join(eval_output_path, 'samples.pkl')
with open(samples_save_path, 'wb') as f:
    # Parse data in case the drawer cannot process PyG data
    data_save = {}
    data_save['x_out'] = x_out_.detach().cpu().numpy()
    data_save['x_in'] = x_in_.detach().cpu().numpy()
    data_save['x_gt'] = x_gt_.detach().cpu().numpy()
    data_save['system_id'] = batch_.detach().cpu().numpy()
    # Currently no need to save node feature, since we read it from original data
    pickle.dump(data_save, f)
print(f'Samples saved to {samples_save_path}')
print(x_out_.shape)
print(x_in_.shape)
print(batch_.shape)


# training_losses = diffusion.training_losses(x_start=x_start, t=None, model_kwargs=model_kwargs)
# exit(0)
# loss = training_losses['loss'].mean()

# logs = {"loss": loss.detach().item(), "step": step}


eval_steps = {f'{_ * 40}ms': _ - 1 for _ in eval_index}

for key in eval_steps:
    cur_Error_min = Error_min_all[..., eval_steps[key]]  # [B_tot]
    cur_Error_ave = Error_ave_all[..., eval_steps[key]]  # [B_tot]
    act_index_map = dataset.act_index_map
    case_index = system_id_all
    sort_idx = torch.argsort(case_index)
    cur_Error_ave = cur_Error_ave[sort_idx]  # change into the correct sorting
    cur_Error_min = cur_Error_min[sort_idx]
    assert (case_index[sort_idx] - torch.arange(len(dataset)).to(case_index)).square().sum().item() == 0
    cur_min_, cur_ave_ = [], []
    for scenario in act_index_map:
        cur_index = torch.from_numpy(act_index_map[scenario]).long().to(system_id_all.device)
        cur_ave = cur_Error_ave[cur_index].mean().item() / config.data.scale
        cur_min = cur_Error_min[cur_index].mean().item() / config.data.scale
        results[f'{scenario}_min_{config.eval.K}_' + key] = cur_min
        results[f'{scenario}_ave_{config.eval.K}_' + key] = cur_ave
        cur_min_.append(cur_min)
        cur_ave_.append(cur_ave)
    results[f'AVE_min_{config.eval.K}_' + key] = sum(cur_min_) / len(cur_min_)
    results[f'AVE_ave_{config.eval.K}_' + key] = sum(cur_ave_) / len(cur_ave_)
    print(results)