from ast import arg
from env import Unlock, LiftEnv, CrashEnv
from agents import SAC, BC
from because import Because
import numpy as np
from utils.utils import load_config
import copy
import argparse
import time, sys, os
import torch
from utils.utils import CUDA, CPU
import datetime
# from stable_baselines3.common.vec_env import SubprocVecEnv
from utils.wrapper import SubprocVecEnv
np.set_printoptions(linewidth=np.inf)

# collect data scripts
def collect_data(save_path, time_stamp, agent, env, env_name, offline_data, num_episodes=1000):
    file_name = os.path.join(save_path, env_name+'_'+offline_data+'_'+time_stamp)
    count_episode = 0
    data_buffer = {'state': [], 'action': [], 'next_state': [], 'reward': [], 'done': []}
    state = env.reset()
    while count_episode < num_episodes:
        action = agent.select_action_parallel(env, state, True)
        
        next_state, reward, done, info = env.step(action)
        data_buffer['state'].append(state)
        data_buffer['action'].append(action)
        data_buffer['next_state'].append(next_state)
        data_buffer['reward'].append(reward)
        data_buffer['done'].append(done)
        count_episode += done.sum()
        state = copy.deepcopy(next_state)
    
    return data_buffer
    
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, required=True, help='IID / OOD')
parser.add_argument('--agent', type=str, default='Because', help='Because / SAC / ICIN')
parser.add_argument('--env', type=str, default='unlock', help='unlock / crash / lift')
parser.add_argument('--Because_model', type=str, help='causal / full / mopo / gnn')
parser.add_argument('--log_dir', type=str, required=True)
parser.add_argument('--offline_data', type=str, required=True, help='random / medium / expert')
parser.add_argument('--dataset', type=str, default='./', help='dataset path')
parser.add_argument('--collect_data', action='store_true', help='whether to collect data')

def make_env_unlock(test_mode='IID', stage='train'):
    env = Unlock(test_mode=test_mode, stage=stage)
    return env

def make_env_crash(test_mode='IID', stage='train'):
    env = CrashEnv(test_mode=test_mode, stage=stage)
    return env

def make_env_lift(test_mode='IID', stage='train'):
    env = LiftEnv(test_mode=test_mode, stage=stage)
    return env

if __name__ == '__main__':
    # make envs should be put into main loop to avoid multi-processing error
    args = parser.parse_args()
    
    # environment parameters

    if args.env == 'unlock':
        
        env = Unlock(test_mode=args.mode)
        config = load_config(config_path="config/unlock_config.yaml")
        agent_config = config[args.agent]
        env_params = {
            'action_dim': env.action_dim,
            'state_dim': env.state_dim,
            'goal_dim': 0,
            'room_size': env.room_size,
            'move_dim': env.move_dim,
            'pick_key_dim': env.pick_key_dim,
            'open_door_dim': env.open_door_dim,
            'max_key_num': env.max_key_num,
        }
        episode = 200
        test_episode = 20
        num_envs = agent_config['planner']['num_envs']
        if args.collect_data: 
            env_train = SubprocVecEnv([lambda: make_env_unlock(args.mode, 'train') for _ in range(num_envs)], start_method='spawn')

        env_test = SubprocVecEnv([lambda: make_env_unlock(args.mode, 'test') for _ in range(num_envs)], start_method='spawn')
        
        print(env.random_action())
        # print(env, env_multiple)
        # input()
    elif args.env == 'lift': 
        env = LiftEnv(test_mode=args.mode)
        print(env)
        config = load_config(config_path="config/lift_config.yaml")
        agent_config = config[args.agent]
        env_params = {
            'action_dim': env.action_dim,
            'state_dim': env.state_dim,
            'goal_dim': 0,
        }
        print(env_params)
        episode = 200
        test_episode = 20
        try: 
            num_envs = agent_config['planner']['num_envs']
        except: 
            num_envs = 4
        if args.collect_data: 
            env_train = SubprocVecEnv([lambda: make_env_lift(args.mode, 'train') for _ in range(num_envs)], start_method='spawn')
        env_test = SubprocVecEnv([lambda: make_env_lift(args.mode, 'test') for _ in range(num_envs)], start_method='spawn')
        
    elif args.env == 'crash':
        env = CrashEnv(test_mode=args.mode, use_render=False, save_gif=False)
        config = load_config(config_path="config/crash_config.yaml")
        
        agent_config = config[args.agent]
        env_params = {
            'action_dim': env.action_dim,
            'state_dim': env.state_dim,
            'goal_dim': 0, 
            'agent_state_dim': env.agent_state_dim,
            'agent_action_dim': env.agent_action_dim,
            'n_agents': env.n_agents,
            'map_scale': env.map_scale,
            'collision_threshold': env.collision_threshold,
            'collision_dim': env.collision_dim,
        }
        episode = 200
        test_episode = 20
        num_envs = agent_config['planner']['num_envs']
        if args.collect_data: 
            env_train = SubprocVecEnv([lambda: make_env_crash(args.mode, 'train') for _ in range(num_envs)], start_method='spawn')
        env_test = SubprocVecEnv([lambda: make_env_crash(args.mode, 'test') for _ in range(num_envs)], start_method='spawn')
        print(env.random_action())
    else:
        raise ValueError('Wrong environment name')
    env_params['env_name'] = args.env
    agent_config['env_params'] = env_params
    save_path = os.path.join('./log', args.log_dir)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print(agent_config)
    
    render = False
    test_only = False
    trails = 10
    test_interval = 10
    save_interval = 100000
    minimal_test = 0
    all_q_values = []

    for t_i in range(trails):
        # create agent
        if args.agent == 'Because':
            agent_config['Because_model'] = args.Because_model
            agent = Because(agent_config)
        elif args.agent == 'BC': 
            agent = CUDA(BC(agent_config))
        elif args.agent == 'SAC':
            agent = SAC(agent_config)
        if test_only:
            agent.model_id = 1000
            agent.load_model()
        
        if not args.collect_data: # pre-load dataset in offline learning
            if args.env == 'unlock': 
                if args.offline_data == 'expert': 
                    data = torch.load("data/unlock/expert_data.pt")
                    label = torch.load("data/unlock/expert_label.pt")
                    
                elif args.offline_data == 'medium':
                    data = torch.load("data/unlock/medium_data.pt")
                    label = torch.load("data/unlock/medium_label.pt") 
                elif args.offline_data == 'random': 
                    data = torch.load("data/unlock/random_data.pt")
                    label = torch.load("data/unlock/random_label.pt")
                
            elif args.env == 'crash': 
                if args.offline_data == 'expert': 
                    data = torch.load("data/crash/expert_data.pt")
                    label = torch.load("data/crash/expert_label.pt")
                    
                elif args.offline_data == 'medium':
                    data = torch.load("data/crash/medium_data.pt")
                    label = torch.load("data/crash/medium_label.pt") 
                elif args.offline_data == 'random': 
                    data = torch.load("data/crash/random_data.pt")
                    label = torch.load("data/crash/random_label.pt")
                                
            elif args.env == 'lift': 
                if args.offline_data == 'expert': 
                    data_raw = np.load("data/robosuite/lift_expert_1000.npy", allow_pickle=True)
                elif args.offline_data == 'medium':
                    data_raw = np.load("data/robosuite/lift_medium_1000.npy", allow_pickle=True)
                elif args.offline_data == 'random': 
                    data_raw = np.load("data/robosuite/lift_random_1000.npy", allow_pickle=True)
                
                data, label = [], []
                for i in range(len(data_raw)): 
                    data.append(np.concatenate([data_raw[i]['obs'], data_raw[i]['acts']], axis=1))
                    label.append(data_raw[i]['obs_next']-data_raw[i]['obs'])
                data = np.concatenate(data, axis=0)
                label = np.concatenate(label, axis=0)
                print(data.shape, label.shape)
                data = CUDA(torch.from_numpy(data).float())
                label = CUDA(torch.from_numpy(label).float())
               
            if agent.name in ['Because']: 
                agent.planner.model.data = data
                agent.planner.model.label = label
            
            elif agent.name in ['BC']: 
                if args.env in ['unlock', 'crash']: 
                    agent.load_data(data, label)
                elif args.env == 'lift':
                    agent.load_data_from_path(args.dataset)

        save_gif_count = 0
        test_reward = []
        train_reward = []
        for e_i in range(episode):
            if not test_only:
                # reset the training environment
                state = env.reset()
                # sys.stdout = open(os.devnull, 'w')
                if args.collect_data: 
                    # done = np.array([False]*num_envs)
                    done = False
                    one_train_reward = 0
                    while not done:
                        action = agent.select_action(env, state, False)
                        next_state, reward, done, info = env.step(action)
                        one_train_reward += reward
                        # print(next_state, reward, done, info)
                        agent.store_transition([state, action, next_state])
                        state = copy.deepcopy(next_state)
                # sys.stdout = sys.__stdout__
                # print(agent.planner.model.data.shape, agent.planner.model.label.shape)
                if agent.name in ['PPO']: # on-policy methods
                    agent.train(state)
                elif agent.name in ['Because']:  # mbrl
                    agent.train()
                    print('test loss: ', agent.planner.model.best_test_loss)
                elif agent.name in ['BC']: 
                    agent.update()
                    print('test loss: ', agent.train_loss)
                    
                   
                # save model
                if (e_i+1) % save_interval == 0:
                    agent.model_id = e_i + 1
                    agent.save_model()

            if (e_i+1) % test_interval == 0 and e_i > minimal_test:
                episode_q_values = []
                # test_reward_mean = []
                count_done = 0
                done = np.array([False]*num_envs)
                state = env_test.reset()
                total_reward = 0
                while count_done < test_episode:                    
                # for t_j in range(test_episode):
                    # reset the testing environment
                    # state = env_test.reset()
                    # done = False
                    step_reward = []
                    # while not done.any():
                    # action = env.random_action()
                    action = agent.select_action_parallel(env_test, state, True)
                    next_state, reward, done, info = env_test.step(action)
                    count_done += done.sum()
                    total_reward += reward.sum()
                    
                    if render:
                        env_test.render()
                        time.sleep(0.05)

                    state = copy.deepcopy(next_state)
                    step_reward.append(reward)
                    # test_reward_mean.append(total_reward)

                    # calculate Q value
                    q_values = []
                    for q_i in range(len(step_reward)):
                        q = 0
                        gamma = 0.99
                        for q_j in range(q_i, len(step_reward)):
                            q += gamma * step_reward[q_j]
                            gamma *= gamma
                        q_values.append(q)
                    episode_q_values.append(q_values)
                all_q_values.append(episode_q_values)

                test_reward_mean = total_reward / test_episode # np.mean(test_reward_mean, axis=0)
                print(test_reward_mean)
                print('[{}/{}] [{}/{}] Test Reward: {}'.format(t_i, trails, e_i, episode, test_reward_mean))
                test_reward.append(test_reward_mean)
                np.save(save_path+'/test.reward.'+str(t_i)+'.npy', test_reward)