import struct
import gymnasium as gym
import gymnasium_robotics
import glob
import os
import numpy as np
from minari import StepDataCallback, DataCollector
from gym.wrappers import TimeLimit

#print(gym.envs.registry.keys())

DIR_PATH = {'complete': '*microwave_kettle_switch_slide', 'partial': '*', 'mixed': '*'}
GOAL_TASKS = {'complete': ['microwave', 'kettle', 'light switch', 'slide cabinet'], 'partial': ['microwave', 'kettle', 'light switch', 'slide cabinet'], 'mixed': ['microwave', 'kettle', 'bottom burner', 'light switch']}
DESCRIPTIONS = {
    'complete': 'The complete dataset includes demonstrations of all 4 target subtasks being completed, in order.',
    'partial': 'The mixed dataset contains various subtasks being performed, but the 4 target subtasks are never completed in sequence together.',
    'mixed': 'The partial dataset includes other tasks being performed, but there are subtrajectories where the 4 target subtasks are completed in sequence.'
}

def parse_mjl_logs(read_filename, skipamount):
    print(f'READ FILE {read_filename}')
    with open(read_filename, mode='rb') as file:
        fileContent = file.read()
    headers = struct.unpack('iiiiiii', fileContent[:28])
    nq = headers[0]
    nv = headers[1]
    nu = headers[2]
    nmocap = headers[3]
    nsensordata = headers[4]
    nuserdata = headers[5]
    name_len = headers[6]
    name = struct.unpack(str(name_len) + 's', fileContent[28:28+name_len])[0]
    rem_size = len(fileContent[28 + name_len:])
    num_floats = int(rem_size/4)
    dat = np.asarray(struct.unpack(str(num_floats) + 'f', fileContent[28+name_len:]))
    recsz = 1 + nq + nv + nu + 7*nmocap + nsensordata + nuserdata
    if rem_size % recsz != 0:
        print("ERROR")
    else:
        dat = np.reshape(dat, (int(len(dat)/recsz), recsz))
        dat = dat.T

    time = dat[0,:][::skipamount] - 0*dat[0, 0]
    qpos = dat[1:nq + 1, :].T[::skipamount, :]
    qvel = dat[nq+1:nq+nv+1,:].T[::skipamount, :]
    ctrl = dat[nq+nv+1:nq+nv+nu+1,:].T[::skipamount,:]
    mocap_pos = dat[nq+nv+nu+1:nq+nv+nu+3*nmocap+1,:].T[::skipamount, :]
    mocap_quat = dat[nq+nv+nu+3*nmocap+1:nq+nv+nu+7*nmocap+1,:].T[::skipamount, :]
    sensordata = dat[nq+nv+nu+7*nmocap+1:nq+nv+nu+7*nmocap+nsensordata+1,:].T[::skipamount,:]
    userdata = dat[nq+nv+nu+7*nmocap+nsensordata+1:,:].T[::skipamount,:]

    data = dict(nq=nq,
               nv=nv,
               nu=nu,
               nmocap=nmocap,
               nsensordata=nsensordata,
               name=name,
               time=time,
               qpos=qpos,
               qvel=qvel,
               ctrl=ctrl,
               mocap_pos=mocap_pos,
               mocap_quat=mocap_quat,
               sensordata=sensordata,
               userdata=userdata,
               logName = read_filename
               )
    return data


class KitchenStepDataCallback(StepDataCallback):
    def __call__(self, env, obs, info, action=None, rew=None, terminated=None, truncated=None):
        step_data = super().__call__(env, obs, info, action, rew, terminated, truncated)
        
        step_data['info'] = {"ctrl": obs['observation'][:9]}
        
        return step_data


def generate_datasets():
    for dst, dir in DIR_PATH.items():
        print(f'CREATING DATASET: {dst}')
        max_episode_steps = 450 if (dst == 'mixed' or dst == 'partial') else 280
        
        base_env = gym.make('FrankaKitchen-v1', remove_task_when_completed=False, terminate_on_tasks_completed=False, tasks_to_complete=GOAL_TASKS[dst], max_episode_steps=max_episode_steps)
        env = TimeLimit(base_env, max_episode_steps=max_episode_steps)
        data_collector = DataCollector(env, step_data_callback=KitchenStepDataCallback, record_infos=True)
        
        act_mid = np.zeros(9)
        act_rng = np.ones(9) * 2
        dataset = []
        pattern = f'/home/mrq/project/diffusion_policy-main/data/kitchen/kitchen_demos_multitask/{dir}'
        max_steps_episode = 0
        demo_subdirs = sorted(glob.glob(pattern))
        
        for subdir in demo_subdirs:
            demo_files = glob.glob(os.path.join(subdir, '*.mjl'))
            print(f'Demo: {subdir}')
            
            for demo_file in demo_files:
                episode_steps = 0
                data = parse_mjl_logs(demo_file, 40)
                obs, _ = data_collector.reset()
                
                if data['ctrl'].shape[0] > max_steps_episode:
                    max_steps_episode = data['ctrl'].shape[0]
                
                for i_frame in range(data['ctrl'].shape[0] - 1):
                    frame_skip = base_env.robot_env.frame_skip
                    timestep = base_env.robot_env.model.opt.timestep

                    ctrl = (data['ctrl'][i_frame] - obs['observation'][:9]) / (frame_skip * timestep)
                    act = (ctrl - act_mid) / act_rng
                    
                    noise = np.random.normal(0, 0.1, size=act.shape)
                    act = act + noise
                    act = np.clip(act, -0.999, 0.999)
                    
                    obs, reward, terminated, truncated, env_info = data_collector.step(act)
                    episode_steps += 1

        data_collector.close()      
        print(f'MAX EPISODE STEPS: {max_steps_episode}')
                

if __name__ == '__main__':
    generate_datasets()