import argparse
import copy
import os

import time
import cv2
import torch.nn as nn
import numpy as np
import scipy.misc
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch_geometric.nn.pool import radius_graph
# from models.SGNN_batch import SGNN
import GNN
from data_new import PhysicsFleXDataset, new_collate , con_network
from config import gen_args
from data_new import normalize_scene_param
from data_new import load_data, get_scene_info
from data_new import get_env_group, prepare_input, denormalize

from utils import add_log, convert_groups_to_colors
from utils import create_instance_colors, set_seed, Tee, count_parameters
from data_new import calc_rigid_transform

args = gen_args()
set_seed(args.random_seed)
args.sequence_length = 130
os.system('mkdir -p ' + args.evalf)
os.system('mkdir -p ' + os.path.join(args.evalf, 'render'))

tee = Tee(os.path.join(args.evalf, 'eval.log'), 'w')


### evaluating

data_names = args.data_names

use_gpu = torch.cuda.is_available()

# create model and load weights
model_nocon_force = GNN.ForcePredictionLayer(n_layer=1, p_step=4, s_dim=2, hidden_dim=200, activation=nn.SiLU()).cuda()
model_con_force = GNN.ConForcePredictionLayer(n_layer=1, p_step=4, s_dim=2, hidden_dim=200, activation=nn.SiLU()).cuda()
model = GNN.GNS(n_layer=1, p_step=4, s_dim=2, hidden_dim=200, activation=nn.SiLU()).cuda()


print("model_kp #params: %d" % count_parameters(model))

if args.eval_epoch < 0:
    model_name = 'net_best.pth'
else:
    model_name = 'net_epoch_%d_iter_%d.pth' % (args.eval_epoch, args.eval_iter)

model_path = os.path.join(args.outf, model_name)
print("Loading network from %s" % model_path)

if args.stage == 'dy':
    pretrained_dict = torch.load(model_path)
    model_dict = model.state_dict()
    model.load_state_dict(pretrained_dict, strict=True)

else:
    AssertionError("Unsupported stage %s, using other evaluation scripts" % args.stage)

model.eval()


if use_gpu:
    model = model.cuda()

mean_p = torch.FloatTensor(args.mean_p).cuda()
std_p = torch.FloatTensor(args.std_p).cuda()
mean_d = torch.FloatTensor(args.mean_d).cuda()
std_d = torch.FloatTensor(args.std_d).cuda()

# infos = np.arange(10)
infos = np.arange(n)

loss_keeper_all = []
loss_raw_keeper_all = []

for idx_episode in range(len(infos)):

    print("Rollout %d / %d" % (idx_episode, len(infos)))

    B = 1
    n_particle, n_shape = 0, 0

    # ground truth
    datas = []
    p_gt = []
    s_gt = []
    v_gt = []
    for step in range(args.time_step):
        data_path = os.path.join(args.dataf, 'valid', str(infos[idx_episode]), str(step) + '.h5')

        data = load_data(data_names, data_path)

        if n_particle == 0 and n_shape == 0:
            n_particle, n_shape, scene_params = get_scene_info(data)
            scene_params = torch.FloatTensor(scene_params).unsqueeze(0)

        if args.verbose_data:
            print("n_particle", n_particle)
            print("n_shape", n_shape)

        datas.append(data)

        p_gt.append(data[0])
        s_gt.append(data[1])
        v_gt.append(data[3])

    # p_gt: time_step x N x state_dim
    # s_gt: time_step x n_s x 4
    p_gt = torch.FloatTensor(np.stack(p_gt))
    s_gt = torch.FloatTensor(np.stack(s_gt))
    v_gt = torch.FloatTensor(np.stack(v_gt))
    p_pred = torch.zeros(args.time_step, n_particle + n_shape, args.state_dim)

    # initialize particle grouping
    group_gt = get_env_group(args, n_particle, scene_params, use_gpu=use_gpu)

    print('scene_params:', group_gt[-1][0, 0].item())

    # memory: B x mem_nlayer x (n_particle + n_shape) x nf_memory
    # for now, only used as a placeholder
    # memory_init = model.init_memory(B, n_particle + n_shape)

    # model rollout
    loss = 0.
    loss_raw = 0.
    loss_counter = 0.
    loss_raw_keeper = []
    loss_keeper = []
    st_idx = args.n_his
    ed_idx = args.sequence_length

    with torch.set_grad_enabled(False):

        for step_id in range(st_idx, ed_idx):

            if step_id == st_idx:
                # state_cur (unnormalized): n_his x (n_p + n_s) x state_dim
                state_cur_pos = p_gt[step_id - args.n_his:step_id]
                state_cur_vel = v_gt[step_id - args.n_his:step_id]
                if use_gpu:
                    state_cur_pos = state_cur_pos.cuda()
                    state_cur_vel = state_cur_vel.cuda()

            if step_id % 50 == 0:
                print("Step %d / %d" % (step_id, ed_idx))

            n_particle = 186
            attr = torch.zeros(n_particle, 2)
            attr = attr[:186, :]
            particles = state_cur_pos.clone()  # [T, N, 3]
            particles = particles[:, :186, :]

            velocities = state_cur_vel.clone()  # [T, N, 3]
            velocities = velocities[:, :186, :]

            mask = particles[1, :, 1] < args.neighbor_radius
            attr[mask, 1] = 1
            scene_params = torch.FloatTensor(scene_params)

            cur_x = particles[1, ...]
            cur_v = velocities[1, ...]
            obj_id = torch.zeros_like(attr)[..., 0].long()
            obj_id = obj_id.cuda()
            # obj_id[:27] = 0
            obj_id[:171] = 0
            obj_id[171:176] = 1
            obj_id[176:181] = 2
            obj_id[181:186] = 3
            

            edge_index = radius_graph(cur_x, r=1, loop=False)  # [2, M]
            edge_index_inner_mask = obj_id[edge_index[0]] == obj_id[edge_index[1]]
            edge_index_inter_mask = obj_id[edge_index[0]] != obj_id[edge_index[1]]
            edge_index_inner = edge_index[..., edge_index_inner_mask]  # [2, M_in]
            edge_index_inter = edge_index[..., edge_index_inter_mask]  # [2, M_out]

        
            if use_gpu:
                attr = attr.cuda()
                particles = particles.cuda()
                velocities = velocities.cuda()
                edge_index_inner = edge_index_inner.cuda()
                edge_index_inter = edge_index_inter.cuda()
                obj_id = obj_id.cuda()
                # mass_tensor  = mass_tensor.cuda()

            
            st_time = time.time()

            v_norm = (state_cur_vel[-1:] - mean_d) / std_d
            x_norm = (state_cur_pos[-1:] - mean_p) / std_p
            x_norm = x_norm.squeeze(1)
            v_norm = v_norm.squeeze(1)  # [B, N, 3]
            force_cur_x = x_norm[0, ...]
            h = attr
            
            pred_noncon_force = model_nocon_force(x_norm.reshape(-1, 3)[:n_particle], v_norm.reshape(-1, 3)[:n_particle],h.reshape(-1, 2)[:n_particle],edge_index_inner,edge_index_inter, obj_id)
                    # pred_motion_norm = model(x_norm.reshape(-1, 3), v_norm.reshape(-1, 3), h.reshape(-1, 2), edge_index_inner,edge_index_inter, obj_id)
            nocon_force_norm = pred_noncon_force.reshape(B, -1, 3)
            noncon_force_norm = nocon_force_norm.squeeze(1)
                    
            edge_con_force_index_inner, edge_con_forceindex_inter = con_network(force_cur_x[:n_particle])
            pred_con_force = model_con_force(x_norm.reshape(-1, 3)[:n_particle], v_norm.reshape(-1, 3)[:n_particle],h.reshape(-1, 2)[:n_particle],edge_con_force_index_inner, edge_con_forceindex_inter, obj_id)
            con_force_norm = pred_con_force.reshape(B, -1, 3)
            con_force_norm = con_force_norm.squeeze(1)
                    
            force_norm = con_force_norm + nocon_force_norm
                    
            
            
            pred_pos, pred_vel = model(x_norm.reshape(-1, 3)[:n_particle], v_norm.reshape(-1, 3)[:n_particle],force_norm.reshape(-1,3)[:n_particle],
                                     h.reshape(-1, 2)[:n_particle], edge_index_inner,
                                     edge_index_inter, obj_id)
            # pred_motion_norm = pred_motion_norm.reshape(-1, 3)
            pred_pos = pred_pos.reshape( -1, 3)
            pred_vel = pred_vel.reshape( -1, 3)
            

            # use Kab
            next_pos = cur_x + (pred_vel * std_d) + mean_d
            try:
                # for obj_i in range(4):
                #     R, T = calc_rigid_transform(cur_x.detach().cpu().numpy()[27*obj_i: 27*(obj_i+1)], next_pos.detach().cpu().numpy()[27*obj_i: 27*(obj_i+1)])
                #     next_pos[27*obj_i: 27*(obj_i+1)] = torch.from_numpy((np.dot(R, cur_x[27*obj_i: 27*(obj_i+1)].detach().cpu().numpy().T) + T).T).cuda().float()
                for obj_i in range(3):
                    if obj_i == 0:
                        R, T = calc_rigid_transform(cur_x.detach().cpu().numpy()[0:171], next_pos.detach().cpu().numpy()[0:171])
                        next_pos[0:171] = torch.from_numpy((np.dot(R, cur_x[0:171].detach().cpu().numpy().T) + T).T).cuda().float()
                    else:
                        R, T = calc_rigid_transform(cur_x.detach().cpu().numpy()[171 +(5*(obj_i-1)):171+(5*(obj_i))], next_pos.detach().cpu().numpy()[171 +(5*(obj_i-1)):171+(5*(obj_i))])
                        next_pos[171 +(5*(obj_i-1)):171+(5*(obj_i))] = torch.from_numpy((np.dot(R, cur_x[171 +(5*(obj_i-1)):171+(5*(obj_i))].detach().cpu().numpy().T) + T).T).cuda().float()
            except:
                print('svd does not converge')
                pass
            pred_motion_norm = (next_pos - cur_x - mean_d) / std_d


            pred_pos = state_cur_pos[-1][:n_particle] + (pred_motion_norm * std_d + mean_d)

            # concatenate the state of the shapes
            # pred_pos (unnormalized): B x (n_p + n_s) x state_dim
            n_particle = 186

            # concatenate the state of the shapes
            # pred_pos (unnormalized): B x (n_p + n_s) x state_dim
            gt_pos = p_gt[step_id]
            if use_gpu:
                gt_pos = gt_pos.cuda()
            pred_pos = torch.cat([pred_pos, gt_pos[n_particle:]], 0)

            # gt_motion_norm (normalized): B x (n_p + n_s) x state_dim
            # pred_motion_norm (normalized): B x (n_p + n_s) x state_dim
            gt_motion = (p_gt[step_id] - p_gt[step_id - 1])
            if use_gpu:
                gt_motion = gt_motion.cuda()
            # mean_d, std_d = model.stat[2:]
            gt_motion_norm = (gt_motion - mean_d) / std_d
            pred_motion_norm = torch.cat([pred_motion_norm, gt_motion_norm[n_particle:]], 0)

            loss_cur = F.l1_loss(pred_pos, gt_pos)
            loss_cur_raw = F.l1_loss(pred_pos, gt_pos)

            loss += loss_cur
            loss_raw += loss_cur_raw
            loss_keeper.append(loss_cur.item())
            loss_raw_keeper.append(loss_cur_raw.item())
            loss_counter += 1

            # state_cur (unnormalized): B x n_his x (n_p + n_s) x state_dim
            state_cur_pos = torch.cat([state_cur_pos[1:], pred_pos.unsqueeze(0)], 0)
            state_cur_pos = state_cur_pos.detach()

            # record the prediction
            p_pred[step_id] = state_cur_pos[-1].detach().cpu()

    loss_keeper_all.append(loss_keeper)
    loss_raw_keeper_all.append(loss_raw_keeper)

    '''
    print loss
    '''
    loss /= loss_counter
    loss_raw /= loss_counter
    print("loss: %.6f, loss_raw: %.10f" % (loss.item(), loss_raw.item()))



    group_gt = [d.data.cpu().numpy()[0, ...] for d in group_gt]
    p_pred = p_pred.numpy()[st_idx:ed_idx]
    p_gt = p_gt.numpy()[st_idx:ed_idx]
    s_gt = s_gt.numpy()[st_idx:ed_idx]
    vis_length = ed_idx - st_idx

  