from softgym.registered_env import env_arg_dict
from softgym.registered_env import SOFTGYM_ENVS
import copy
import argparse
import pyflex
import numpy as np
import torch
import cv2
import os 
from chester import logger
import json
import os.path as osp
from GNS.fabric_vsf.scripts.format_hdf5_softgym import format_data
from GNS.camera_utils import get_matrix_world_to_camera, project_to_image
import math
import imageio
from softgym.envs.bimanual_env import BimanualEnv
from bimanual_flow.collect_data_bimanual import DatasetGenerator
import pickle

class VArgs(object):
    def __init__(self, vv):
        for key, val in vv.items():
            setattr(self, key, val)


def vv_to_args(vv):
    args = VArgs(vv)
    return args

def build_collect_dataset(vv, env, action_range):
    num_eps = 10000 # not used
    horizon = 2 # not used
    cloth_type = 'towel' # not used
    action_type = 'pickplace' #qnet # pickplace # debug
    img_type = 'depth' # not used
    edgethresh = 10 if cloth_type == 'tshirt' else 5
    if not vv['corner_biasing']:
        actmaskprob = 0 #
    else:
        actmaskprob = 0.9 
    cemaskratio = 0.5 # ratio of how often to sample cloth edge mask
    on_table=False if cloth_type == 'towel' else True
    truecratio = 0.5
    if vv['corner_biasing']:
        use_corner = True # important
    else:
        use_corner = False

    cfgs = {    
        'debug': False, # overwrite old folder if True
        'num_eps': num_eps,
        'img_type': img_type,
        'cloth_type': cloth_type,
        'action_type': action_type,
        'edgethresh': edgethresh,
        'actmaskprob': actmaskprob,
        'cemaskratio': cemaskratio,
        'tshirtmap_path': None,
        'on_table': on_table,
        'horizon': horizon,
        'state_dim': 200*200*3,
        'dataset_folder': '',
        'action_dim': 7,
        'dataset_name': f'biman_{cloth_type}_act{action_type}_n{num_eps}_h{horizon}_co{int(use_corner)}_am{actmaskprob}_tc{truecratio}_cam0.65',
        'desc_path': False,
        'goals': [],
        'use_corner': use_corner,
        'truecratio': truecratio,
        'headless': True,
        'action_range': action_range
    }

    cfgs['goals'] = [None]

    cfgs['dataset_folder'] = None
    cfgs['tshirtmap_path'] = None 

    dataset = DatasetGenerator(cfgs, env=env)
    return dataset

def move_picker_out_of_scene():
    shape_states = pyflex.get_shape_states().reshape((-1, 14))
    shape_states[:, :3] = -1
    shape_states[:, 7:10] = -1
    pyflex.set_shape_states(shape_states)
    pyflex.step()


def normalize_action(action, imsize, action_range):
    # scale every element of the action to be in the range [-1, 1]
    pick1, place1, pick2, place2 = action

    picku1, pickv1 = pick1
    placeu1, placev1 = place1
    picku2, pickv2 = pick2
    placeu2, placev2 = place2

    move_u_1 = placeu1 - picku1
    move_v_1 = placev1 - pickv1
    move_u_2 = placeu2 - picku2
    move_v_2 = placev2 - pickv2

    # print("move_u_1: ", move_u_1)
    # print("move_v_1: ", move_v_1)
    # print("move_u_2: ", move_u_2)
    # print("move_v_2: ", move_v_2)

    max_move = (action_range[1] - action_range[0])
    # print('action range is: ', action_range)
    move_u_1 = (move_u_1 - action_range[0]) / max_move
    move_v_1 = (move_v_1 - action_range[0]) / max_move 
    move_u_2 = (move_u_2 - action_range[0]) / max_move
    move_v_2 = (move_v_2 - action_range[0]) / max_move

    assert move_u_1 >= -1 and move_u_1 <= 1
    assert move_v_1 >= -1 and move_v_1 <= 1
    assert move_u_2 >= -1 and move_u_2 <= 1
    assert move_v_2 >= -1 and move_v_2 <= 1

    picku1 = 2 * picku1 / (imsize - 1) - 1 
    pickv1 = 2 * pickv1 / (imsize - 1) - 1

    picku2 = 2 * picku2 / (imsize - 1) - 1 
    pickv2 = 2 * pickv2 / (imsize - 1) - 1

    # placeu1 = 2 * placeu1 / (imsize - 1) - 1 
    # placev1 = 2 * placev1 / (imsize - 1) - 1

    # placeu2 = 2 * placeu2 / (imsize - 1) - 1 
    # placev2 = 2 * placev2 / (imsize - 1) - 1

    return [picku1, pickv1, move_u_1, move_v_1, picku2, pickv2, move_u_2, move_v_2]

def run_task(vv, log_dir, exp_name):
    args = vv_to_args(vv)
    args.__dict__.update(**vv)

    # Configure logger
    logger.configure(dir=log_dir, exp_name=exp_name)
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)


    assert args.depth_max == 0.45
    assert args.camera_param_height == 0.45
    # assert args.camera_width == 56
    # assert args.camera_height == 56
    if vv['small_action']:
        print("using small action!")
        action_range = [0, 23]
    else: # use large action
        print("using large action!")
        action_range = [0, 45]

    # Dump parameters
    with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(args.__dict__, f, indent=2, sort_keys=True)

    # TODO: rebuild the bimanual env
    env = BimanualEnv(use_depth=True,
                    use_cached_states=False,
                    horizon=args.num_traj_pick,
                    use_desc=False,
                    action_repeat=1,
                    headless=True,
                    render=True,
                    camera_width=args.camera_width,
                    camera_height=args.camera_height,
                    camera_param_height=args.camera_param_height)

    env.reset()
    # o = env._get_obs(pure=True)
    # cv2.imshow('o', o)
    # cv2.waitKey()
    
    collect_dataset = build_collect_dataset(vv, env, action_range)

    rgbd_data = []
    action_data = []
    for traj_id in range(args.range[0], args.range[1]):
        print("traj id: ", traj_id)
        env.reset()

        # move_pickesr_out_of_scene()
        o = env._get_obs(pure=True)
        
        # cv2.imshow("reset observation", o)
        # cv2.waitKey()

        rgbd_traj = []
        action_traj = []

        rgbd = pyflex.render_sensor()
        rgbd = np.array(rgbd).reshape(args.camera_height, args.camera_width, 4)
        rgbd = rgbd[::-1, :, :]
        depth = rgbd[:, :, 3]
        depth_original = depth.copy()
        # NOTE: scale depth to [0, 255]
        depth = (np.clip(depth[:, :, None] / args.depth_max * 255, a_min=0, a_max=255)).astype(np.uint8)
        rgbd = np.dstack([o, depth])

        rgbd_traj.append(rgbd)

        if args.show:
            # cv2.imshow("rgb", rgbd[:, :, :3][:, :, ::-1])
            cv2.imshow("rgb", rgbd[:, :, :3][:, :, :])
            cv2.imshow("depth", rgbd[:, :, 3])
            cv2.imwrite("depth.png", rgbd[:, :, 3])
            cv2.waitKey()
            exit()

        for t in range(args.num_traj_pick):
            print("\t pick-place-id: ", t)

            # TODO: change the sampled action to be from bimanual env
            # action_raw = collect_dataset.get_rand_action(None, None, depth_original, None, 
                # None, collect_dataset.cfgs['action_type'], debug_idx=traj_id)
            _, action_raw = collect_dataset.get_rand_action(None, None, depth_original, None, None, 
                action_type=collect_dataset.cfgs['action_type'], debug_idx=traj_id, 
                bias_towards_corner=vv['corner_biasing']
                )

            # TODO: change the way of normalizing actions
            action_normalized = normalize_action(action_raw, args.camera_width, action_range)
            if vv['single_arm']:
                action_normalized = action_normalized[:4]
            action_traj.append(action_normalized)

            _, _, _, _ = env.step(action_raw, pickplace=True, on_table=collect_dataset.cfgs['on_table'], 
                single_arm=vv['single_arm'])
            move_picker_out_of_scene()
            o = env._get_obs(pure=True)

            rgbd = pyflex.render_sensor()
            rgbd = np.array(rgbd).reshape(args.camera_height, args.camera_width, 4)
            rgbd = rgbd[::-1, :, :]
            depth = rgbd[:, :, 3]
            depth_original = depth.copy()
            depth = (np.clip(depth[:, :, None] / args.depth_max * 255, a_min=0, a_max=255)).astype(np.uint8)
            rgbd = np.dstack([o, depth])
            rgbd_traj.append(rgbd)

            if args.show:
                # cv2.imshow("rgb", rgbd[:, :, :3][:, :, ::-1])
                cv2.imshow("rgb", rgbd[:, :, :3][:, :, :])
                cv2.imshow("depth", rgbd[:, :, 3])
                cv2.waitKey()

        rgb_traj = [im[:, :, :3] for im in rgbd_traj]
        rgb_traj = np.concatenate(rgb_traj, axis=1)
        if traj_id < 20:
            imageio.imwrite(os.path.join(log_dir, './vsf_traj_{}.png'.format(traj_id)), rgb_traj)

        rgbd_data.append(rgbd_traj)
        action_data.append(action_traj)

    dir_ = os.path.dirname(__file__)
    save_path = os.path.join(dir_, '../data', exp_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)

    torch.save((rgbd_data, action_data), osp.join(save_path, "softgym_traj_{}_{}".format(
        args.range[0], args.range[1]
    )))

    if vv.get('save_pickle', False):
        with open(osp.join(save_path, "softgym_traj_{}_{}_pickle".format(
                args.range[0], args.range[1]
            )), 'wb') as f:
            pickle.dump((rgbd_data, action_data), f)

    # format_data()

