#!/usr/bin/env python
# coding: utf-8

# In[2]:
import sys
import torchvision.transforms as transforms
import os
import socket
import random, threading, time, copy
from collections import namedtuple, deque, OrderedDict
import numpy as np, cv2
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from scipy import ndimage
import imutils
import matplotlib.pyplot as plt
from copy import deepcopy
from softgym.envs.bimanual_env import BimanualEnv
from softgym.envs.bimanual_tshirt import BimanualTshirtEnv
import softgym.envs.tshirt_descriptor as td
import pyflex
from plot_loss import plot_loss
import yaml
from models import QNetBimanual

HYPERS = OrderedDict()
def arg(tag, default):
    HYPERS[tag] = type(default)((sys.argv[sys.argv.index(tag)+1])) if tag in sys.argv else default
    return HYPERS[tag]

ENV_TYPE = arg('env', 'towel')
ADD_CFG = arg('addcfg', '')
BASE_PATH = arg('base', os.getcwd())
SEED = arg('seed', 0)

with open(f'{BASE_PATH}/config/config_{ENV_TYPE}_offline.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

if ADD_CFG != '':
    with open(f'{BASE_PATH}/config/{ADD_CFG}.yaml') as f:
        addcfg = yaml.load(f, Loader=yaml.FullLoader)
        cfg.update(addcfg)

hostname = socket.gethostname()
buffer_folder = f"{cfg[hostname]['buffer_folder']}/{cfg['run_name']}"

if ENV_TYPE == 'tshirt':
    tshirtmap_path = f"{cfg[hostname]['tshirtmap_path']}"

cfg['seed'] = SEED

HER      = cfg['her']
BATCH    = cfg['batch']
GAMMA    = cfg['gamma']
TAU      = cfg['tau']
EPSILON  = cfg['epsilon']
EXP_STEP = 0
DR       = cfg['domain_rand']

# Probability to mask edge during training and eval (1.0, 1.0), (0.5, 0.0), (0.0, 0.0)
MASK_EDGE_TRAIN_PROB = cfg['mask_edge_train_prob']
MASK_EDGE_EVAL_PROB = cfg['mask_edge_eval_prob']
VIZ      = True
OBS_MODE = cfg['obs_mode']
RUN_NAME = cfg['run_name']

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))
ExperienceIndex = namedtuple('ExperienceIndex', ('obs', 'goal', 'act', 'rew', 'nobs', 'done', 'idx', 'edgemask'))


def get_mask(img):
    if OBS_MODE in ['depth']:
        mask = (img[:,:,0] > 0) | (img[:,:,1] > 0) | (img[:,:,2] > 0)
        return mask
    else:
        img_hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
        mask1 = cv2.inRange(img_hsv, np.array([20, 50., 10.]), np.array([40, 255., 255.]))
        mask2 = cv2.inRange(img_hsv, np.array([80, 50., 10.]), np.array([100, 255., 255.]))
        mask = cv2.bitwise_or(mask1,mask2)
        kernel = np.ones((5,5),np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        kernel = np.ones((2,2),np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        return mask


def postprocess(heatmap, theta, scale):
    heatmap = rotate(heatmap, -theta).cpu().data.numpy()
    heatmap = unpad(heatmap, scale)
    heatmap = unscale_img(heatmap, scale)
    return heatmap


def rotate(img, theta):
    theta = -theta
    img = img.unsqueeze(0).unsqueeze(0)
    affine_mat = np.array([[np.cos(theta), np.sin(theta), 0],[-np.sin(theta), np.cos(theta), 0]])
    affine_mat = torch.FloatTensor(affine_mat).unsqueeze(0).cuda()
    flow_grid = F.affine_grid(affine_mat, img.data.size())
    img = F.grid_sample(img, flow_grid, mode='nearest')
    return img.squeeze()


def scale_img(img, scale):
    scaled_img = cv2.resize(img, None, fx=scale, fy=scale)
    return scaled_img

def unscale_img(img, scale):
    rscale = 1.0/scale
    scaled_img = cv2.resize(img, None, fx=rscale, fy=rscale)
    return scaled_img


def pad(img):
    pad_amount = int(np.ceil((IM_W - img.shape[0])/2))
    padding = [(pad_amount,pad_amount),(pad_amount,pad_amount)]
    if len(img.shape) == 3:
        padding.append((0,0))
    padded = np.pad(img, padding, 'constant', constant_values=0)
    return padded

def unpad(img, scale):
    pad_amount = int(np.ceil((IM_W - np.ceil(W*scale))/2))
    unpadded = copy.deepcopy(img[pad_amount:-pad_amount,pad_amount:-pad_amount])
    return unpadded

def mask(img, mask):
    masked = img*mask
    return masked


def train():
    if len(buf) < BATCH or step < EXP_STEP:
        return None
    print(f"Training {model_name}")
    qnet.train()

    loss  = train_on_buf(buf)

    opt.zero_grad()
    loss.backward()
    opt.step()

    return {"loss": loss.item()}

def sample_batch(buf):
    batch_size = BATCH if len(buf) >= BATCH else len(buf)
    batch = random.sample(buf, batch_size)
    if cfg['index_buffer'] == True:
        img_batch = []
        for b in batch:
            idx = b.obs
            obs = load_buffer_data(b.obs, isnobs=False)
            nobs = load_buffer_data(b.nobs, isnobs=True)
            goal = b.goal
            act = b.act
            done = b.done
            rew = b.rew

            if cfg['inputedge']:
                edgemask = cv2.imread(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/obs_actmask_{str(idx).zfill(6)}.png')[:, :, 0]
                edgemask = edgemask / 255
            else:
                edgemask = None

            img_batch.append(ExperienceIndex(obs, goal, act, rew, nobs, done, idx, edgemask))
        return img_batch
    else:
        return batch


def replace_goals(batch, buf):
    her_batch = []
    for b in batch:
        obs = b.obs
        nobs = b.nobs
        act = b.act
        done = b.done
        goal = b.goal

        if cfg['inputedge']:
            edgemask = b.edgemask
        else:
            edgemask = None

        if HER:
            if cfg['sparse']:
                # 50 percent chance of sampling a random nobs, otherwise use actual nobs as goal
                if np.random.uniform() < 0.5:
                    goal = random_goal(buf)
                    # Check if the random nobs happens to be this nobs
                    if ((nobs==goal).all()):
                        rew = 1
                        done = True
                    else:
                        rew = 0
                        done = False
                else:
                    goal = b.nobs
                    rew = 1
                    done = True
            else: # dense reward
                # 50 percent chance of keeping the original goal, otherwise relabel goal with nobs
                if np.random.uniform() < 0.5:
                    rew = b.rew / cfg['reward_magn']
                    done = False
                else:
                    goal = b.nobs
                    rew = 0.0
                    done = True

        def augment(img,angle,shift):
            img = imutils.rotate(img, angle)
            img = imutils.translate(img,*shift)
            return img

        angle = np.random.uniform(-5,5)
        shift = (np.random.uniform(-5,5),np.random.uniform(-5,5))

        if np.random.uniform() < 0.9:
            obs = augment(obs,angle,shift)
            nobs = augment(nobs,angle,shift)
            goal = augment(goal,angle,shift)
            if cfg['inputedge']:
                edgemask = augment(edgemask, angle, shift)

        if cfg['index_buffer'] == True:
            exp = ExperienceIndex(obs,goal,act,rew,nobs,done, b.idx, edgemask)
        else:
            exp = Experience(obs,goal,act,rew,nobs,done)
        her_batch.append(exp)
    return her_batch


def train_on_buf(buf):
    batch = sample_batch(buf)
    batch = replace_goals(batch, buf)

    if DR:
        batch = apply_dr(batch) # Warning: does not include ExperienceIndex

    if OBS_MODE == 'depth' and cfg['perlin_noise']:
        batch = apply_perlin(batch, index_buffer=cfg['index_buffer'])

    losses = []
    for batch_idx, b in enumerate(batch):
        if cfg['inputedge']:
            in_edgemask = b.edgemask
        else:
            in_edgemask = None

        ## EVALUATE Qn
        heatmaps, angles_scale_idxs = get_heatmaps(b.nobs/255.0, b.goal/255.0, target=True, inedgemask=in_edgemask)
        maps = torch.zeros([num_rotations*num_scales*num_scales, 2, W, W], 
                            dtype=torch.float32, device='cuda')
        obs_mask = get_mask(b.obs)
        for i, (heatmap, angle_scale_idx) in enumerate(zip(heatmaps, angles_scale_idxs)):
            scale1 = resize_scales[angle_scale_idx[1]][0]
            scale2 = resize_scales[angle_scale_idx[2]][0]
            
            size1 = (int(b.nobs.shape[0]*scale1), int(b.nobs.shape[1]*scale1))
            size2 = (int(b.nobs.shape[0]*scale2), int(b.nobs.shape[1]*scale2))
            padding1 = int(np.ceil((IM_W - size1[0])/2))
            padding2 = int(np.ceil((IM_W - size2[0])/2))
            transform1 = torch.nn.Sequential(
                transforms.CenterCrop(IM_W - padding1*2),
                transforms.Resize((b.nobs.shape[0], b.nobs.shape[1]), interpolation=2))
            transform2 = torch.nn.Sequential(
                transforms.CenterCrop(IM_W - padding2*2),
                transforms.Resize((b.nobs.shape[0], b.nobs.shape[1]), interpolation=2))

            theta = np.degrees(angle_scale_idx[0]*ROT_BIN)
            maps[i, 0, :, :] = transform1(transforms.functional.rotate(torch.unsqueeze(heatmap[0],0), -theta))
            maps[i, 1, :, :] = transform2(transforms.functional.rotate(torch.unsqueeze(heatmap[1],0), -theta))
            maps[i, :, ~obs_mask] = torch.finfo(torch.float32).min

        n,_,r,c = maps.shape
        flat_maps1 = maps[:,0,:,:].view(n,r*c)
        flat_maps2 = maps[:,1,:,:].view(n,r*c)
        max1,_ = torch.max(flat_maps1,axis=1)
        max2,_ = torch.max(flat_maps2,axis=1)

        Qn = torch.max(max1 + max2)/2

        # Calculate target value
        y = b.rew + GAMMA*Qn*(not b.done)

        # Get Q from obs, act
        theta = np.degrees(b.act[0]*ROT_BIN)
        scale1 = resize_scales[b.act[1]][0]
        scale2 = resize_scales[b.act[2]][0]

        obs_t = torch.from_numpy(b.obs/255.0).float().cuda().permute(2, 0, 1)
        goal_t = torch.from_numpy(b.goal/255.0).float().cuda().permute(2, 0, 1)
        if in_edgemask is not None:
            edgemasks = np.stack([in_edgemask, in_edgemask, in_edgemask], axis=2)
            edgemasks_t = torch.from_numpy(edgemasks).float().cuda().permute(2, 0, 1)
        size1 = (int(b.obs.shape[0]*scale1), int(b.obs.shape[1]*scale1))
        size2 = (int(b.obs.shape[0]*scale2), int(b.obs.shape[1]*scale2))
        padding1 = int(np.ceil((IM_W - size1[0])/2))
        padding2 = int(np.ceil((IM_W - size2[0])/2))
        transform1 = torch.nn.Sequential(
            transforms.Resize(size1, interpolation=2),
            transforms.Pad(padding1, fill=0, padding_mode='constant'))
        transform2 = torch.nn.Sequential(
            transforms.Resize(size2, interpolation=2),
            transforms.Pad(padding2, fill=0, padding_mode='constant'))
        
        obs1 = transforms.functional.rotate(transform1(obs_t), theta).unsqueeze(0)
        goal1 = transforms.functional.rotate(transform1(goal_t), theta).unsqueeze(0)
        obs2 = transforms.functional.rotate(transform2(obs_t), theta).unsqueeze(0)
        goal2 = transforms.functional.rotate(transform2(goal_t), theta).unsqueeze(0)
        if in_edgemask is not None:
            edgemasks_tf = transforms.functional.rotate(transform(edgemasks_t), theta).unsqueeze(0)
            heatmap_unrot = qnet(obs1, goal1, obs2, goal2, inedgemasks=edgemasks_tf)
        else:
            heatmap_unrot = qnet(obs1, goal1, obs2, goal2,)
        heatmap = transforms.functional.rotate(heatmap_unrot, -theta).squeeze()
        
        mask1 = torch.zeros((W, W), dtype=torch.uint8, device='cuda')
        mask1[b.act[3], b.act[4]] = 255
        mask1_rot = transform1(mask1.unsqueeze(0)).squeeze()
        mask1_rot[mask1_rot != 0] = 255

        mask2 = torch.zeros((W, W), dtype=torch.uint8, device='cuda')
        mask2[b.act[5], b.act[6]] = 255
        mask2_rot = transform2(mask2.unsqueeze(0)).squeeze()
        mask2_rot[mask2_rot != 0] = 255
        
        ind1 = [torch.floor(torch.argmax(mask1_rot) / mask1_rot.size()[1]).long(),
               torch.argmax(mask1_rot) % mask1_rot.size()[1]]
        ind2 = [torch.floor(torch.argmax(mask2_rot) / mask2_rot.size()[1]).long(),
               torch.argmax(mask2_rot) % mask2_rot.size()[1]]
    
        q = (heatmap[0][ind1[0],ind1[1]] + heatmap[1][ind2[0],ind2[1]])/2
        if torch.isnan(Qn):
            fig, ax = plt.subplots(1, 3)
            ax[0].set_title("obs")
            ax[0].imshow(b.obs)
            ax[1].set_title("nobs")
            ax[1].imshow(b.nobs)
            ax[2].set_title("goal")
            ax[2].imshow(b.goal)
            torch.save(maps,"{}/output/{}/maps.dmp".format(cfg[hostname]['out_path'], model_name))
            torch.save(maps,"{}/output/{}/heatmap.dmp".format(cfg[hostname]['out_path'], model_name))
            plt.show()
            raise ValueError('Qn is nan')

        if torch.isnan(q):
            fig, ax = plt.subplots(1, 3)
            ax[0].set_title("obs")
            ax[0].imshow(b.obs)
            ax[1].set_title("nobs")
            ax[1].imshow(b.nobs)
            ax[2].set_title("goal")
            ax[2].imshow(b.goal)
            torch.save(maps,"{}/output/{}/maps.dmp".format(cfg[hostname]['out_path'], model_name))
            torch.save(maps,"{}/output/{}/heatmap.dmp".format(cfg[hostname]['out_path'], model_name))
            plt.show()
            raise ValueError('q is nan')

        y = torch.FloatTensor([y]).cuda()
        loss = F.smooth_l1_loss(q,y)

        if torch.isnan(loss):
            continue

        losses.append(loss)
        
        if np.random.sample() < cfg['samplebg']:
            # get mask from obs
            bg_mask = obs[0, 0, :, :] == 0

            # get random point from non cloth mask
            h, w = bg_mask.size()
            bg_idx = torch.randint(h*w, (1,))
            bg_y = bg_idx // h
            bg_x = bg_idx % h

            # Get heatmap value at random point
            bg_q = heatmap[bg_y, bg_x]

            # Get loss compared to 0 and append
            bgloss = F.smooth_l1_loss(bg_q, torch.tensor(0, dtype=torch.float32, device='cuda'))
            losses.append(bgloss)

        # debug
        if False:
            fig, ax = plt.subplots(1, 4)
            ax[0].set_title("obs")
            ax[0].imshow(b.obs)
            ax[1].set_title("nobs")
            ax[1].imshow(b.nobs)
            ax[2].set_title("goal")
            ax[2].imshow(b.goal)
            ax[3].set_title("mask")
            if cfg['inputedge']:
                ax[3].imshow(in_edgemask)
            plt.show()

    loss = torch.mean(torch.stack(losses))
    return loss


def get_heatmaps(img, goal, target=False, inedgemask=None):
    with torch.no_grad():
        angles_scale_idxs = []
        imgs1 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                            dtype=torch.float32, device='cuda')
        goals1 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                             dtype=torch.float32, device='cuda')
        imgs2 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                            dtype=torch.float32, device='cuda')
        goals2 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                             dtype=torch.float32, device='cuda')
        img_t = torch.from_numpy(img).float().cuda().permute(2, 0, 1)
        goal_t = torch.from_numpy(goal).float().cuda().permute(2, 0, 1)

        if inedgemask is not None:
            ems1 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                                     dtype=torch.float32, device='cuda')
            ems2 = torch.zeros([num_rotations*num_scales*num_scales, 3, IM_W, IM_W], 
                                     dtype=torch.float32, device='cuda')
            em_np = np.stack([inedgemask, inedgemask, inedgemask], axis=2)
            em_t = torch.from_numpy(em_np).float().cuda().permute(2, 0, 1)
        else:
            ems = None

        i = 0
        for rotate_idx in range(num_rotations):
            for scale_idx1 in range(num_scales):
                for scale_idx2 in range(num_scales):
                    scale1 = resize_scales[scale_idx1][0]
                    scale2 = resize_scales[scale_idx2][0]
                        
                    size1 = (int(img.shape[0]*scale1), int(img.shape[1]*scale1))
                    size2 = (int(img.shape[0]*scale2), int(img.shape[1]*scale2))
                    padding1 = int(np.ceil((IM_W - size1[0])/2))
                    padding2 = int(np.ceil((IM_W - size2[0])/2))
                    transform1 = torch.nn.Sequential(
                        transforms.Resize(size1, interpolation=2),
                        transforms.Pad(padding1, fill=0, padding_mode='constant'))

                    transform2 = torch.nn.Sequential(
                        transforms.Resize(size2, interpolation=2),
                        transforms.Pad(padding2, fill=0, padding_mode='constant'))
                        
                    imgs1[i, :, :, :] = transforms.functional.rotate(transform1(img_t), np.degrees(rotate_idx*ROT_BIN))
                    goals1[i, :, :, :] = transforms.functional.rotate(transform1(goal_t), np.degrees(rotate_idx*ROT_BIN))
                    
                    imgs2[i, :, :, :] = transforms.functional.rotate(transform2(img_t), np.degrees(rotate_idx*ROT_BIN))
                    goals2[i, :, :, :] = transforms.functional.rotate(transform2(goal_t), np.degrees(rotate_idx*ROT_BIN))
                    
                    if inedgemask is not None:
                        ems1[i, :, :, :] = transforms.functional.rotate(transform1(em_t), np.degrees(rotate_idx*ROT_BIN))
                        ems2[i, :, :, :] = transforms.functional.rotate(transform2(em_t), np.degrees(rotate_idx*ROT_BIN))

                    angles_scale_idxs.append((rotate_idx, scale_idx1, scale_idx2))
                    i += 1
        if not target:
            heatmaps = qnet(imgs1, imgs2, goals1, goals2, inedgemasks=ems)
        else:
            heatmaps = tnet(imgs1, imgs2, goals1, goals2, inedgemasks=ems)
        return heatmaps, angles_scale_idxs


def select_action(obs, goal, rgb, viz=False, act_gt=None, random=False, inedgemask=None, outedgemask=None):
    mask = get_mask(obs)
    with torch.no_grad():
        qnet.eval()
        inedgemask_int = inedgemask / 255.0 if inedgemask is not None else None
        heatmaps, angles_scale_idxs = get_heatmaps(obs/255.0, goal/255.0, inedgemask=inedgemask_int)
        maps = []
        vmaps = []
        max_idxs = []
        for heatmap, angle_scale_idx in zip(heatmaps, angles_scale_idxs):
            scale1 = resize_scales[angle_scale_idx[1]][0]
            scale2 = resize_scales[angle_scale_idx[2]][0]
            m1 = postprocess(heatmap[0,:,:], angle_scale_idx[0]*ROT_BIN, scale1)
            m2 = postprocess(heatmap[1,:,:], angle_scale_idx[0]*ROT_BIN, scale2)
            m = [m1,m2]
            vmaps.append(copy.deepcopy(m))
            if outedgemask is None:
                m[0][mask==0] = -10000000
                m[1][mask==0] = -10000000
            else:
                m[0][outedgemask==0] = -10000000
                m[1][outedgemask==0] = -10000000
            maps.append(m)
            
        maps = np.array(maps)
        max_val = np.max(maps)

        if random:
            uv = td.random_sample_from_masked_image(mask, 2)
            u1,v1 = uv[0][0], uv[1][0]
            u2,v2 = uv[0][1], uv[1][1]
            action = [np.random.randint(num_rotations), np.random.randint(num_scales), np.random.randint(num_scales), u1, v1, u2, v2]
        else:
            # get best action
            n,_,r,c = maps.shape
            ch1 = maps[:,0,:,:].reshape(n,r*c)
            ch2 = maps[:,1,:,:].reshape(n,r*c)
            max1 = np.argmax(ch1,axis=1)
            max2 = np.argmax(ch2,axis=1)
            best_ind = np.argmax(np.amax(ch1,axis=1) + np.amax(ch2,axis=1))
            rotate_max_idx, scale1_max_idx, scale2_max_idx = angles_scale_idxs[best_ind]
            pos1 = max1[best_ind]
            pos2 = max2[best_ind]
            u1,v1 = int(pos1/c), pos1%c
            u2,v2 = int(pos2/c), pos2%c
            action = [rotate_max_idx, scale1_max_idx, scale2_max_idx, u1, v1, u2, v2]

        if viz: 
            viz_img1, viz_img2 = render_heatmaps(obs, vmaps, action)
            if outedgemask is not None:
                maps1 = deepcopy(maps) # for viz
                maps1[:, :, outedgemask==0] = 0
                viz_img_masked = render_heatmaps(obs, maps1, action)
                viz_img = np.concatenate((viz_img, viz_img_masked), axis=0)
        else:
            viz_img1, viz_img2 = None, None

        MID_FOLD_DIST=78
        theta = np.deg2rad((360.0/num_rotations) * action[0])
        fold_dist1 = MID_FOLD_DIST * resize_scales[action[1]][1]
        fold_dist2 = MID_FOLD_DIST * resize_scales[action[2]][1]
        x1 = action[4]
        y1 = action[3]
        x2 = action[6]
        y2 = action[5]

        arrow1_x = int(x1 + fold_dist1*np.cos(theta))
        arrow1_y = int(y1 + fold_dist1*np.sin(theta))
        arrow2_x = int(x2 + fold_dist2*np.cos(theta))
        arrow2_y = int(y2 + fold_dist2*np.sin(theta))

        #disp_img = (0.5*rgb + 0.5*goal).astype(np.uint8)
        disp_img = rgb.astype(np.uint8)
        color1 = (0,0,200)
        color2 = (0,200,0)
        cv2.circle(disp_img, (x1, y1), 6, color1, 2)
        cv2.arrowedLine(disp_img, (x1, y1), (arrow1_x, arrow1_y), color1, 2, tipLength=0.1)
        cv2.circle(disp_img, (x2, y2), 6, color2, 2)
        cv2.arrowedLine(disp_img, (x2, y2), (arrow2_x, arrow2_y), color2, 2, tipLength=0.1)

        act_type = "greedy"
    return action, act_type, disp_img, [viz_img1, viz_img2]

def render_heatmaps(color, heatmaps, max_idx):
    img1 = None
    img2 = None
    heatmaps = (heatmaps - np.min(heatmaps)) / (np.max(heatmaps) - np.min(heatmaps))

    i = 0
    for r in range(num_rotations):
        row1 = None
        row2 = None
        for s1 in range(num_scales):
            for s2 in range(num_scales):
                m1 = heatmaps[i,0,:,:]
                m1 = cv2.applyColorMap((m1*255).astype(np.uint8), cv2.COLORMAP_JET)
                m1 = (0.5*m1 + 0.5*(color*255)).astype(np.uint8)
                m2 = heatmaps[i,1,:,:]
                m2 = cv2.applyColorMap((m2*255).astype(np.uint8), cv2.COLORMAP_JET)
                m2 = (0.5*m2 + 0.5*(color*255)).astype(np.uint8)

                theta = r*ROT_BIN
                fold_dist1 = 0.1*((s1)+1.0) # because of flip
                fold_dist2 = 0.1*((s2)+1.0) # because of flip
                line_start = np.array([W/2, W/2])
                line_end1 = line_start + (W*(fold_dist1*np.array([np.sin(theta), np.cos(theta)]))).astype(np.int8)
                line_end2 = line_start + (W*(fold_dist2*np.array([np.sin(theta), np.cos(theta)]))).astype(np.int8)

                m1 = cv2.arrowedLine(m1, (int(line_start[1]),int(line_start[0])), (int(line_end1[1]),int(line_end1[0])), (100,100,100), 3, tipLength=0.3)                
                if r == max_idx[0] and s1 == max_idx[1]:
                    m1 = cv2.circle(m1, (int(max_idx[4]), int(max_idx[3])), 7, (255,255,255), 2)
                m1 = cv2.resize(m1, (100,100))

                m2 = cv2.arrowedLine(m2, (int(line_start[1]),int(line_start[0])), (int(line_end2[1]),int(line_end2[0])), (100,100,100), 3, tipLength=0.3)
                if r == max_idx[0] and s2 == max_idx[2]:
                    m1 = cv2.circle(m1, (int(max_idx[6]), int(max_idx[5])), 7, (255,255,255), 2)
                m2 = cv2.resize(m2, (100,100))

                if row1 is None:
                    row1 = m1
                else:
                    row1 = np.concatenate((row1, m1), axis=0)

                if row2 is None:
                    row2 = m2
                else:
                    row2 = np.concatenate((row2, m2), axis=0)

                i+=1
        if img1 is None:
            img1 = row1
        else:
            img1 = np.concatenate((img1,row1), axis=1)

        if img2 is None:
            img2 = row2
        else:
            img2 = np.concatenate((img2,row2), axis=1)
    return img1,img2


def prepare_buffer(buf):
    if not os.path.exists(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data'):
        os.mkdir(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data')

    max_ind = 0
    inds = 0

    for b in buf:
        obs_ind = b.obs
        nobs_ind = b.nobs

        if (OBS_MODE == 'color'):
            obs = cv2.imread(os.path.join(buffer_folder,"images/%06d_rgb_before.png"%(obs_ind)))
            obs = cv2.resize(obs, (W, W))
            
            nobs = cv2.imread(os.path.join(buffer_folder,"images/%06d_rgb_after.png"%(nobs_ind)))
            nobs = cv2.resize(nobs, (W, W))
        elif (OBS_MODE == 'depth'):
            depth = np.load(os.path.join(buffer_folder,"rendered_images/%06d_depth_before.npy"%(obs_ind)))
            depth = depth*255
            depth = depth.astype(np.uint8)
            obs = np.dstack([depth, depth, depth])
            obs = cv2.resize(obs, (W, W))

            mask = depth > 0
            if np.sum(depth) == 0:
                raise ValueError("zero sum found in obs")

            depth = np.load(os.path.join(buffer_folder,"rendered_images/%06d_depth_after.npy"%(nobs_ind)))
            depth = depth*255
            depth = depth.astype(np.uint8)
            nobs = np.dstack([depth, depth, depth])
            nobs = cv2.resize(nobs, (W, W))

            mask = depth > 0
            if np.sum(depth) == 0:
                raise ValueError("zero sum found in obs")

        if cfg['inputedge']:
            # Compute the edge mask just once and store it
            coords_o = np.load(f'{buffer_folder}/coords/{str(obs_ind).zfill(6)}_coords_before.npy') # N x 3
            rgb_o = cv2.imread(f'{buffer_folder}/images/{str(obs_ind).zfill(6)}_rgb_before.png') # 720 x 720 x 3
            depth_o = np.load(f'{buffer_folder}/rendered_images/{str(obs_ind).zfill(6)}_depth_before.npy') # 200 x 200
            mask_o = cv2.imread(f'{buffer_folder}/image_masks/{str(obs_ind).zfill(6)}_mask_before.png')[:, :, 0] # 200 x 200 x 1

            actmask_o, _, _ = em.get_act_mask(env, coords_o, rgb_o, depth_o, mask_o) # only need obs act mask

            cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/obs_actmask_{str(obs_ind).zfill(6)}.png', actmask_o)

        cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/obs_{OBS_MODE}_{str(obs_ind).zfill(6)}.png', obs)
        cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/nobs_{OBS_MODE}_{str(obs_ind).zfill(6)}.png', nobs)
        if obs_ind > max_ind:
            max_ind = obs_ind
        inds += 1

    assert max_ind+1 == inds
    return max_ind+1


def add_buffer_data(buf_ind, obs, nobs, inedgemask=None, outedgemask=None):
    obs = cv2.cvtColor(obs, cv2.COLOR_RGB2BGR)
    nobs = cv2.cvtColor(nobs, cv2.COLOR_RGB2BGR)
    cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/obs_{OBS_MODE}_{str(buf_ind).zfill(6)}.png', obs)
    cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/nobs_{OBS_MODE}_{str(buf_ind).zfill(6)}.png', nobs)
    if inedgemask is not None:
        cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/obs_actmask_{str(buf_ind).zfill(6)}.png', inedgemask)
    if outedgemask is not None:
        outedgemask = outedgemask.astype(np.uint8)
        cv2.imwrite(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/obs_outmask_{str(buf_ind).zfill(6)}.png', outedgemask)
    buf_ind += 1

    return buf_ind


def load_buffer_data(ind, isnobs=False):
    if isnobs:
        pre = 'nobs'
    else:
        pre = 'obs'

    im = cv2.imread(f'{cfg[hostname]["out_path"]}/output/{model_name}/online_data/{pre}_{OBS_MODE}_{str(ind).zfill(6)}.png')
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    mask = get_mask(im) != 0
    im[mask == False, :] = 0

    return im


def random_goal(buf):
    n = random.sample(buf,1)[0].nobs
    if cfg['index_buffer'] == True:
        return load_buffer_data(n, isnobs=True)
    return n


def get_goal_im(fn):
    if OBS_MODE == 'color':
        fpath = f'../goals/{fn}.png'
    elif OBS_MODE == 'depth':
        fpath = f'../goals/{fn}_depth.png'
    elif OBS_MODE == 'desc':
        fpath = f'../goals/{fn}_desc.png'
    im = cv2.imread(fpath)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im_mask = get_mask(im) != 0
    im[im_mask == False, :] = 0
    pos = np.load(f'../goals/particles/{fn}.npy')[:, :3]
    return im, pos

#------------------------------------------------------------------------------
# Setup
#------------------------------------------------------------------------------

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

num_rotations = 8
num_scales = 3
resize_scales = [(2,0.5),(1,1),(0.5,2)]
MID_FOLD_DIST = 78
W = 200

ROT_BIN = (2*np.pi/num_rotations)

# Padding
diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
PAD = int(np.ceil((diag_len - W)/2))
IM_W = W+2*PAD

if ENV_TYPE == 'tshirt':
    env = BimanualTshirtEnv(use_depth=True,
            use_cached_states=False,
            horizon=5,
            use_desc=False,
            action_repeat=1,
            cam_height=cfg['cam_height'],
            headless=cfg['headless'])
else:
    env = BimanualEnv(use_depth=True,
            use_cached_states=False,
            horizon=5,
            use_desc=False,
            action_repeat=1,
            cam_height=cfg['cam_height'],
            headless=cfg['headless'])

model_name = f'biman_{ENV_TYPE}_{cfg["obs_mode"]}' + \
             f'_itr{cfg["episodes"]}' + \
             f'_lr{cfg["learn_rate"]}' + \
             f'_gamma{cfg["gamma"]}' + \
             f'_idxbuf{1 if cfg["index_buffer"] else 0}' + \
             f'{"_actmaskprob"+str(cfg["actmaskprob"])+"_truecratio"+str(cfg["truecratio"]) if cfg["maskoutput"] else ""}' + \
             f'_seed{SEED}' + \
             f'{"_online" if cfg["online"] else "_offline"}' + \
             f'{cfg["model_suffix"]}'
if cfg['debug'] and os.path.exists(f'{cfg[hostname]["out_path"]}/output/{model_name}'):
    import shutil
    shutil.rmtree(f'{cfg[hostname]["out_path"]}/output/{model_name}')
os.mkdir(f'{cfg[hostname]["out_path"]}/output/{model_name}')
os.mkdir(f'{cfg[hostname]["out_path"]}/output/{model_name}/images')
os.mkdir(f'{cfg[hostname]["out_path"]}/output/{model_name}/evals')
os.mkdir(f'{cfg[hostname]["out_path"]}/output/{model_name}/unfold_evals')
os.mkdir(f'{cfg[hostname]["out_path"]}/output/{model_name}/weights')
with open(f'{cfg[hostname]["out_path"]}/output/{model_name}/config.yaml', 'w') as f:
    yaml.dump(cfg, f)

# Load 300 data points (one hour)
print("Loading replay buffer")
if cfg['index_buffer'] == True:
    buf = torch.load(os.path.join(buffer_folder,'{}_idx.buf'.format(RUN_NAME)))
    buf = buf[:cfg['max_buf']]
    buf_ind = prepare_buffer(buf)
else:
    buf_nomask  = torch.load(f'{buffer_folder}/{RUN_NAME}_{OBS_MODE}.buf')
    buf_nomask  = buf_nomask[:cfg['max_buf']]
    buf = []
    for i, b in enumerate(buf_nomask):
        obs = deepcopy(b.obs)
        nobs = deepcopy(b.nobs)
        goal = deepcopy(b.goal)

        obs_mask = get_mask(obs) != 0
        obs[obs_mask == False, :] = 0

        nobs_mask = get_mask(nobs) != 0
        nobs[nobs_mask == False, :] = 0

        if goal is not None:
            goal_mask = get_mask(goal) != 0
            goal[goal_mask == False, :] = 0

        new_b = Experience(obs, goal, b.act, b.rew, nobs, b.done)
        buf.append(new_b)
print("done")

print("load QNet")
in_channels = 12
qnet = QNetBimanual(in_channels,IM_W).cuda()
tnet = copy.deepcopy(qnet)
print("load optimizer")
opt = optim.Adam(qnet.parameters(), lr=cfg['learn_rate'])
loss_function = nn.SmoothL1Loss(reduction='none')
print("done")

updating_epsilon = EPSILON

step = 0

#### Corner masking

def get_corner_particles():
    state = env.reset(given_goal=None, given_goal_pos=None)
    uv = td.particle_uv_pos(env.camera_params,None)
    uv[:,[1,0]]=uv[:,[0,1]]
    uv = (uv/719) * 199

    # corner 1
    dists = np.linalg.norm((uv - np.array([25,25])),axis=1)
    c1 = dists.argmin()

    # corner 2
    dists = np.linalg.norm((uv - np.array([25,175])),axis=1)
    c2 = dists.argmin()

    # corner 3
    dists = np.linalg.norm((uv - np.array([175,175])),axis=1)
    c3 = dists.argmin()

    # corner 4
    dists = np.linalg.norm((uv - np.array([175,25])),axis=1)
    c4 = dists.argmin()

    return c1,c2,c3,c4

corners = get_corner_particles()

def get_harris(mask, thresh=0.2):
    """Harris corner detector
    Params
    ------
        - mask: np.float32 image of 0.0 and 1.0
        - thresh: threshold for filtering small harris values    Returns
    -------
        - harris: np.float32 array of
    """
    # Params for cornerHarris: 
    # mask - Input image, it should be grayscale and float32 type.
    # blockSize - It is the size of neighbourhood considered for corner detection
    # ksize - Aperture parameter of Sobel derivative used.
    # k - Harris detector free parameter in the equation.
    # https://docs.opencv.org/master/dd/d1a/group__imgproc__feature.html#gac1fc3598018010880e370e2f709b4345
    harris = cv2.cornerHarris(mask, blockSize=5, ksize=5, k=0.01)
    harris[harris<thresh*harris.max()] = 0.0 # filter small values
    harris[harris!=0] = 1.0
    harris_dilated = cv2.dilate(harris, kernel=np.ones((7,7),np.uint8))
    harris_dilated[mask == 0] = 0
    return harris_dilated

def get_particle_uv(idx):
    uv = td.particle_uv_pos(env.camera_params,None)
    uv[:,[1,0]]=uv[:,[0,1]]
    uv = (uv/719) * 199
    u,v = uv[idx]
    return u,v

def get_true_corner_mask(clothmask, r=4):
    true_corners = np.zeros((200,200))
    for c in corners:
        b,a = get_particle_uv(c)
        h,w = true_corners.shape
        y,x = np.ogrid[-a:h-a, -b:w-b]
        if y.shape[0] > h:
            y = y[:h]
        if x.shape[1] > w:
            x = x[:, :w]
        cmask = x*x + y*y <= r*r
        try:
            true_corners[cmask] = 1
        except Exception as e:
            print(e)
            import IPython; IPython.embed()

    true_corners = true_corners * clothmask
    return true_corners

#------------------------------------------------------------------------------
# Run
#------------------------------------------------------------------------------

print(len(buf))
loss_list = []
eval_list = []
unfold_eval_list = []

train_times = []
total_updates = cfg['episodes']+1
update = 0
start = time.time()

while update < total_updates:
    # train
    ti = time.time()
    print(len(buf))
    loss_metrics = train()
    time_diff = time.time()-ti
    train_times.append(time_diff)
    print("time: {}".format(time_diff))
    print("loss: {}".format(loss_metrics['loss']))
    loss_list.append(loss_metrics['loss'])

    # update target network
    for o, t in zip(qnet.parameters(), tnet.parameters()):
        t.data = o.data.clone()*TAU + t.data.clone()*(1-TAU)

    # do eval
    if cfg['eval_interval'] > 0 and update % cfg['eval_interval'] == 0:
        print("Update:",update)

        torch.save(qnet, '{}/output/{}/weights/qnet_{}_rotations_{}_gamma_{}_update_{}.pt'.format(cfg[hostname]['out_path'], model_name, RUN_NAME,num_rotations,GAMMA,update))
        torch.save(tnet, '{}/output/{}/weights/tnet_{}_rotations_{}_gamma_{}_update_{}.pt'.format(cfg[hostname]['out_path'], model_name, RUN_NAME,num_rotations,GAMMA,update))
        torch.save(opt, '{}/output/{}/weights/opt_{}_rotations_{}_gamma_{}_update_{}.pt'.format(cfg[hostname]['out_path'], model_name, RUN_NAME,num_rotations,GAMMA,update))
        np.save("{}/output/{}/loss.npy".format(cfg[hostname]['out_path'], model_name), loss_list)
        np.save("{}/output/{}/eval.npy".format(cfg[hostname]['out_path'], model_name), eval_list)
        torch.save(buf, "{}/output/{}/online_data.buf".format(cfg[hostname]['out_path'], model_name))

        plot_loss(np.array(loss_list), np.array(eval_list), eval_interval=cfg['eval_interval'], savepath='../output/%s/loss_eval.png' % model_name, showfig=False)
    update += 1
total_time = time.time() - start
avg_iter = np.mean(np.array(train_times))
print("total time: {}".format(total_time))
print("average iter: {}".format(avg_iter))
with open(f"{cfg[hostname]['out_path']}/output/{model_name}/time.txt", 'w') as f:
    f.write(str(total_time) + '\n')
    f.write(str(avg_iter))

# Plot and save loss, eval
try:
    plot_loss(np.array(loss_list), np.array(eval_list), eval_interval=cfg['eval_interval'], savepath=f'../output/{model_name}/loss_eval.png', showfig=False)
except Exception as e:
    print(e)