from softgym.envs.bimanual_env import BimanualEnv
from softgym.envs.bimanual_tshirt import BimanualTshirtEnv
import numpy as np
import pyflex
import softgym.envs.tshirt_descriptor as td
import torch
import os
from PIL import Image
import json
import random
import cv2
from collections import namedtuple
import copy
import matplotlib.pyplot as plt
import socket
import argparse

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))


def get_corner_particles(env):
    state = env.reset(given_goal=None, given_goal_pos=None)
    uv = td.particle_uv_pos(env.camera_params,None)
    uv[:,[1,0]]=uv[:,[0,1]]
    uv = (uv/719) * 199

    # corner 1
    dists = np.linalg.norm((uv - np.array([25,25])),axis=1)
    c1 = dists.argmin()

    # corner 2
    dists = np.linalg.norm((uv - np.array([25,175])),axis=1)
    c2 = dists.argmin()

    # corner 3
    dists = np.linalg.norm((uv - np.array([175,175])),axis=1)
    c3 = dists.argmin()

    # corner 4
    dists = np.linalg.norm((uv - np.array([175,25])),axis=1)
    c4 = dists.argmin()

    # u,v = uv[c1]
    # print(u,v)
    # action = [[u,v],[175,175],[u,v],[175,175]]
    # self.env.step(action, pickplace=True)

    return c1,c2,c3,c4

def get_harris(mask, thresh=0.2):
    """Harris corner detector
    Params
    ------
        - mask: np.float32 image of 0.0 and 1.0
        - thresh: threshold for filtering small harris values    Returns
    -------
        - harris: np.float32 array of
    """
    # Params for cornerHarris: 
    # mask - Input image, it should be grayscale and float32 type.
    # blockSize - It is the size of neighbourhood considered for corner detection
    # ksize - Aperture parameter of Sobel derivative used.
    # k - Harris detector free parameter in the equation.
    # https://docs.opencv.org/master/dd/d1a/group__imgproc__feature.html#gac1fc3598018010880e370e2f709b4345
    harris = cv2.cornerHarris(mask, blockSize=5, ksize=5, k=0.01)
    harris[harris<thresh*harris.max()] = 0.0 # filter small values
    harris[harris!=0] = 1.0
    harris_dilated = cv2.dilate(harris, kernel=np.ones((7,7),np.uint8))
    harris_dilated[mask == 0] = 0
    return harris_dilated

def get_particle_uv(idx, env):
    uv = td.particle_uv_pos(env.camera_params,None)
    uv[:,[1,0]]=uv[:,[0,1]]
    uv = (uv/719) * 199
    u,v = uv[idx]
    return u,v

def get_true_corner_mask(clothmask, corners, env, r=4):
    true_corners = np.zeros((200,200))
    for c in corners:
        b,a = get_particle_uv(c, env)
        h,w = true_corners.shape
        y,x = np.ogrid[-a:h-a, -b:w-b]
        cmask = x*x + y*y <= r*r
        true_corners[cmask] = 1

    true_corners = true_corners * clothmask
    return true_corners

class DatasetGenerator(object):
    def __init__(self, cfgs):
        self.cfgs = cfgs
        
        if cfgs['cloth_type'] == 'towel':
            self.env = BimanualEnv(use_depth=cfgs['img_type'] == 'depth',
                    use_cached_states=False,
                    horizon=cfgs['horizon'],
                    use_desc=False,
                    cam_height=cfgs['cam_height'],
                    action_repeat=1,
                    headless=cfgs['headless'])
        elif cfgs['cloth_type'] == 'tshirt':
            # self.env = BimanualTshirtEnv(use_depth=cfgs['img_type'] == 'depth',
            #         use_cached_states=False,
            #         use_desc=False,
            #         horizon=cfgs['horizon'],
            #         action_repeat=1,
            #         headless=cfgs['headless'])
            raise NotImplementedError

        self.corners = get_corner_particles(self.env)

        #self.em = EdgeMasker(self.env, cfgs['cloth_type'], tshirtmap_path=cfgs['tshirtmap_path'], edgethresh=cfgs['edgethresh'])

    def makedirs(self):
        save_folder = os.path.join(self.cfgs['dataset_folder'], self.cfgs['dataset_name'])
        if self.cfgs['debug']:
            os.system('rm -r %s' % save_folder)
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
            os.makedirs(os.path.join(save_folder, 'images'))
            os.makedirs(os.path.join(save_folder, 'coords'))
            os.makedirs(os.path.join(save_folder, 'image_masks'))
            os.makedirs(os.path.join(save_folder, 'rendered_images'))
            os.makedirs(os.path.join(save_folder, 'knots'))
            os.makedirs(os.path.join(save_folder, 'edge_masks'))
        return save_folder

    def get_masked(self, img):
        """Just used for masking goals, otherwise we use depth"""
        img_hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
        mask = cv2.inRange(img_hsv, np.array([0., 15., 0.]), np.array([255, 255., 255.]))
        kernel = np.ones((3,3),np.uint8)
        morph = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        return morph

    def get_rgbd(self):
        rgbd = pyflex.render_sensor()
        rgbd = np.array(rgbd).reshape(self.env.camera_height, self.env.camera_width, 4)
        rgbd = rgbd[::-1, :, :]
        rgb = rgbd[:, :, :3]
        img = self.env.get_image(self.env.camera_height, self.env.camera_width)
        depth = rgbd[:, :, 3]
        mask = depth > 0
        return img, depth, mask

    def get_rand_action(self, state, img, depth, action_type='perp', single_scale=True, debug_idx=None):
        clothmask = depth > 0
        mask_output = True if np.random.sample() < self.cfgs['actmaskprob'] else False
        if mask_output:
            harris_corners = get_harris(clothmask.astype(np.float32))
            true_corners = get_true_corner_mask(clothmask, self.corners, self.env)
            if np.sum(true_corners) > 2 and np.random.uniform() < self.cfgs['truecratio']:
                mask = true_corners > 0
            elif np.sum(harris_corners) > 2:
                mask = harris_corners > 0
            else:
                mask = clothmask
        else: 
            mask = clothmask

        # returns two arrays of x, and y positions with num_pick number of values
        pick_idx = td.random_sample_from_masked_image(mask, 2)

        if action_type == 'perp':
            # basic random action with two randomly selected pick points
            # angle discretized to 8 rotation bins over circle
            u1,v1 = pick_idx[0][0], pick_idx[1][0]
            u2,v2 = pick_idx[0][1], pick_idx[1][1]
            
            angle = np.rad2deg(np.arctan2(u1 - u2, v1 - v2))
            if self.line_pt_dir([u1, v1],[u2, v2],[100,100]) < 0:
                angle -= 90
            else:
                angle += 90

            idx = int(np.round((angle / 360.0) * 8))
            idx = ((idx) % 8)

            dist1 = np.random.randint(3)

            if single_scale:
            	dist2 = dist1
            else:
            	dist2 = np.random.randint(3)

            return np.array([idx,dist1,dist2,u1,v1,u2,v2])

        if action_type == 'debug':
            actions = [[2,1,1,50,50,50,150],
                    [0,1,1,50,50,150,50],
                    [4,1,1,50,150,150,150],
                    [6,1,1,150,50,150,150]]

            return np.array(actions[debug_idx])

    def line_pt_dir(self, a,b,p):
        ax,ay = a
        bx,by = b
        px,py = p

        bx -= ax
        by -= ay
        px -= ax
        py -= ay

        cross_prod = bx * py - by * px

        # right of line
        if cross_prod > 0:
            return 1

        # left of line
        if cross_prod < 0:
            return -1

        # on the line
        return 0

    def save_data(self, idx, state, coords, img, depth, dataset_path, beforeact=False):
        save_time = 'before' if beforeact else 'after'
        #all_mask, fge_mask, ce_mask = edgemasks
        mask = depth > 0
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]

        rgb_img = Image.fromarray(img, 'RGB')
        rgb_img.save(os.path.join(dataset_path, 'images', '%06d_rgb_%s.png'% (idx, save_time)))
        mask_img = Image.fromarray(mask)
        mask_img.save(os.path.join(dataset_path, 'image_masks', '%06d_mask_%s.png' % (idx, save_time)))
        
        #if self.cfgs['desc_path']:
        #    desc_img = Image.fromarray(state['desc'])
        #    desc_img.save(os.path.join(dataset_path, 'descs', '%06d_desc_%s.png' % (idx, save_time)))

        np.save(os.path.join(dataset_path, 'rendered_images', '%06d_depth_%s.npy' % (idx, save_time)),depth)
        np.save(os.path.join(dataset_path, 'coords', '%06d_coords_%s.npy' % (idx, save_time)),coords)
        np.save(os.path.join(dataset_path, 'knots', '%06d_knots_%s.npy' % (idx, save_time)),uv)
        #np.save(os.path.join(dataset_path, 'edge_masks', '%06d_allmask_%s.npy' % (idx, save_time)), all_mask)
        #np.save(os.path.join(dataset_path, 'edge_masks', '%06d_fgemask_%s.npy' % (idx, save_time)), fge_mask)
        #np.save(os.path.join(dataset_path, 'edge_masks', '%06d_cemask_%s.npy' % (idx, save_time)), ce_mask)

    def get_obs(self):
        coords = pyflex.get_positions().reshape(-1, 4)[:,:3]
        img, depth, mask = self.get_rgbd()
        #all_mask, fge_mask, ce_mask = self.em.get_act_mask(self.env, coords, img, depth, mask)
        #edgemasks = (all_mask, fge_mask, ce_mask)
        img = cv2.resize(img, (200, 200))
        depth = cv2.resize(depth, (200, 200))
        #return coords, img, depth, edgemasks
        return coords, img, depth

    def generate(self):
        min_reward = 0
        max_reward = -10000

        # load goals
        goals = []
        for g in self.cfgs['goals']:
            if g is not None:
                if self.cfgs['img_type'] == 'color':
                    goal = cv2.imread(f"../goals/{g}.png")
                    goal = cv2.cvtColor(goal, cv2.COLOR_BGR2RGB)
                    goal_mask = self.get_masked(goal) != 0
                    goal[goal_mask == False, :] = 0
                elif self.cfgs['img_type'] == 'depth':
                    goal = cv2.imread(f"../goals/{g}_depth.png")
                elif self.cfgs['img_type'] == 'desc':
                    goal = cv2.imread(f"../goals/{g}_desc.png")

                goal_pos = np.load('../goals/particles/{}.npy'.format(g))[:,:3]
            else:
                goal = g
                goal_pos = None
            goals.append([goal, goal_pos])

        save_folder = self.makedirs()
        buf = []
        idx_buf = [] # buffer with only indexes, no images
        idx = 0
        for ep in range(self.cfgs['num_eps']):
            goal, goal_pos = random.choice(goals)
            state = self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
            done = False

            if self.cfgs['record']:
                self.env.start_record()

            while not done:
                # Get edgemask of observation
                coords, img, depth = self.get_obs()

                # check if out of screen
                mask = depth > 0
                if np.sum(mask) < 250:
                    self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
                    coords, img, depth = self.get_obs()

                action = self.get_rand_action(state, img, depth, action_type=self.cfgs['action_type'], debug_idx=ep)
                next_state, reward, done, _ = self.env.step(action, pickplace=self.cfgs['action_type'] == 'uniform', on_table=self.cfgs['on_table'])
                coords_next, img_next, depth_next = self.get_obs()

                # check if out of screen
                mask = depth_next > 0
                if np.sum(mask) < 250:
                    self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
                    continue

                self.save_data(idx, state, coords, img, depth, save_folder, beforeact=True)
                self.save_data(idx, next_state, coords_next, img_next, depth_next, save_folder, beforeact=False)

                if reward < min_reward:
                    min_reward = reward
                if reward > max_reward:
                    max_reward = reward

                im_type = self.cfgs['img_type']
                buf.append(Experience(state[im_type], state["goal"], action, reward, next_state[im_type], done))
                idx_buf.append(Experience(idx, state["goal"], action, reward, idx, done))
            
                state = copy.deepcopy(next_state)
                self.env.render(mode='rgb_array')
                idx += 1
            
            if self.cfgs['record']:
                self.env.end_record(video_path=f'{ep}.gif')

            if (ep % 500) == 0:
                torch.save(buf, os.path.join(save_folder,f'{self.cfgs["dataset_name"]}.buf'))
                torch.save(idx_buf, os.path.join(save_folder,f'{self.cfgs["dataset_name"]}_idx.buf'))
            torch.save(buf, os.path.join(save_folder,f'{self.cfgs["dataset_name"]}.buf'))
            torch.save(idx_buf, os.path.join(save_folder,f'{self.cfgs["dataset_name"]}_idx.buf'))

        # create knots info
        knots = os.listdir(os.path.join(save_folder, 'knots'))
        knots.sort()
        kdict = {}
        for i, name in enumerate(knots):
            knot = np.load(os.path.join(save_folder,'knots',name))
            knot = np.reshape(knot,(knot.shape[0],1,knot.shape[1]))
            kdict[str(i)] = knot.tolist()
        with open(os.path.join(save_folder,'images','knots_info.json'),'w') as f:
            json.dump(kdict,f)

        print(f"min reward: {min_reward}, max reward: {max_reward}")
        np.save(os.path.join(save_folder, f'rewards.npy'), [min_reward, max_reward])

    def collect_goals(self):

        names = ['debug_bimanual_1','debug_bimnual_2','debug_bimnual_3','debug_bimanual_4']
        actions = [[2,1,1,50,50,50,150],
                    [0,1,1,50,50,150,50],
                    [4,1,1,50,150,150,150],
                    [6,1,1,150,50,150,150]]

        for name,action in zip(names,actions):

            goal, goal_pos = None, None
            state = self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
            coords, img, depth = self.get_obs()
            #action = self.get_rand_action(state, img, depth, action_type='debug')
            next_state, reward, done, _ = self.env.step(action, pickplace=self.cfgs['action_type'] == 'uniform', on_table=self.cfgs['on_table'])
            coords_next, img_next, depth_next = self.get_obs()
            mask = depth_next > 0

            depth_next = depth_next*255
            depth_next = depth_next.astype(np.uint8)
            nobs = np.dstack([depth_next, depth_next, depth_next])
            nobs = cv2.resize(nobs, (200, 200))

            pos = pyflex.get_positions().reshape(-1, 4)

            impath = f'/home/user/cloth_folding/softagent/tshirt_exp/corlbaseline/goals/{name}_depth.png'
            pospath = f'/home/user/cloth_folding/softagent/tshirt_exp/corlbaseline/goals/particles/{name}.npy'

            cv2.imwrite(impath, nobs)
            np.save(pospath, pos)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_eps', dest='num_eps', type=int, default=60)
    parser.add_argument('--cloth_type', dest='cloth_type', default='towel')
    parser.add_argument('--action_type', dest='action_type', default='perp')
    parser.add_argument('--img_type', dest='img_type', default='depth')
    parser.add_argument('--actmaskprob', dest='actmaskprob', type=float, default=0.0)
    parser.add_argument('--cemaskratio', dest='cemaskratio', type=float, default=0.0)
    parser.add_argument('--truecratio', dest='truecratio', type=float, default=0.0)
    parser.add_argument('--cam_height', dest='cam_height', type=float, default=0.45)
    parser.add_argument('--horizon', dest='horizon', type=int, default=5)
    parser.add_argument('--overwrite', dest='overwrite', action='store_true')
    parser.add_argument('--headless', dest='headless', action='store_true')
    parser.add_argument('--record', dest='record', action='store_true')
    args = parser.parse_args()
    num_eps = args.num_eps
    cloth_type = args.cloth_type
    action_type = args.action_type
    img_type = args.img_type
    actmaskprob = args.actmaskprob
    cemaskratio = args.cemaskratio
    cam_height = args.cam_height
    horizon = args.horizon
    overwrite = args.overwrite
    headless = args.headless
    record = args.record
    truecratio = args.truecratio

    edgethresh = 10 if cloth_type == 'tshirt' else 5
    on_table=False if cloth_type == 'towel' else True
    cfgs = {
        'debug': overwrite, # overwrite old folder if True
        'num_eps': num_eps,
        'img_type': img_type,
        'cloth_type': cloth_type,
        'action_type': action_type,
        'edgethresh': edgethresh,
        'actmaskprob': actmaskprob,
        'cemaskratio': cemaskratio,
        'tshirtmap_path': None,
        'on_table': on_table,
        'horizon': horizon,
        'state_dim': 200*200*3,
        'dataset_folder': '',
        'action_dim': 7,
        'dataset_name': f'biman_{cloth_type}_act{action_type}_n{num_eps}_horiz{horizon}_ontable{1 if on_table else 0}{"_actmaskprob" + str(actmaskprob) + "truecratio" + str(truecratio) + "_cornerbias" if actmaskprob > 0 else ""}',
        'goals': [],
        'headless': headless,
        'cam_height': cam_height,
        'record': record,
        'truecratio': truecratio
    }

    if cloth_type == 'towel':
        cfgs['goals'] = [f'towel_train_{i}' for i in range(32)]
    elif cloth_type == 'tshirt':
        cfgs['goals'] = ['tsf', 'two_tsf', 'three_step', 'tstsf', 'partial_horz_1', 'partial_vert_1', 'partial_vert_2', 'partial_diag_0', 'partial_diag_1', 'partial_diag_2']

    hostname = socket.gethostname()
    if hostname == 'TheCat':
        cfgs['dataset_folder'] = '/media/ExtraDrive1/temp_data'
        cfgs['tshirtmap_path'] = '/home/user/cloth_folding/softagent/softgym/PyFlexRobotics/data/tshirt_edgemap_id.txt'
        # if cloth_type == 'towel':
        #     cfgs['desc_path'] = '/data/fabric_data/fabric_data/trained_models/sg_towel3500_size200_iter3500_nodr_occl2_dbg/003500.pth'
        # elif cloth_type == 'tshirt':
        #     cfgs['desc_path'] = '/media/ExtraDrive1/temp_data/debug_1000/trained_models/softgym_no_DR/003500.pth'
    elif hostname == 'u109974':
        cfgs['dataset_folder'] = '/data/fabric_data/replay_data'
        cfgs['tshirtmap_path'] = '/home/exx/projects/softagent/softgym/PyFlexRobotics/data/tshirt_edgemap_id.txt'
        # if cloth_type == 'towel':
        #     cfgs['desc_path'] = '/data/fabric_data/fabric_data/trained_models/sg_towel3500_size200_iter3500_nodr_occl2_dbg/003500.pth'
        # elif cloth_type == 'tshirt':
        #     cfgs['desc_path'] = '/data/fabric_data/fabric_data/trained_models/sg_tshirt1000_size200_nodr/003500.pth'

    dataset = DatasetGenerator(cfgs)
    dataset.generate()
    #dataset.collect_goals()
