#!/usr/bin/env python
# coding: utf-8

# In[2]:
import sys
import torchvision.transforms as transforms
import os
import random, threading, time, copy
from collections import namedtuple, deque, OrderedDict
import numpy as np, cv2
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from scipy import ndimage
import imutils
import matplotlib.pyplot as plt
from copy import deepcopy
from softgym.envs.corl_baseline import GCFold
from softgym.envs.tshirt_base_env import TshirtBaselineFold
import softgym.envs.tshirt_descriptor as td
import pyflex
import yaml
import socket
import traceback
from models import QNet

def arg(tag, default):
    HYPERS[tag] = type(default)((sys.argv[sys.argv.index(tag)+1])) if tag in sys.argv else default
    return HYPERS[tag]

def get_mask(img, obs_mode):
    if obs_mode in ['depth']:
        mask = (img[:,:,0] > 0) | (img[:,:,1] > 0) | (img[:,:,2] > 0)
        return mask
    else:
        img_hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
        mask1 = cv2.inRange(img_hsv, np.array([20, 50., 10.]), np.array([40, 255., 255.]))
        mask2 = cv2.inRange(img_hsv, np.array([80, 50., 10.]), np.array([100, 255., 255.]))
        mask = cv2.bitwise_or(mask1,mask2)
        kernel = np.ones((5,5),np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        kernel = np.ones((2,2),np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        return mask

def postprocess(heatmap, theta, scale):
    heatmap = rotate(heatmap, -theta).cpu().data.numpy()
    heatmap = unpad(heatmap, scale)
    heatmap = unscale_img(heatmap, scale)
    return heatmap


def rotate(img, theta):
    theta = -theta
    img = img.unsqueeze(0)
    affine_mat = np.array([[np.cos(theta), np.sin(theta), 0],[-np.sin(theta), np.cos(theta), 0]])
    affine_mat = torch.FloatTensor(affine_mat).unsqueeze(0).cuda()
    flow_grid = F.affine_grid(affine_mat, img.data.size())
    img = F.grid_sample(img, flow_grid, mode='nearest')
    return img.squeeze()


def scale_img(img, scale):
    scaled_img = cv2.resize(img, None, fx=scale, fy=scale)
    return scaled_img

def unscale_img(img, scale):
    rscale = 1.0/scale
    scaled_img = cv2.resize(img, None, fx=rscale, fy=rscale)
    return scaled_img


def pad(img):
    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    pad_amount = int(np.ceil((IM_W - img.shape[0])/2))
    padding = [(pad_amount,pad_amount),(pad_amount,pad_amount)]
    if len(img.shape) == 3:
        padding.append((0,0))
    padded = np.pad(img, padding, 'constant', constant_values=0)
    return padded

def unpad(img, scale):
    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    pad_amount = int(np.ceil((IM_W - np.ceil(W*scale))/2))
    unpadded = copy.deepcopy(img[pad_amount:-pad_amount,pad_amount:-pad_amount])
    return unpadded

def mask(img, mask):
    masked = img*mask
    return masked


def get_heatmaps(qnet, img, goal, target=False, inedgemask=None):    
    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD
 
    with torch.no_grad():
        angles_scale_idxs = []
        imgs = torch.zeros([num_rotations*num_scales, 3, IM_W, IM_W], 
                            dtype=torch.float32, device='cuda')
        goals = torch.zeros([num_rotations*num_scales, 3, IM_W, IM_W], 
                             dtype=torch.float32, device='cuda')
        img_t = torch.from_numpy(img).float().cuda().permute(2, 0, 1)
        goal_t = torch.from_numpy(goal).float().cuda().permute(2, 0, 1)

        if inedgemask is not None:
            ems = torch.zeros([num_rotations*num_scales, 3, IM_W, IM_W], 
                                     dtype=torch.float32, device='cuda')
            em_np = np.stack([inedgemask, inedgemask, inedgemask], axis=2)
            em_t = torch.from_numpy(em_np).float().cuda().permute(2, 0, 1)
        else:
            ems = None

        for rotate_idx in range(num_rotations):
            for scale_idx in range(num_scales):
                scale = resize_scales[scale_idx][0]
                i = rotate_idx*num_scales + scale_idx
                    
                size = (int(img.shape[0]*scale), int(img.shape[1]*scale))
                padding = int(np.ceil((IM_W - size[0])/2))
                transform = torch.nn.Sequential(
                    transforms.Resize(size, interpolation=2),
                    transforms.Pad(padding, fill=0, padding_mode='constant'))
                    
                imgs[i, :, :, :] = transforms.functional.rotate(transform(img_t), np.degrees(rotate_idx*ROT_BIN))
                goals[i, :, :, :] = transforms.functional.rotate(transform(goal_t), np.degrees(rotate_idx*ROT_BIN))
                
                if inedgemask is not None:
                    ems[i, :, :, :] = transforms.functional.rotate(transform(em_t), np.degrees(rotate_idx*ROT_BIN))

                angles_scale_idxs.append((rotate_idx, scale_idx))

        x = torch.cat([imgs, goals], dim=1)
        if not target:
            heatmaps = qnet(x)
        else:
            heatmaps = tnet(x)
        return heatmaps, angles_scale_idxs


def select_action(qnet, obs_mode, obs, goal, rgb, viz=True, act_gt=None, random=False, inedgemask=None, outedgemask=None):
    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    mask = get_mask(obs, obs_mode)
    with torch.no_grad():
        qnet.eval()
        inedgemask_int = inedgemask / 255.0 if inedgemask is not None else None
        heatmaps, angles_scale_idxs = get_heatmaps(qnet, obs/255.0, goal/255.0)
        maps = []
        vmaps = []
        max_idxs = []
        for heatmap, angle_scale_idx in zip(heatmaps, angles_scale_idxs):
            scale = resize_scales[angle_scale_idx[1]][0]
            m = postprocess(heatmap, angle_scale_idx[0]*ROT_BIN, scale)
            vmaps.append(copy.deepcopy(m))
            if outedgemask is None:
                m[mask==0] = -10000000
            else:
                m[outedgemask==0] = -10000000
            maps.append(m)
            
        maps = np.array(maps)
        max_val = np.max(maps)

        if random:
            uv = td.random_sample_from_masked_image(mask, 1)
            u,v = uv[0][0], uv[1][0]
            action = [np.random.randint(num_rotations), np.random.randint(num_scales), u, v]
        else:
            act = list(np.unravel_index(np.argmax(maps),maps.shape))
            rotate_max_idx = int(np.floor(act[0]/num_scales))
            scale_max_idx  = act[0]%num_scales
            action = [rotate_max_idx, scale_max_idx, act[1], act[2]]

        if viz: 
            viz_img = render_heatmaps(obs, vmaps, action)
            if outedgemask is not None:
                maps1 = deepcopy(maps) # for viz
                maps1[:, outedgemask==0] = 0
                viz_img_masked = render_heatmaps(obs, maps1, action)
                viz_img = np.concatenate((viz_img, viz_img_masked), axis=0)

        MID_FOLD_DIST=78
        theta = np.deg2rad((360.0/num_rotations) * action[0])
        fold_dist = MID_FOLD_DIST * resize_scales[action[1]][1]
        x = action[3]
        y = action[2]

        arrow_x = int(x + fold_dist*np.cos(theta))
        arrow_y = int(y + fold_dist*np.sin(theta))

        disp_img = rgb.astype(np.uint8)
        color = (0,0,200)
        cv2.circle(disp_img, (x, y), 6, color, 2)
        cv2.arrowedLine(disp_img, (x, y), (arrow_x, arrow_y), color, 2, tipLength=0.1)

        act_type = "greedy"
    return action, act_type, disp_img, viz_img

def render_heatmaps(color, heatmaps, max_idx):
    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    img = None
    heatmaps = (heatmaps - np.min(heatmaps)) / (np.max(heatmaps) - np.min(heatmaps))

    i = 0
    for r in range(num_rotations):
        row = None
        for s in range(num_scales):
            m = heatmaps[i]
            m = cv2.applyColorMap((m*255).astype(np.uint8), cv2.COLORMAP_JET)
            m = (0.5*m + 0.5*(color*255)).astype(np.uint8)

            theta = r*ROT_BIN
            fold_dist = 0.1*((s)+1.0) # because of flip
            line_start = np.array([W/2, W/2])
            line_end = line_start + (W*(fold_dist*np.array([np.sin(theta), np.cos(theta)]))).astype(np.int8)

            m = cv2.arrowedLine(m, (int(line_start[1]),int(line_start[0])), (int(line_end[1]),int(line_end[0])), (100,100,100), 3, tipLength=0.3)
            if r == max_idx[0] and s == max_idx[1]:
                m = cv2.circle(m, (int(max_idx[3]), int(max_idx[2])), 7, (255,255,255), 2)
            m = cv2.resize(m, (100,100))
            if row is None:
                row = m
            else:
                row = np.concatenate((row, m), axis=0)
            i+=1
        if img is None:
            img = row
        else:
            img = np.concatenate((img,row), axis=1)
    return img

def get_eval_im(fn, obs_mode, base_path):
    folder = 'square_final' if HEIGHT == '0.65' else 'square_low'
    if obs_mode == 'color':
        fpath = f'../goals/{folder}/{fn}.png'
    elif obs_mode == 'depth':
        fpath = f'../goals/{folder}/{fn}_depth.png'
    im = cv2.imread(fpath)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im_mask = get_mask(im, obs_mode) != 0
    im[im_mask == False, :] = 0
    pos = np.load(f'../goals/{folder}/particles/{fn}.npy')
    return im, pos

def run_goal(train_iter, qnet, env, cfg, save_dir, eval_ims, em):
    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    for i, combo in enumerate(cfg['eval_combos']):
        is_multistep = type(combo[0]) == list
        if is_multistep: 
            start_name = combo[0][0]
            goal_name = combo[-1][1]
            if STEP == 'multi': # Provide intermediate steps
                name = f'ms-{start_name}-{goal_name}'
            else: # Only provide final goal
                name = f'ms-{start_name}-{goal_name}-finalonly'
        else: # single step goal
            start_name, goal_name = combo
            name = f'{start_name}-{goal_name}'

        existing = os.listdir(save_dir)
        exists = np.any([str(train_iter) in x and name in x for x in existing])

        if exists:
            print(f"Skipping iter {train_iter} for {name} as it already exists")
            continue

        stepcount = 0
        start, start_pos = eval_ims[start_name]
        goal, goal_pos = eval_ims[goal_name]
        goal_pos = goal_pos[:, :3]
        curr_goal = goal

        # reset env
        print(f"goal set to {goal_name}")
        obs = env.reset(given_goal=goal, given_goal_pos=goal_pos) # provide final goal for reward computation
        pyflex.set_positions(start_pos)
        pyflex.step()
        obs = env._get_obs()
        done = False

        vizs = []
        disps = []
        distances = []
        while not done: 
            in_edgemask = None
            out_edgemask = None


            if is_multistep and STEP == 'multi' and stepcount < len(combo):
                print("updating current goal to " + combo[stepcount][1])
                curr_goal, _ = eval_ims[combo[stepcount][1]]
            act, _, disp_img, viz_img = select_action(qnet, cfg['obs_mode'], obs[cfg['obs_mode']], curr_goal, rgb=obs["color"], inedgemask=in_edgemask, outedgemask=out_edgemask)
            
            # step env
            nobs, rew, done, _ = env.step(act, on_table=cfg['on_table'])
            nobs_mask = get_mask(nobs[cfg['obs_mode']], cfg['obs_mode']) != 0
            nobs[cfg['obs_mode']][nobs_mask == False, :] = 0
        
            # Calc particle metrics
            pos = pyflex.get_positions().reshape(-1, 4)[:,:3]
            pos_metric = td.calc_metric_abs(pos, goal_pos)
            distances.append(pos_metric[1])

            vizs.append(viz_img)
            disps.append(disp_img)

            obs = copy.deepcopy(nobs)

            stepcount += 1

        disps.append(obs["color"]) # final obs

        cv2.imwrite(f'{save_dir}/{name}_disp_step_{train_iter}.png', np.vstack(disps))
        cv2.imwrite(f'{save_dir}/{name}_viz_step_{train_iter}.png', np.vstack(vizs))

        # add the best metrics for episodes
        min_dist = min(distances)
        min_step = np.argmin(distances)
        print("Min distance {} achieved at step {}".format(min_dist, min_step))
        with open(f'{save_dir}/metrics.csv', 'a+') as f:
            f.write(f'{train_iter},{name},{min_dist},{min_step}\n')

def evaluate(EVAL_CFG, EVAL_PATH, ENV_TYPE, env, maxiter=25000, base_path='.', out_path='.'):
    with open(f'{out_path}/output/{EVAL_PATH}/config.yaml') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    with open(f'{base_path}/{EVAL_CFG}.yaml') as f:
        eval_cfg = yaml.load(f, Loader=yaml.FullLoader)
        cfg.update(eval_cfg)

    Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))
    ExperienceIndex = namedtuple('ExperienceIndex', ('obs', 'goal', 'act', 'rew', 'nobs', 'done', 'idx'))

    num_rotations = 8
    num_scales = 3
    resize_scales = [(2,0.5),(1,1),(0.5,2)]
    MID_FOLD_DIST = 78
    W = 200

    ROT_BIN = (2*np.pi/num_rotations)

    # Padding
    diag_len = float(W)*np.sqrt(2)*resize_scales[0][0] # because of flip
    PAD = int(np.ceil((diag_len - W)/2))
    IM_W = W+2*PAD

    em = None

    def add_to_eval_ims(cfg, base_path, eval_ims, combo):
        start_fn, goal_fn = combo
        print(start_fn, goal_fn)
        if start_fn not in eval_ims:
            im, pos = get_eval_im(start_fn, cfg['obs_mode'], base_path) 
            eval_ims[start_fn] = (im, pos)
        if goal_fn not in eval_ims:
            im, pos = get_eval_im(goal_fn, cfg['obs_mode'], base_path) 
            eval_ims[goal_fn] = (im, pos)
        return eval_ims

    # test goals
    eval_ims = dict()
    for i, combo in enumerate(cfg['eval_combos']):
        if type(combo[0]) == list: # multi step goal
            for c in combo:
                eval_ims = add_to_eval_ims(cfg, base_path, eval_ims, c)
        else: # single step goal
            eval_ims = add_to_eval_ims(cfg, base_path, eval_ims, combo)

    # For every weight file, load the weights and evaluate on goals
    weight_fns = os.listdir(f'{out_path}/output/{EVAL_PATH}/weights')
    qnet_fns = sorted([x for x in weight_fns if 'qnet' in x], key=lambda k: int(k.split('_')[-1].replace('.pt', '')), reverse=True)
    tnet_fns = sorted([x for x in weight_fns if 'tnet' in x], key=lambda k: int(k.split('_')[-1].replace('.pt', '')), reverse=True)

    qnet_tmpl = '_'.join(qnet_fns[0].split('_')[:-1])
    tnet_tmpl = '_'.join(tnet_fns[0].split('_')[:-1])

    # Make save directory
    save_dir = f'{out_path}/output/{EVAL_PATH}/{"debug" if cfg["debug"] else cfg["eval_type"]}_evals_{STEP}'
    if os.path.exists(save_dir):
        if cfg['overwrite']:
            print(f"deleting {save_dir}")
            import shutil
            shutil.rmtree(save_dir)
            os.mkdir(save_dir)
    else: 
        os.mkdir(save_dir)

    for train_iter in np.arange(500, maxiter+1, step=500): 
        try:
            qnet_fn = f'{qnet_tmpl}_{train_iter}.pt'

            # load weights
            qnet = torch.load(f'{out_path}/output/{EVAL_PATH}/weights/{qnet_fn}')
            qnet.eval()

            print(f"evaluating weight {train_iter} {EVAL_PATH}")
            run_goal(train_iter, qnet, env, cfg, save_dir, eval_ims, em)
        except Exception as e:
            print(traceback.format_exc())            
            break
        
if __name__ == '__main__':
    HYPERS = OrderedDict()

    OUTMASK = arg('outmask', False)
    EVAL_PATH = arg('eval_path', '')
    HEIGHT = arg('height', '')
    STEP = arg('step', 'multi') # 'single'
    ENV_TYPE = 'tshirt' if 'tshirt' in EVAL_PATH else 'towel'

    if ENV_TYPE == 'tshirt':
        env = TshirtBaselineFold(use_depth=True,
                    use_cached_states=False,
                    horizon=5,
                    use_desc=False,
                    cam_height=HEIGHT, # TODO should be int but this arg doesn't do anything
                    action_repeat=1,
                    headless=True)
    else:
        env = GCFold(use_depth=True, 
                    use_cached_states=False, 
                    horizon=5, 
                    use_desc=False,
                    cam_height=HEIGHT,
                    action_repeat=1, 
                    headless=True)

    evaluate('eval', EVAL_PATH, ENV_TYPE, env, base_path='.', out_path='..')
