#!/usr/bin/env python
# coding: utf-8

# In[2]:
import sys
import torchvision.transforms as transforms
import os
import random, threading, time, copy
from collections import namedtuple, deque, OrderedDict
import numpy as np, cv2
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from scipy import ndimage
import imutils
import matplotlib.pyplot as plt
from copy import deepcopy
from softgym.envs.bimanual_env import BimanualEnv
from softgym.envs.bimanual_tshirt import BimanualTshirtEnv
import softgym.envs.tshirt_descriptor as td
import pyflex
from plot_loss import plot_loss
import yaml
import socket
import traceback


class QNetBimanual(nn.Module):
    def __init__(self, in_channels=12):
        super(QNetBimanual, self).__init__()
        self.state_trunk = nn.Sequential(nn.Conv2d(in_channels, 32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 1),
                                         nn.ReLU(True))
        self.head        = nn.Sequential(nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,32, 3, 1),
                                                                 nn.ReLU(True),
                                                                 nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,2, 3, 1))
        
    def forward(self, obs1, goal1, obs2, goal2, inedgemasks=None):
        num_rotations = 8
        num_scales = 3
        # resize_scales = [(0.5,2),(1,1),(2,0.5)]
        resize_scales = [(2,0.5),(1,1),(0.5,2)]
        MID_FOLD_DIST = 78
        W = 200

        ROT_BIN = (2*np.pi/num_rotations)

        # Padding
        # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
        diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
        PAD = int(np.ceil((diag_len - W)/2))
        IM_W = W+2*PAD

        if inedgemasks is not None:
            x = torch.cat([obs1, inedgemasks, goal1], dim=1)
        else:
            x = torch.cat([obs1, goal1, obs2, goal2], dim=1)
        x  = self.state_trunk(x)
        out = self.head(x)
        out = nn.Upsample(size=(IM_W,IM_W), mode="bilinear").forward(out)
        return out

def arg(tag, default):
    HYPERS[tag] = type(default)((sys.argv[sys.argv.index(tag)+1])) if tag in sys.argv else default
    return HYPERS[tag]

def get_mask(img, obs_mode):
    if obs_mode in ['depth']:
        mask = (img[:,:,0] > 0) | (img[:,:,1] > 0) | (img[:,:,2] > 0)
        return mask
    else:
        img_hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
        mask1 = cv2.inRange(img_hsv, np.array([20, 50., 10.]), np.array([40, 255., 255.]))
        mask2 = cv2.inRange(img_hsv, np.array([80, 50., 10.]), np.array([100, 255., 255.]))
        mask = cv2.bitwise_or(mask1,mask2)
        kernel = np.ones((5,5),np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        kernel = np.ones((2,2),np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        return mask


def postprocess(heatmap, theta, scale):
    heatmap = rotate(heatmap, -theta).cpu().data.numpy()
    heatmap = unpad(heatmap, scale)
    heatmap = unscale_img(heatmap, scale)
    return heatmap


def rotate(img, theta):
    theta = -theta
    img = img.unsqueeze(0).unsqueeze(0)
    affine_mat = np.array([[np.cos(theta), np.sin(theta), 0],[-np.sin(theta), np.cos(theta), 0]])
    affine_mat = torch.FloatTensor(affine_mat).unsqueeze(0).cuda()
    flow_grid = F.affine_grid(affine_mat, img.data.size())
    img = F.grid_sample(img, flow_grid, mode='nearest')
    return img.squeeze()


def scale_img(img, scale):
    scaled_img = cv2.resize(img, None, fx=scale, fy=scale)
    return scaled_img

def unscale_img(img, scale):
    rscale = 1.0/scale
    scaled_img = cv2.resize(img, None, fx=rscale, fy=rscale)
    return scaled_img


def pad(img):
    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    pad_amount = int(np.ceil((IM_W - img.shape[0])/2))
    padding = [(pad_amount,pad_amount),(pad_amount,pad_amount)]
    if len(img.shape) == 3:
        padding.append((0,0))
    padded = np.pad(img, padding, 'constant', constant_values=0)
    return padded

def unpad(img, scale):
    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    pad_amount = int(np.ceil((IM_W - np.ceil(W*scale))/2))
    unpadded = copy.deepcopy(img[pad_amount:-pad_amount,pad_amount:-pad_amount])
    return unpadded

def mask(img, mask):
    masked = img*mask
    return masked


def get_heatmaps(qnet, img, goal, target=False, inedgemask=None):
    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    with torch.no_grad():
        angles_scale_idxs = []
        imgs1 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                            dtype=torch.float32, device='cuda')
        goals1 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                             dtype=torch.float32, device='cuda')
        imgs2 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                            dtype=torch.float32, device='cuda')
        goals2 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                             dtype=torch.float32, device='cuda')
        img_t = torch.from_numpy(img).float().cuda().permute(2, 0, 1)
        goal_t = torch.from_numpy(goal).float().cuda().permute(2, 0, 1)

        if inedgemask is not None:
            # TODO not working yet
            ems1 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                                     dtype=torch.float32, device='cuda')
            ems2 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                                     dtype=torch.float32, device='cuda')
            em_np = np.stack([inedgemask, inedgemask, inedgemask], axis=2)
            em_t = torch.from_numpy(em_np).float().cuda().permute(2, 0, 1)
        else:
            ems = None

        i = 0
        for rotate_idx in range(num_rotations):
            for scale_idx1 in range(num_scales):
                for scale_idx2 in range(num_scales):
                    scale1 = resize_scales[scale_idx1][0]
                    scale2 = resize_scales[scale_idx2][0]
                    #i = rotate_idx*num_scales + scale_idx
                        
                    size1 = (int(img.shape[0]*scale1), int(img.shape[1]*scale1))
                    size2 = (int(img.shape[0]*scale2), int(img.shape[1]*scale2))
                    padding1 = int(np.ceil((IM_W - size1[0])/2))
                    padding2 = int(np.ceil((IM_W - size2[0])/2))
                    transform1 = torch.nn.Sequential(
                        transforms.Resize(size1, interpolation=2),
                        transforms.Pad(padding1, fill=0, padding_mode='constant'))

                    transform2 = torch.nn.Sequential(
                        transforms.Resize(size2, interpolation=2),
                        transforms.Pad(padding2, fill=0, padding_mode='constant'))
                        
                    imgs1[i, :, :, :] = transforms.functional.rotate(transform1(img_t), np.degrees(rotate_idx*ROT_BIN))
                    goals1[i, :, :, :] = transforms.functional.rotate(transform1(goal_t), np.degrees(rotate_idx*ROT_BIN))
                    
                    imgs2[i, :, :, :] = transforms.functional.rotate(transform2(img_t), np.degrees(rotate_idx*ROT_BIN))
                    goals2[i, :, :, :] = transforms.functional.rotate(transform2(goal_t), np.degrees(rotate_idx*ROT_BIN))
                    
                    if inedgemask is not None:
                        ems1[i, :, :, :] = transforms.functional.rotate(transform1(em_t), np.degrees(rotate_idx*ROT_BIN))
                        ems2[i, :, :, :] = transforms.functional.rotate(transform2(em_t), np.degrees(rotate_idx*ROT_BIN))

                    angles_scale_idxs.append((rotate_idx, scale_idx1, scale_idx2))
                    i += 1
        if not target:
            heatmaps = qnet(imgs1, imgs2, goals1, goals2)
        else:
            heatmaps = tnet(imgs1, imgs2, goals1, goals2, inedgemasks=ems)
        return heatmaps, angles_scale_idxs


def select_action(qnet, obs, goal, rgb, viz=True, act_gt=None, random=False, inedgemask=None, outedgemask=None):
    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    mask = get_mask(obs, 'depth')
    with torch.no_grad():
        qnet.eval()
        inedgemask_int = inedgemask / 255.0 if inedgemask is not None else None
        heatmaps, angles_scale_idxs = get_heatmaps(qnet, obs/255.0, goal/255.0, inedgemask=inedgemask_int)
        maps = []
        vmaps = []
        max_idxs = []
        for heatmap, angle_scale_idx in zip(heatmaps, angles_scale_idxs):
            scale1 = resize_scales[angle_scale_idx[1]][0]
            scale2 = resize_scales[angle_scale_idx[2]][0]
            m1 = postprocess(heatmap[0,:,:], angle_scale_idx[0]*ROT_BIN, scale1)
            m2 = postprocess(heatmap[1,:,:], angle_scale_idx[0]*ROT_BIN, scale2)
            m = [m1,m2]
            vmaps.append(copy.deepcopy(m))
            if outedgemask is None:
                m[0][mask==0] = -10000000
                m[1][mask==0] = -10000000
            else:
                m[outedgemask==0] = -10000000
            maps.append(m)
            
        maps = np.array(maps)
        max_val = np.max(maps)

        if random:
            uv = td.random_sample_from_masked_image(mask, 2)
            u1,v1 = uv[0][0], uv[1][0]
            u2,v2 = uv[0][1], uv[1][1]
            action = [np.random.randint(num_rotations), np.random.randint(num_scales), np.random.randint(num_scales), u1, v1, u2, v2]
        else:
            #act = list(np.unravel_index(np.argmax(maps),maps.shape))
            #rotate_max_idx = int(np.floor(act[0]/num_scales))
            #scale_max_idx  = act[0]%num_scales
            #action = [rotate_max_idx, scale_max_idx, act[1], act[2]]

            # get best action
            n,_,r,c = maps.shape
            ch1 = maps[:,0,:,:].reshape(n,r*c)
            ch2 = maps[:,1,:,:].reshape(n,r*c)
            max1 = np.argmax(ch1,axis=1)
            max2 = np.argmax(ch2,axis=1)
            best_ind = np.argmax(np.amax(ch1,axis=1) + np.amax(ch2,axis=1))
            rotate_max_idx, scale1_max_idx, scale2_max_idx = angles_scale_idxs[best_ind]
            pos1 = max1[best_ind]
            pos2 = max2[best_ind]
            u1,v1 = int(pos1/c), pos1%c
            u2,v2 = int(pos2/c), pos2%c
            action = [rotate_max_idx, scale1_max_idx, scale2_max_idx, u1, v1, u2, v2]

        if viz: 
            viz_img1, viz_img2 = render_heatmaps(obs, vmaps, action)
            if outedgemask is not None:
                maps1 = deepcopy(maps) # for viz
                maps1[:, outedgemask==0] = 0
                viz_img_masked = render_heatmaps(obs, maps1, action)
                viz_img = np.concatenate((viz_img, viz_img_masked), axis=0)

        MID_FOLD_DIST=78
        theta = np.deg2rad((360.0/num_rotations) * action[0])
        fold_dist1 = MID_FOLD_DIST * resize_scales[action[1]][1]
        fold_dist2 = MID_FOLD_DIST * resize_scales[action[2]][1]
        x1 = action[4]
        y1 = action[3]
        x2 = action[6]
        y2 = action[5]

        arrow1_x = int(x1 + fold_dist1*np.cos(theta))
        arrow1_y = int(y1 + fold_dist1*np.sin(theta))
        arrow2_x = int(x2 + fold_dist2*np.cos(theta))
        arrow2_y = int(y2 + fold_dist2*np.sin(theta))

        #disp_img = (0.5*rgb + 0.5*goal).astype(np.uint8)
        disp_img = rgb.astype(np.uint8)
        color1 = (0,0,200)
        color2 = (0,200,0)
        cv2.circle(disp_img, (x1, y1), 6, color1, 2)
        cv2.arrowedLine(disp_img, (x1, y1), (arrow1_x, arrow1_y), color1, 2, tipLength=0.1)
        cv2.circle(disp_img, (x2, y2), 6, color2, 2)
        cv2.arrowedLine(disp_img, (x2, y2), (arrow2_x, arrow2_y), color2, 2, tipLength=0.1)

        act_type = "greedy"
    return action, act_type, disp_img, [viz_img1, viz_img2]

def render_heatmaps(color, heatmaps, max_idx):
    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    img1 = None
    img2 = None
    heatmaps = (heatmaps - np.min(heatmaps)) / (np.max(heatmaps) - np.min(heatmaps))

    i = 0
    for r in range(num_rotations):
        row1 = None
        row2 = None
        for s1 in range(num_scales):
            for s2 in range(num_scales):
                m1 = heatmaps[i,0,:,:]
                m1 = cv2.applyColorMap((m1*255).astype(np.uint8), cv2.COLORMAP_JET)
                m1 = (0.5*m1 + 0.5*(color*255)).astype(np.uint8)
                m2 = heatmaps[i,1,:,:]
                m2 = cv2.applyColorMap((m2*255).astype(np.uint8), cv2.COLORMAP_JET)
                m2 = (0.5*m2 + 0.5*(color*255)).astype(np.uint8)

                theta = r*ROT_BIN
                # fold_dist = 0.1*((2-s)+1.0)
                fold_dist1 = 0.1*((s1)+1.0) # because of flip
                fold_dist2 = 0.1*((s2)+1.0) # because of flip
                line_start = np.array([W/2, W/2])
                line_end1 = line_start + (W*(fold_dist1*np.array([np.sin(theta), np.cos(theta)]))).astype(np.int8)
                line_end2 = line_start + (W*(fold_dist2*np.array([np.sin(theta), np.cos(theta)]))).astype(np.int8)

                m1 = cv2.arrowedLine(m1, (int(line_start[1]),int(line_start[0])), (int(line_end1[1]),int(line_end1[0])), (100,100,100), 3, tipLength=0.3)                
                if r == max_idx[0] and s1 == max_idx[1]:
                    m1 = cv2.circle(m1, (int(max_idx[4]), int(max_idx[3])), 7, (255,255,255), 2)
                m1 = cv2.resize(m1, (100,100))

                m2 = cv2.arrowedLine(m2, (int(line_start[1]),int(line_start[0])), (int(line_end2[1]),int(line_end2[0])), (100,100,100), 3, tipLength=0.3)
                if r == max_idx[0] and s2 == max_idx[2]:
                    m1 = cv2.circle(m1, (int(max_idx[6]), int(max_idx[5])), 7, (255,255,255), 2)
                m2 = cv2.resize(m2, (100,100))

                if row1 is None:
                    row1 = m1
                else:
                    row1 = np.concatenate((row1, m1), axis=0)

                if row2 is None:
                    row2 = m2
                else:
                    row2 = np.concatenate((row2, m2), axis=0)

                i+=1
        if img1 is None:
            img1 = row1
        else:
            img1 = np.concatenate((img1,row1), axis=1)

        if img2 is None:
            img2 = row2
        else:
            img2 = np.concatenate((img2,row2), axis=1)
    return img1,img2

def get_eval_im(fn, obs_mode, base_path):
    if obs_mode == 'color':
        fpath = f'{base_path}/goals/square_final/{fn}.png'
    elif obs_mode == 'depth':
        fpath = f'{base_path}/goals/square_final/{fn}_depth.png'
    # elif obs_mode == 'desc':
        # fpath = f'../goals/{fn}_desc.png'
    im = cv2.imread(fpath)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im_mask = get_mask(im, obs_mode) != 0
    im[im_mask == False, :] = 0
    pos = np.load(f'{base_path}/goals/square_final/particles/{fn}.npy')
    return im, pos

def run_goal(train_iter, qnet, env, cfg, save_dir, eval_ims, em):
    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    for i, combo in enumerate(cfg['eval_combos']):
        is_multistep = type(combo[0]) == list
        if is_multistep: 
            start_name = combo[0][0]
            goal_name = combo[-1][1]
            if STEP == 'multi': # Provide intermediate steps
                name = f'ms-{start_name}-{goal_name}'
            else: # Only provide final goal
                name = f'ms-{start_name}-{goal_name}-finalonly'
        else: # single step goal
            start_name, goal_name = combo
            name = f'{start_name}-{goal_name}'

    # for start_name, goal_name in cfg['eval_combos']:
        # name = f'{start_name}-{goal_name}'
        existing = os.listdir(save_dir)
        exists = np.any([str(train_iter) in x and name in x for x in existing])

        if exists:
            print(f"Skipping iter {train_iter} for {name} as it already exists")
            continue

        stepcount = 0
        start, start_pos = eval_ims[start_name]
        goal, goal_pos = eval_ims[goal_name]
        goal_pos = goal_pos[:, :3]
        curr_goal = goal

        # reset env
        print(f"goal set to {goal_name}")
        obs = env.reset(given_goal=goal, given_goal_pos=goal_pos) # provide final goal for reward computation
        pyflex.set_positions(start_pos)
        pyflex.step()
        obs = env._get_obs()
        done = False

        vizs = [[],[]]
        disps = []
        distances = []
        while not done: 
            # if cfg['inputedge']:
            #     coords = pyflex.get_positions().reshape(-1, 4)[:,:3]
            #     rgb, depth = env.get_rgbd()
            #     mask = depth > 0
            #     in_edgemask, _, _ = em.get_act_mask(env, coords, rgb, depth, mask)
            # else:
            in_edgemask = None
            
            # if cfg['maskoutput'] and np.random.sample() < cfg['mask_edge_eval_prob']:
            #     coords = pyflex.get_positions().reshape(-1, 4)[:,:3]
            #     rgb, depth = env.get_rgbd()
            #     mask = depth > 0
            #     both_mask, fge_mask, ce_mask = em.get_act_mask(env, coords, rgb, depth, mask)
            #     # if np.random.uniform() < cfg['mask_inner_prob']:
            #     #     out_edgemask = ce_mask
            #     # else:
            #     #     out_edgemask = fge_mask
            #     out_edgemask = both_mask
            # else:
            out_edgemask = None

            if is_multistep and STEP == 'multi' and stepcount < len(combo):
                print("updating current goal to " + combo[stepcount][1])
                curr_goal, _ = eval_ims[combo[stepcount][1]]
            act, _, disp_img, viz_img = select_action(qnet, obs[cfg['obs_mode']], curr_goal, rgb=obs["color"], inedgemask=in_edgemask, outedgemask=out_edgemask)
            
            # step env
            nobs, rew, done, _ = env.step(act, on_table=cfg['on_table'])
            nobs_mask = get_mask(nobs[cfg['obs_mode']], cfg['obs_mode']) != 0
            nobs[cfg['obs_mode']][nobs_mask == False, :] = 0
        
            # Calc particle metrics
            pos = pyflex.get_positions().reshape(-1, 4)[:,:3]
            pos_metric = td.calc_metric_abs(pos, goal_pos)
            distances.append(pos_metric[1])

            vizs[0].append(viz_img[0])
            vizs[1].append(viz_img[1])
            disps.append(disp_img)

            obs = copy.deepcopy(nobs)

            stepcount += 1

        disps.append(obs["color"]) # final obs

        cv2.imwrite(f'{save_dir}/{name}_disp_step_{train_iter}.png', np.vstack(disps))
        cv2.imwrite(f'{save_dir}/{name}_viz_0_step_{train_iter}.png', np.vstack(vizs[0]))
        cv2.imwrite(f'{save_dir}/{name}_viz_1_step_{train_iter}.png', np.vstack(vizs[1]))

        # add the best metrics for episodes
        min_dist = min(distances)
        min_step = np.argmin(distances)
        print("Min distance {} achieved at step {}".format(min_dist, min_step))
        with open(f'{save_dir}/metrics.csv', 'a+') as f:
            f.write(f'{train_iter},{name},{min_dist},{min_step}\n')

def evaluate(EVAL_CFG, EVAL_PATH, ENV_TYPE, env, maxiter=25000, base_path='.', out_path='.'):
    with open(f'{out_path}/output/{EVAL_PATH}/config.yaml') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    with open(f'{base_path}/config/{EVAL_CFG}.yaml') as f:
        eval_cfg = yaml.load(f, Loader=yaml.FullLoader)
        cfg.update(eval_cfg)

    # hostname = socket.gethostname()
    # if 'compute' in hostname:
    #     hostname = 'seuss'
    # model_path = f"{cfg[hostname]['model_path']}"
    # tshirtmap_path = f"{cfg[hostname]['tshirtmap_path']}" if ENV_TYPE == 'tshirt' else None

    # # Probability to mask edge during training and eval (1.0, 1.0), (0.5, 0.0), (0.0, 0.0)
    # # MASK_EDGE_TRAIN_PROB = cfg['mask_edge_train_prob']
    # MASK_EDGE_EVAL_PROB = cfg['mask_edge_eval_prob'] if cfg['maskoutput'] else 0
    # VIZ      = True
    # OBS_MODE = cfg['obs_mode']
    # RUN_NAME = cfg['run_name']

    Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))
    ExperienceIndex = namedtuple('ExperienceIndex', ('obs', 'goal', 'act', 'rew', 'nobs', 'done', 'idx'))

    num_rotations = 8
    num_scales = 3
    # resize_scales = [(0.5,2),(1,1),(2,0.5)]
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    # diag_len = float(W)*np.sqrt(2)*resize_scales[0][1]
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    #if cfg['edgemask'] == 'gt':
    #    em = EdgeMasker(env, ENV_TYPE, tshirtmap_path=tshirtmap_path, edgethresh=cfg['edgethresh'])
    #elif cfg['edgemask'] == 'learned':
    #    em = LearnedEdgeMasker(env, ENV_TYPE, cfg['mean'], model_path)
    #else:
    #    em = None

    def add_to_eval_ims(cfg, base_path, eval_ims, combo):
        start_fn, goal_fn = combo
        print(start_fn, goal_fn)
        if start_fn not in eval_ims:
            im, pos = get_eval_im(start_fn, cfg['obs_mode'], base_path) 
            eval_ims[start_fn] = (im, pos)
        if goal_fn not in eval_ims:
            im, pos = get_eval_im(goal_fn, cfg['obs_mode'], base_path) 
            eval_ims[goal_fn] = (im, pos)
        return eval_ims


    # test goals
    eval_ims = dict()
    for i, combo in enumerate(cfg['eval_combos']):
        if type(combo[0]) == list: # multi step goal
            for c in combo:
                eval_ims = add_to_eval_ims(cfg, base_path, eval_ims, c)
        else: # single step goal
            eval_ims = add_to_eval_ims(cfg, base_path, eval_ims, combo)

    # eval_ims = dict()
    # for start_fn, goal_fn in cfg['eval_combos']: 
    #     if start_fn not in eval_ims:
    #         im, pos = get_eval_im(start_fn, cfg['obs_mode']) 
    #         eval_ims[start_fn] = (im, pos)
    #     if goal_fn not in eval_ims:
    #         im, pos = get_eval_im(goal_fn, cfg['obs_mode']) 
    #         eval_ims[goal_fn] = (im, pos)

    # if cfg['inputedge'] or cfg['maskoutput']:
    #     start_edgemasks = dict()
    #     for start_fn, _ in cfg['eval_combos']:
    #         if start_fn not in start_edgemasks:
    #             path = f'goals/{start_fn}_edgemask_{cfg["edgemask"]}.png'
    #             if os.path.exists(path):
    #                 print(f"path exists, loading {start_fn}_edgemask_{cfg['edgemask']} from file")
    #                 start_edgemask = cv2.imread(path)
    #                 start_edgemasks[start_fn] = start_edgemask
    #                 continue
    #             start_coords = np.load(f'goals/particles/{start_fn}.npy')[:, :3]
    #             start_rgb = cv2.imread(f'goals/{start_fn}.png')
    #             start_rgb = cv2.cvtColor(start_rgb, cv2.COLOR_BGR2RGB)
    #             start_depth = cv2.imread(f'goals/{start_fn}_depth.png')[:, :, 0]
    #             start_mask = (start_depth != 0).astype(int)
    #             start_edgemask, _, _ = em.get_act_mask(env, start_coords, start_rgb, start_depth, start_mask)
    #             print(f"writing {start_fn}_edgemask_{cfg['edgemask']} to {path}")
    #             cv2.imwrite(path, start_edgemask)
    #             start_edgemasks[start_fn] = start_edgemask

    # For every weight file, load the weights and evaluate on goals
    weight_fns = os.listdir(f'{out_path}/output/{EVAL_PATH}/weights')
    qnet_fns = sorted([x for x in weight_fns if 'qnet' in x], key=lambda k: int(k.split('_')[-1].replace('.pt', '')), reverse=True)
    tnet_fns = sorted([x for x in weight_fns if 'tnet' in x], key=lambda k: int(k.split('_')[-1].replace('.pt', '')), reverse=True)

    qnet_tmpl = '_'.join(qnet_fns[0].split('_')[:-1])
    tnet_tmpl = '_'.join(tnet_fns[0].split('_')[:-1])

    # Make save directory
    save_dir = f'{out_path}/output/{EVAL_PATH}/{"debug" if cfg["debug"] else cfg["eval_type"]}_evals_{STEP}'
    if os.path.exists(save_dir):
        if cfg['overwrite']:
            print(f"deleting {save_dir}")
            import shutil
            shutil.rmtree(save_dir)
            os.mkdir(save_dir)
    else: 
        os.mkdir(save_dir)

    for train_iter in np.arange(500, maxiter+1, step=500): 
        try:
            qnet_fn = f'{qnet_tmpl}_{train_iter}.pt'

            # load weights
            qnet = torch.load(f'{out_path}/output/{EVAL_PATH}/weights/{qnet_fn}')
            qnet.eval()

            print(f"evaluating weight {train_iter} {EVAL_PATH}")
            run_goal(train_iter, qnet, env, cfg, save_dir, eval_ims, None)
        except Exception as e:
            print(traceback.format_exc())            
            break

    import IPython; IPython.embed()        
    # if cfg['mask_edge_eval_prob'] == 1:
    #     metrics(EVAL_PATH, mask=True, out_path=out_path)
    # else:
    #     metrics(EVAL_PATH, mask=False, out_path=out_path)

if __name__ == '__main__':
    HYPERS = OrderedDict()

    OUTMASK = arg('outmask', False)
    EVAL_PATH = arg('eval_path', '')
    STEP = arg('step', 'multi') # 'single'
    ENV_TYPE = 'tshirt' if 'tshirt' in EVAL_PATH else 'towel'

    if ENV_TYPE == 'tshirt':
        env = BimanualTshirtEnv(use_depth=True,
                use_cached_states=False,
                horizon=5,
                cam_height=0.65,
                use_desc=False,
                action_repeat=1,
                headless=True)
    else:
        env = BimanualEnv(use_depth=True,
                use_cached_states=False,
                horizon=5,
                cam_height=0.65,
                use_desc=False,
                action_repeat=1,
                headless=True)

    evaluate('eval', EVAL_PATH, ENV_TYPE, env, base_path='..', out_path='..')
    # if OUTMASK:
        # evaluate('eval_outmask1.0', EVAL_PATH, ENV_TYPE, env)
        