import gym
from args import get_args
from model import ActorCritic
import pickle
import os
import json
import pathlib
import torch



def write_to_json(ips,filename):
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(ips, f, indent=4)

env_group_2 = [ 'Point2D-FourRoom-v1', 'Reacher-v2', 'SawyerDoor-v0'] 
env_group_1 = ['FetchPush-v1', 'FetchSlide-v1', 'FetchPickAndPlace-v1']

import sys
class Logger(object):
    def __init__(self, filename='default.log', stream=sys.stdout):
        self.terminal = stream
        self.log = open(filename, 'a')
        print(filename,flush= True)
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
    def flush(self):
        return True

def main():
    args = get_args()
    import random
    prefix= "../data/"+args.name 
    pathlib.Path(prefix).mkdir(exist_ok=True)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    
    if args.env_name in env_group_1:
        env = gym.make(args.env_name)
        state_dim = env.observation_space.spaces['observation'].shape[0] 
        goal_dim = env.observation_space.spaces['desired_goal'].shape[0]  # extended state
        from splid import splid
        num_actions = env.action_space.shape[0]
    elif args.env_name in env_group_2:
        from splid_new_environment import splid
        from new_environment.common.vec_env import VecEnv
        from new_environment.common.env_util import get_env_type, build_env, get_game_envs 
        from new_environment.common.parse_args import common_arg_parser, parse_unknown_args
        from new_environment.common import logger
        from new_environment.common.parse_args import get_learn_function_defaults, parse_cmdline_kwargs, parse_unknown_args
        
        from new_environment.util import init_logger
        
        _game_envs = get_game_envs()
        env = build_env(args, _game_envs)
        env.reset()
        obs, _, _, info = env.step(env.action_space.sample())
        state_dim = sum(obs['observation'].shape)-1
        goal_dim = sum(obs['desired_goal'].shape)-1
        num_actions = sum(env.action_space.shape)
        
        print("shape",state_dim,goal_dim,num_actions)
        
    else:
        raise("Unknown Environment!")
        
   
    testnetwork = False
    '''joint train'''
    for i in range(int(args.num_parallel_run)):
            network = ActorCritic(goal_dim + state_dim, num_actions, layer_norm=args.layer_norm)
            network.to(device)
            args.seed += int(args.rank)
            args.factor = float(args.factor)
            path = prefix + f'/data{i}.pickle'
            sys.stdout = Logger(stream=sys.stdout,filename=prefix+f'/log{i}.log')
            write_to_json(vars(args),prefix + '/setting.json')
            f = open(path, 'wb')
            storage = splid(args, network,env,goal_dim,state_dim, device,testnetwork)
            pickle.dump(storage,f)
            f.close()
        
if __name__ == '__main__':
    main()



