from asyncio import FastChildWatcher
from turtle import Turtle, numinput, shape
import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
# import dmc2gym
import copy

import utils
from logger import Logger
from video import VideoRecorder

from crsfd_sac import crsfdSacAgent
from torchvision import transforms
from datetime import datetime
from peginhole.peginhole import SinglePeginHole
from collections import deque
def parse_args():
    parser = argparse.ArgumentParser()
    # environment
    parser.add_argument('--domain_name', default='peginhole')
    parser.add_argument('--task_name', default='4')
    parser.add_argument('--pre_transform_image_size', default=84, type=int)

    parser.add_argument('--image_size', default=84, type=int)
    parser.add_argument('--action_repeat', default=1, type=int)
    parser.add_argument('--frame_stack', default=3, type=int)
    # replay buffer
    parser.add_argument('--replay_buffer_capacity', default=200000, type=int)
    # train
    parser.add_argument('--agent', default='crsfd_sac', type=str)
    parser.add_argument('--init_steps', default=1000, type=int)
    parser.add_argument('--num_train_steps', default=200000, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--hidden_dim', default=1024, type=int)
    # eval
    parser.add_argument('--eval_freq', default=10000, type=int)
    parser.add_argument('--num_eval_episodes', default=10, type=int)
    # critic
    parser.add_argument('--critic_lr', default=1e-3, type=float)
    parser.add_argument('--critic_beta', default=0.9, type=float)
    parser.add_argument('--critic_tau', default=0.01, type=float) # try 0.05 or 0.1
    parser.add_argument('--critic_target_update_freq', default=2, type=int) # try to change it to 1 and retain 0.01 above
    # actor
    parser.add_argument('--actor_lr', default=1e-3, type=float)
    parser.add_argument('--actor_beta', default=0.9, type=float)
    parser.add_argument('--actor_log_std_min', default=-10, type=float)
    parser.add_argument('--actor_log_std_max', default=2, type=float)
    parser.add_argument('--actor_update_freq', default=1, type=int)
    # encoder
    parser.add_argument('--encoder_type', default='identity', type=str)
    parser.add_argument('--encoder_feature_dim', default=50, type=int)
    parser.add_argument('--encoder_lr', default=1e-3, type=float)
    parser.add_argument('--encoder_tau', default=0.05, type=float)
    parser.add_argument('--num_layers', default=4, type=int)
    parser.add_argument('--num_filters', default=32, type=int)
    parser.add_argument('--crsfd_latent_dim', default=128, type=int)
    # sac
    parser.add_argument('--discount', default=0.99, type=float)
    parser.add_argument('--init_temperature', default=0.1, type=float)
    parser.add_argument('--alpha_lr', default=1e-4, type=float)
    parser.add_argument('--alpha_beta', default=0.5, type=float)
    # misc
    parser.add_argument('--seed', default=-1, type=int)
    parser.add_argument('--work_dir', default='.', type=str)
    parser.add_argument('--save_tb', default=False, action='store_true')
    parser.add_argument('--save_buffer', default=False, action='store_true')
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument('--save_model', default=False, action='store_true')
    parser.add_argument('--detach_encoder', default=False, action='store_true')

    parser.add_argument('--log_interval', default=1000, type=int)
    parser.add_argument('--evaluate_dir', default=None, type=str)
    parser.add_argument('--evaluate_step', default=40000, type=int)
    parser.add_argument('--camera_id', default=5, type=int)
    parser.add_argument('--hole_height', default=0.8, type=float)

    parser.add_argument('--pretrain', default=None, type=str)
    parser.add_argument('--pretrain_step', default=40000, type=int)

    parser.add_argument('--load_buffer', default=None, type=str)
    parser.add_argument('--per_step_update', default=2, type=int)
    parser.add_argument('--pretrain_num', default=0, type=int)

    parser.add_argument('--demo_dir', default=None, type=str)
    parser.add_argument('--observation_mode', default='one', type=str)
    parser.add_argument('--action_mode', default='free', type=str)
    parser.add_argument('--horizon', default=50, type=int)
    parser.add_argument('--control_freq', default=10, type=int)
    parser.add_argument('--large_hole', default=False, action='store_true')
    parser.add_argument('--sparse', default=False, action='store_true')
    parser.add_argument('--img_demo', default=False, action='store_true')
    parser.add_argument('--stochastic', default=False, action='store_true')
    parser.add_argument('--demo_obs', default='standerd', type=str)
    parser.add_argument('--collect_imgbuffer', default=False, action='store_true')
    parser.add_argument('--demo_ratio', default=0.25, type=float)
    parser.add_argument('--demo_decay', default=False, action='store_true')
    parser.add_argument('--intrinsic_r', default=False, action='store_true')

    parser.add_argument('--reward_scale', default=10, type=float)
    parser.add_argument('--sparse3', default=False, action='store_true')

    parser.add_argument('--insert_depth', default=0.03, type=float)
    parser.add_argument('--self_imitation', default=False, action='store_true')


    args = parser.parse_args()
    return args


def evaluate(env, agent, video, num_episodes, L, step, args):
    all_ep_rewards = []
    

    def run_eval_loop(sample_stochastically=True):
        start_time = time.time()
        prefix = 'stochastic_' if sample_stochastically else ''
        part_insert=0
        total_insert=0
        for i in range(num_episodes):
            obs = env.reset()
            video.init(enabled=(i == 0))
            done = False
            episode_reward = 0
            while not done:
                # center crop image
                if args.encoder_type == 'pixel':
                    obs = utils.center_crop_image(obs,args.image_size)
                with utils.eval_mode(agent):
                    if sample_stochastically:
                        action = agent.sample_action(obs)
                    else:
                        action = agent.select_action(obs)
                obs, reward, done, _ = env.step(action)
                video.record(env)
                episode_reward += reward

            video.save('%d.mp4' % step)
            L.log('eval/' + prefix + 'episode_reward', episode_reward, step)
            all_ep_rewards.append(episode_reward)
            if env.get_peg_pos_to_hole()[2]<-0.045:
                total_insert += 1
            if env.get_peg_pos_to_hole()[2]<-0.002:
                part_insert += 1
        
        L.log('eval/' + prefix + 'eval_time', time.time()-start_time , step)
        mean_ep_reward = np.mean(all_ep_rewards)
        best_ep_reward = np.max(all_ep_rewards)
        L.log('eval/' + prefix + 'mean_episode_reward', mean_ep_reward, step)
        L.log('eval/' + prefix + 'best_episode_reward', best_ep_reward, step)
        L.log('eval/' + prefix + 'total_insert', total_insert, step)
        L.log('eval/' + prefix + 'part_insert', part_insert, step)
        return mean_ep_reward

    score = run_eval_loop(sample_stochastically=False)
    L.dump(step)

    return score


def make_agent(obs_shape, action_shape, args, device, cloning=True):
    if args.agent == 'crsfd_sac':
        return crsfdSacAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            hidden_dim=args.hidden_dim,
            discount=args.discount,
            init_temperature=args.init_temperature,
            alpha_lr=args.alpha_lr,
            alpha_beta=args.alpha_beta,
            actor_lr=args.actor_lr,
            actor_beta=args.actor_beta,
            actor_log_std_min=args.actor_log_std_min,
            actor_log_std_max=args.actor_log_std_max,
            actor_update_freq=args.actor_update_freq,
            critic_lr=args.critic_lr,
            critic_beta=args.critic_beta,
            critic_tau=args.critic_tau,
            critic_target_update_freq=args.critic_target_update_freq,
            encoder_type=args.encoder_type,
            encoder_feature_dim=args.encoder_feature_dim,
            encoder_lr=args.encoder_lr,
            encoder_tau=args.encoder_tau,
            num_layers=args.num_layers,
            num_filters=args.num_filters,
            log_interval=args.log_interval,
            detach_encoder=args.detach_encoder,
            crsfd_latent_dim=args.crsfd_latent_dim,
            reward_scale=args.reward_scale,
            cloning=cloning
        )
    else:
        assert 'agent is not supported: %s' % args.agent

def main():
    args = parse_args()
    if args.seed == -1: 
        args.__dict__["seed"] = np.random.randint(1,1000000)
    utils.set_seed_everywhere(args.seed)

    env = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=["Panda"], peg_class=args.task_name, horizon=args.horizon, large_hole=args.large_hole, hole_height=args.hole_height, control_freq=args.control_freq, sparse=args.sparse, sparse3=args.sparse3, depth=args.insert_depth,has_offscreen_renderer=False)
    env_eval = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=["Panda"], peg_class=args.task_name, horizon=args.horizon, large_hole=args.large_hole, hole_height=args.hole_height, control_freq=args.control_freq, sparse=args.sparse, sparse3=args.sparse3, depth=args.insert_depth,has_offscreen_renderer=False)

    env._max_episode_steps = args.horizon
    env_eval._max_episode_steps = args.horizon

    # env.seed(args.seed)

    # stack several consecutive frames together
    if args.encoder_type == 'pixel' or args.observation_mode == 'ee':
        env = utils.FrameStack(env, k=args.frame_stack)
        env_eval = utils.FrameStack(env_eval, k=args.frame_stack)
    
    # make directory
    ts = time.gmtime() 
    ts = time.strftime("%m-%d", ts)    
    env_name = args.domain_name + '-' + args.task_name
    timenow = str(datetime.now())
    timenow = timenow[0:10] + '_' + timenow[11:13] + '_' + timenow[14:16]+ '_' + timenow[17:19]
    exp_name = env_name + '_' + timenow
    folder = './tmp_demo/' if args.load_buffer is not None else './tmp_ee/'
    folder = './tmp_pixel/' if args.encoder_type == 'pixel' else folder
    folder = './tmp518/'
    work_dir = folder+args.domain_name
    if not os.path.exists(work_dir): os.mkdir(work_dir)

    args.work_dir = work_dir + '/'  + exp_name
    if not os.path.exists(args.work_dir): os.mkdir(args.work_dir)
    utils.make_dir(args.work_dir)
    video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
    model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
    video = VideoRecorder(video_dir if args.save_video else None)
    with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    # action, state
    action_shape = env.action_space.shape
    if args.encoder_type == 'pixel':
        obs_shape = (3*args.frame_stack, args.image_size, args.image_size)
        pre_aug_obs_shape = (3*args.frame_stack,args.pre_transform_image_size,args.pre_transform_image_size)
    else:
        obs_shape = env.observation_space.shape
        pre_aug_obs_shape = obs_shape
    print(pre_aug_obs_shape)
    print(action_shape)

    # replaybuffer, agent, logger
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    replay_buffer = utils.ReplayBuffer(
        obs_shape=pre_aug_obs_shape,
        action_shape=action_shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device,
        image_size=args.image_size,
        demo_ratio=args.demo_ratio,
        demo_decay=args.demo_decay,
        self_imitate_num=20000 if args.self_imitation else 0,
    )
    trajectory = utils.OneTrajectory(obs_shape=pre_aug_obs_shape, action_shape=action_shape,)
    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device,
        cloning=False
    )
    L = Logger(args.work_dir, use_tb=args.save_tb)

    # pretrain or load buffer 
    if args.pretrain is not None:
        agent.load(args.pretrain, args.pretrain_step)
    if args.load_buffer is not None:
        replay_buffer.load_peg2(args.load_buffer)
        for i in range(args.pretrain_num):
            agent.update(replay_buffer, L, -1)
    if args.collect_imgbuffer:
        img_buffer = np.zeros(40000,3,84,84)

    # start trainnig
    best_score = 5
    log_time = 0
    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    good_traj_count = 0

    for step in range(args.num_train_steps):
        # evaluate agent periodically
        if step % args.eval_freq == 0:
            L.log('eval/episode', episode, step)
            score = evaluate(env_eval, agent, video, args.num_eval_episodes, L, step,args)
            if args.save_model and score>best_score:
                # agent.save_crsfd(model_dir, step)
                agent.save(model_dir, step)
                best_score = score
            if args.save_buffer:
                replay_buffer.save(buffer_dir)

        # run training update
        if step >= args.init_steps:
            num_updates = args.per_step_update 
            for _ in range(num_updates):
                agent.cloning(replay_buffer, L, step)

  
#########################################################################
#########################################################################
def evaluate_peg(sample_stochastically=False, num_episodes=1,step=40000):
    args = parse_args()
    video_dir = os.path.join(args.evaluate_dir, 'video')
    model_dir = os.path.join(args.evaluate_dir, 'model')
    video = VideoRecorder(video_dir,camera_id=args.camera_id, fps=args.control_freq)
    env = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=["Panda"], peg_class=args.task_name, has_offscreen_renderer=True, horizon=args.horizon, large_hole=args.large_hole, hole_height=args.hole_height, control_freq=args.control_freq)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    action_shape = env.action_space.shape
    obs_shape = env.observation_space.shape

    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device
    )
    
    agent.load(model_dir, step)
    video.init(enabled=1)
    for i in range(num_episodes):
        obs = env.reset()
        # print(env.hole_pos)
        # print(obs[28:31])
        done = False
        episode_reward = 0
        while not done:
            with utils.eval_mode(agent):
                if sample_stochastically:
                    action = agent.sample_action(obs)
                else:
                    action = agent.select_action(obs)
            obs, reward, done, _ = env.step(action)
            if done:
                print(env.get_peg_pos_to_hole()[2])
            video.record(env)
            episode_reward += reward
        print('{},{}'.format(step, episode_reward))
    timenow = str(datetime.now())[0:-10]
    timenow = timenow[0:10] + '_' + timenow[11:13] + '_' + timenow[14:16]
    video.save('{}_{}_{}_{}.mp4'.format(step, timenow, args.task_name, args.large_hole))

def collect_demo2(collet_num=20, model_step=400001, threshold=-0.05, sparse=False,sparse3=False,early_stop=True, demo_obs='standerd', sample_stochastically=True):
    args = parse_args()
    model_dir = os.path.join(args.demo_dir, 'model')
    img_demo = True if demo_obs == 'pixel' else False

    # load env, agent
    env = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=["Panda"], peg_class=args.task_name, has_offscreen_renderer=img_demo, horizon=args.horizon, large_hole=args.large_hole, control_freq=args.control_freq, depth=args.insert_depth,sparse3=sparse3)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    action_shape = env.action_space.shape
    obs_shape = env.observation_space.shape
    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device
    )
    agent.load(model_dir, model_step)
    
    # create demo buffer
    if demo_obs == 'pixel':
        size = args.image_size
        obs_shape = [3*args.frame_stack, size, size]
        frames = deque([], maxlen=3)
    elif demo_obs == 'eef':
        obs_shape = [14]
    obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
    print(obs_shape)
    print(action_shape)
    obs_array = np.zeros((collet_num, *obs_shape), dtype=obs_dtype)
    action_array = np.zeros((collet_num, *action_shape))
    r_array = np.zeros((collet_num, 1))
    nexto_array = np.zeros((collet_num, *obs_shape), dtype=obs_dtype)
    done_array = np.zeros((collet_num, 1))

    # start collection
    obs = env.reset()
    done = False
    step = 0
    if img_demo == True:
        img = env.render(width=size, height=size, camera_id=5)
        img = np.transpose(img, (2,0,1)).copy()
        for _ in range(3):
            frames.append(img)
        frame = np.concatenate(list(frames), axis=0)
    elif demo_obs == 'eef':
        eef = env.get_eef_info()

    while(step < collet_num):
        with utils.eval_mode(agent):
            if sample_stochastically:
                action = agent.sample_action(obs)
            else:
                action = agent.select_action(obs)
        next_obs, reward, done, _ = env.step(action)
        if sparse:
            # print(env.get_peg_pos_to_hole()[2])
            reward = env.sparse_reward()
        elif sparse3:
            reward = env.sparse_reward3(depth=args.insert_depth)
        elif env.get_peg_pos_to_hole()[2] < -0.045:
            reward = 50
        
        if img_demo == True:
            next_img = env.render(width=size, height=size, camera_id=5)
            next_img = np.transpose(next_img, (2,0,1)).copy()
            frames.append(next_img)
            next_frame = np.concatenate(list(frames), axis=0)

            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = frame, action, reward, next_frame, done
            frame = next_frame
        elif demo_obs == 'eef':
            next_eef = env.get_eef_info()
            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = eef, action, reward, next_eef, done
            eef = next_eef
        else:
            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = obs, action, reward, next_obs, done
        print('reward:{},done:{}'.format(reward,done))

        obs = next_obs
        step += 1
        if env.get_peg_pos_to_hole()[2] < -0.045 or done == 1:
            text = 'success' if env.get_peg_pos_to_hole()[2] < -0.045 else 'time_limit'
            print(text+str(env.get_peg_pos_to_hole()[2]))
            obs = env.reset()
            done = False
        # print('{},{}'.format(step, episode_reward))
    # save demo buffer
    if not os.path.exists('./demo/{}'.format(args.task_name)): os.mkdir('./demo/{}'.format(args.task_name))
    s = 'stochastic' if args.stochastic else 'deterministic'
    a = 'sparse' if sparse else 'dense'
    b = 'large' if args.large_hole else 'small'
    c = args.demo_obs
    np.savez('./demo_03cm/{}/{}_{}_{}_{}.npz'.format(args.task_name, s,a, b, c), obs_array,action_array,r_array,nexto_array,done_array)


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    args = parse_args()
    if  args.evaluate_dir is not None:
        evaluate_peg(num_episodes=5,step=args.evaluate_step)
    elif args.demo_dir is not None:
        collect_demo2(collet_num=1000, model_step=args.evaluate_step,demo_obs=args.demo_obs, sample_stochastically=args.stochastic,sparse=args.sparse, sparse3=args.sparse3)
    else:
        main()
        