from softgym.envs.tshirt_base_env import TshirtBaselineFold
from softgym.envs.corl_baseline import GCFold
import numpy as np
import pyflex
import softgym.envs.tshirt_descriptor as td
import torch
import os
from PIL import Image
import json
import random
import cv2
from collections import namedtuple
import copy
import matplotlib.pyplot as plt
import socket
import argparse
from edge_masker import EdgeMasker
from utils import get_harris

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

class DatasetGenerator(object):
    def __init__(self, cfgs):
        self.cfgs = cfgs
        
        if cfgs['cloth_type'] == 'towel':
            self.env = GCFold(use_depth=cfgs['img_type'] == 'depth',
                    use_cached_states=False,
                    horizon=cfgs['horizon'],
                    use_desc=False,
                    cam_height=cfgs['cam_height'],
                    action_repeat=1,
                    headless=cfgs['headless'])
        elif cfgs['cloth_type'] == 'tshirt':
            raise NotImplementedError

        self.corners = self.get_corner_particles()
        self.em = EdgeMasker(self.env, cfgs['cloth_type'], tshirtmap_path=cfgs['tshirtmap_path'], edgethresh=cfgs['edgethresh'])

    def makedirs(self):
        save_folder = os.path.join(self.cfgs['dataset_folder'], self.cfgs['dataset_name'])
        if self.cfgs['debug']:
            os.system('rm -r %s' % save_folder)
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
            os.makedirs(os.path.join(save_folder, 'images'))
            os.makedirs(os.path.join(save_folder, 'coords'))
            os.makedirs(os.path.join(save_folder, 'rendered_images'))
            os.makedirs(os.path.join(save_folder, 'knots'))
            os.makedirs(os.path.join(save_folder, 'actions'))
        return save_folder

    def get_masked(self, img):
        """Just used for masking goals, otherwise we use depth"""
        img_hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
        mask = cv2.inRange(img_hsv, np.array([0., 15., 0.]), np.array([255, 255., 255.]))
        kernel = np.ones((3,3),np.uint8)
        morph = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        return morph

    def get_rgbd(self):
        rgbd = pyflex.render_sensor()
        rgbd = np.array(rgbd).reshape(self.env.camera_height, self.env.camera_width, 4)
        rgbd = rgbd[::-1, :, :]
        rgb = rgbd[:, :, :3]
        img = self.env.get_image(self.env.camera_height, self.env.camera_width)
        depth = rgbd[:, :, 3]
        mask = depth > 0
        return img, depth, mask

    def get_corner_particles(self):
        state = self.env.reset(given_goal=None, given_goal_pos=None)
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]
        uv = (uv/719) * 199

        # corner 1
        dists = np.linalg.norm((uv - np.array([25,25])),axis=1)
        c1 = dists.argmin()

        # corner 2
        dists = np.linalg.norm((uv - np.array([25,175])),axis=1)
        c2 = dists.argmin()

        # corner 3
        dists = np.linalg.norm((uv - np.array([175,175])),axis=1)
        c3 = dists.argmin()

        # corner 4
        dists = np.linalg.norm((uv - np.array([175,25])),axis=1)
        c4 = dists.argmin()

        return c1,c2,c3,c4

    def get_particle_uv(self, idx):
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]
        uv = (uv/719) * 199
        u,v = uv[idx]
        return u,v

    def get_true_corner_mask(self, clothmask, r=4):
        true_corners = np.zeros((200,200))
        for c in self.corners:
            b,a = self.get_particle_uv(c)
            h,w = true_corners.shape
            y,x = np.ogrid[-a:h-a, -b:w-b]
            cmask = x*x + y*y <= r*r
            true_corners[cmask] = 1

        true_corners = true_corners * clothmask
        return true_corners

    def get_rand_action(self, state, img, depth, edgemasks, action_type='corlbaseline'):
        # Choose mask for action
        clothmask = depth > 0
        if np.random.uniform() < self.cfgs['actmaskprob']:
            if self.cfgs['use_corner']:
                print("GETTING CORNER")
                harris_corners = get_harris(clothmask.astype(np.float32))
                true_corners = self.get_true_corner_mask(clothmask)
                if np.sum(true_corners) > 2 and np.random.uniform() < self.cfgs['truecratio']:
                    mask = true_corners > 0
                elif np.sum(harris_corners) > 2:
                    mask = harris_corners > 0
                else:
                    mask = clothmask
            else:
                _, fge_mask, ce_mask = edgemasks
                if np.any(ce_mask != 0) and np.random.uniform() < self.cfgs['cemaskratio']: # sample from cloth edge mask
                    mask = ce_mask > 0
                else: # sample from fg edge mask
                    mask = fge_mask > 0             
        else: # Cloth mask 
            mask = depth > 0

        pick_idx = td.random_sample_from_masked_image(mask, 1)
        if action_type == 'uniform':
            obs, depth, mask = self.get_rgbd()
            obs = cv2.resize(obs, (200, 200))
            place_idx = np.unravel_index(np.random.choice(obs.shape[0]*obs.shape[1], 1, replace=False), (obs.shape[0], obs.shape[1]))
            return np.array([pick_idx, place_idx])
        elif action_type == 'corlbaseline':
            u,v = pick_idx[0][0], pick_idx[1][0]
            angle = np.rad2deg(np.arctan2(100 - u, 100 - v))
            idx = int(np.round((angle / 360.0) * 8))
            idx = ((idx) % 8)
            dist = np.random.randint(3)
            return np.array([idx,dist,u,v])

    def save_data(self, idx, state, coords, img, depth, dataset_path, beforeact=False):
        save_time = 'before' if beforeact else 'after'
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]

        rgb_img = Image.fromarray(img, 'RGB')
        rgb_img.save(os.path.join(dataset_path, 'images', '%06d_rgb_%s.png'% (idx, save_time)))
        np.save(os.path.join(dataset_path, 'rendered_images', '%06d_depth_%s.npy' % (idx, save_time)),depth)
        np.save(os.path.join(dataset_path, 'coords', '%06d_coords_%s.npy' % (idx, save_time)),coords)
        np.save(os.path.join(dataset_path, 'knots', '%06d_knots_%s.npy' % (idx, save_time)),uv)

    def save_action(self, idx, action, dataset_path):
        np.save(os.path.join(dataset_path, 'actions', '%06d_action.npy' % (idx)), action)

    def get_obs(self):
        coords = pyflex.get_positions().reshape(-1, 4)[:,:3]
        img, depth, mask = self.get_rgbd()
        if self.cfgs['use_corner']:
            edgemasks = None
        else:
            edgemasks = self.em.get_act_mask(self.env, coords, img, depth, mask)
        img = cv2.resize(img, (200, 200))
        depth = cv2.resize(depth, (200, 200))
        return coords, img, depth, edgemasks
        # return coords, img, depth

    def generate(self):
        min_reward = 0
        max_reward = -10000

        # load goals
        goals = []
        for g in self.cfgs['goals']:
            if g is not None:
                if self.cfgs['img_type'] == 'color':
                    goal = cv2.imread(f"../goals/{g}.png")
                    goal = cv2.cvtColor(goal, cv2.COLOR_BGR2RGB)
                    goal_mask = self.get_masked(goal) != 0
                    goal[goal_mask == False, :] = 0
                elif self.cfgs['img_type'] == 'depth':
                    goal = cv2.imread(f"../goals/{g}_depth.png")
                elif self.cfgs['img_type'] == 'desc':
                    goal = cv2.imread(f"../goals/{g}_desc.png")

                goal_pos = np.load('../goals/particles/{}.npy'.format(g))[:,:3]
            else:
                goal = g
                goal_pos = None
            goals.append([goal, goal_pos])

        save_folder = self.makedirs()
        buf = []
        idx_buf = [] # buffer with only indexes, no images
        idx = 0
        for ep in range(self.cfgs['num_eps']):
            goal, goal_pos = random.choice(goals)
            state = self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
            done = False

            while not done:
                # Get edgemask of observation
                coords, img, depth, edgemasks = self.get_obs()

                # check if out of screen
                mask = depth > 0
                if np.sum(mask) < 250:
                    self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
                    coords, img, depth, edgemasks = self.get_obs()
                    
                action = self.get_rand_action(state, img, depth, edgemasks, action_type=self.cfgs['action_type'])
                self.save_action(idx, action, save_folder)

                next_state, reward, done, _ = self.env.step(action, pickplace=self.cfgs['action_type'] == 'uniform', on_table=self.cfgs['on_table'])
                coords_next, img_next, depth_next, _ = self.get_obs()

                # check if out of screen
                mask = depth_next > 0
                if np.sum(mask) < 250:
                    self.env.reset(given_goal=goal, given_goal_pos=goal_pos)
                    continue

                self.save_data(idx, state, coords, img, depth, save_folder, beforeact=True)
                self.save_data(idx, next_state, coords_next, img_next, depth_next, save_folder, beforeact=False)

                if reward < min_reward:
                    min_reward = reward
                if reward > max_reward:
                    max_reward = reward

                im_type = self.cfgs['img_type']
                buf.append(Experience(state[im_type], state["goal"], action, reward, next_state[im_type], done))
                idx_buf.append(Experience(idx, state["goal"], action, reward, idx, done))
            
                state = copy.deepcopy(next_state)
                self.env.render(mode='rgb_array')
                idx += 1
            
            torch.save(buf, os.path.join(save_folder,f'{self.cfgs["dataset_name"]}.buf'))
            torch.save(idx_buf, os.path.join(save_folder,f'{self.cfgs["dataset_name"]}_idx.buf'))

        # create knots info
        knots = os.listdir(os.path.join(save_folder, 'knots'))
        knots.sort()
        kdict = {}
        for i, name in enumerate(knots):
            knot = np.load(os.path.join(save_folder,'knots',name))
            knot = np.reshape(knot,(knot.shape[0],1,knot.shape[1]))
            kdict[str(i)] = knot.tolist()
        with open(os.path.join(save_folder,'images','knots_info.json'),'w') as f:
            json.dump(kdict,f)

        print(f"min reward: {min_reward}, max reward: {max_reward}")
        np.save(os.path.join(save_folder, f'rewards.npy'), [min_reward, max_reward])

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_eps', dest='num_eps', type=int, default=60)
    parser.add_argument('--cloth_type', dest='cloth_type', default='towel')
    parser.add_argument('--action_type', dest='action_type', default='corlbaseline')
    parser.add_argument('--img_type', dest='img_type', default='depth')
    parser.add_argument('--actmaskprob', dest='actmaskprob', type=float, default=0.0)
    parser.add_argument('--cemaskratio', dest='cemaskratio', type=float, default=0.0)
    parser.add_argument('--cam_height', dest='cam_height', type=float, default=0.65)
    parser.add_argument('--horizon', dest='horizon', type=int, default=5)
    parser.add_argument('--overwrite', dest='overwrite', action='store_true')
    parser.add_argument('--headless', dest='headless', action='store_true')
    parser.add_argument('--use_corner', dest='use_corner', action='store_true')
    parser.add_argument('--truecratio', dest='truecratio', default=0.5)
    args = parser.parse_args()
    num_eps = args.num_eps
    cloth_type = args.cloth_type
    action_type = args.action_type
    img_type = args.img_type
    actmaskprob = args.actmaskprob
    cemaskratio = args.cemaskratio
    cam_height = args.cam_height
    horizon = args.horizon
    overwrite = args.overwrite
    headless = args.headless
    use_corner = args.use_corner
    truecratio = args.truecratio

    edgethresh = 10 if cloth_type == 'tshirt' else 5
    on_table=False if cloth_type == 'towel' else True
    dataset_name = f'sg_{cloth_type}_act{action_type}_n{num_eps}_edgethresh{edgethresh}_horiz{horizon}_corner{use_corner}_actmask{actmaskprob}_cratio{truecratio}_ontable{1 if on_table else 0}_1' \
                    if use_corner else  \
                    f'sg_{cloth_type}_act{action_type}_n{num_eps}_edgethresh{edgethresh}_horiz{horizon}_corner{use_corner}_actmask{actmaskprob}_ceratio{cemaskratio}_ontable{1 if on_table else 0}_1'
    cfgs = {
        'debug': overwrite, # overwrite old folder if True
        'num_eps': num_eps,
        'img_type': img_type,
        'cloth_type': cloth_type,
        'action_type': action_type,
        'edgethresh': edgethresh,
        'actmaskprob': actmaskprob,
        'cemaskratio': cemaskratio,
        'tshirtmap_path': None,
        'on_table': on_table,
        'horizon': horizon,
        'state_dim': 200*200*3,
        'dataset_folder': '',
        'action_dim': 4,
        'dataset_name': dataset_name,
        'goals': [],
        'headless': headless,
        'cam_height': cam_height,
        'use_corner': use_corner,
        'truecratio': truecratio
    }

    if cloth_type == 'towel':
        cfgs['goals'] = [f'towel_train_{i}' for i in range(32)]
    elif cloth_type == 'tshirt':
        cfgs['goals'] = ['tsf', 'two_tsf', 'three_step', 'tstsf', 'partial_horz_1', 'partial_vert_1', 'partial_vert_2', 'partial_diag_0', 'partial_diag_1', 'partial_diag_2']

    cfgs['dataset_folder'] = 'path/to/data'
    cfgs['tshirtmap_path'] = '/path/to/tshirtmap.txt'

    dataset = DatasetGenerator(cfgs)
    dataset.generate()
