import argparse
import json
import gym
import sys
from gym import spaces
import numpy as np

import torch as th
import importlib
import multiprocessing

from stable_baselines3 import PPO
from pantheonrl.algos.ppo.mappo import MAPPO
from pantheonrl.algos.ppo.instruct_ppo import InstructPPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed

from pantheonrl.common.wrappers import frame_wrap, recorder_wrap
from pantheonrl.common.agents import OnPolicyAgent, StaticPolicyAgent, StaticRNNPolicyAgent
from pantheonrl.common.multiagentenv import DummyEnv

from overcookedgym.overcooked_utils import LAYOUT_LIST, DIVERSE_COORPERATION_STYLE_LIST, DIVERSE_ORDERS_STYLE_LIST, CENTER_POTS_STYLE_LIST, CROSSWAY_STYLE_LIST
from assistivegym.assistive_utils import ENV_NAME_LIST, HUMAN_CONFIG_LIST


sys.path.append('assistivegym')
sys.path.append('lb-foraging')
sys.path.append('overcookedgym/human_aware_rl/overcooked_ai')

ENV_LIST = ['OvercookedMultiEnv-v0', 'AssistiveMultiEnv-v0', 'LBFMultiEnv-v0']

ADAP_TYPES = ['ADAP', 'ADAP_MULT']
EGO_LIST = ['PPO', 'ModularAlgorithm', 'LOAD'] + ADAP_TYPES
PARTNER_LIST = ['PPO', 'DEFAULT', 'FIXED'] + ADAP_TYPES


class EnvException(Exception):
    """ Raise when parameters do not align with environment """


def input_check(args):
    # Env checking
    if args.env == 'OvercookedMultiEnv-v0':
        if 'layout_name' not in args.env_config:
            raise EnvException(f"layout_name needed for {args.env}")
        elif args.env_config['layout_name'] not in LAYOUT_LIST:
            raise EnvException(
                f"{args.env_config['layout_name']} is not a valid layout")

    # Construct alt configs
    if args.alt_config is None:
        args.alt_config = [{} for _ in args.alt]
    elif len(args.alt_config) != len(args.alt):
        raise EnvException(
            "Number of partners is different from number of configs")

    # Construct ego config
    if 'verbose' not in args.ego_config:
        args.ego_config['verbose'] = 1

    if (args.tensorboard_log is not None) != \
            (args.tensorboard_name is not None):
        raise EnvException("Must define log and names for tensorboard")



def make_env(args, seed, human_config=None):
    module = importlib.import_module('assistive_gym.envs')
    def _init():    
        env_name = args.env_config["env_name"]
        env_class = getattr(module, env_name.split('-')[0] + 'Env')
        env = env_class()
        env.render_mode = 'human'
        return Monitor(env) # Wrap with Monitor to log ep_info
    set_random_seed(seed)
    return _init

def generate_subprocenv(args, human_config=None):
    vec_env = SubprocVecEnv([make_env(args, i, human_config) for i in range(multiprocessing.cpu_count())])
    return vec_env

def generate_assistive_env(args):
    module = importlib.import_module('assistive_gym.envs')
    env_name = args.env_config["env_name"]
    env_class = getattr(module, env_name.split('-')[0] + 'Env')
    env = env_class()
    return env

def generate_env(args):
    env = gym.make(args.env, **args.env_config)

    altenv = env.getDummyEnv(1)

    if args.framestack > 1:
        env = frame_wrap(env, args.framestack)
        altenv = frame_wrap(altenv, args.framestack)

    if args.record is not None:
        env = recorder_wrap(env)

    return env, altenv

def generate_agent_pair(env, args, seed=None):
    kwargs = args.ego_config
    kwargs['env'] = env
    if args.env == "AssistiveMultiEnv-v0":
        kwargs['n_steps'] = 200 
    else:
        kwargs['n_steps'] = 2048
    kwargs['device'] = args.device
    if seed is not None:
        kwargs['seed'] = seed
    elif args.seed is not None:
        kwargs['seed'] = args.seed

    kwargs['tensorboard_log'] = args.tensorboard_log
    kwargs['ent_coef'] = 0.00
    kwargs['learning_rate'] = 1e-4
    
    ego_kwargs = kwargs.copy()
    partner_kwargs = kwargs.copy()
    if args.env == "AssistiveMultiEnv-v0":
        if isinstance(env, SubprocVecEnv):
            ego_dummy_env = DummyEnv(env.get_attr('observation_space')[0], env.get_attr('action_space_robot')[0])
            ego_kwargs['env'] = SubprocVecEnv([lambda: ego_dummy_env for _ in range(env.num_envs)])
            partner_dummy_env = DummyEnv(env.get_attr('observation_space')[0], env.get_attr('action_space_human')[0])
            partner_kwargs['env'] = SubprocVecEnv([lambda: partner_dummy_env for _ in range(env.num_envs)])
        elif isinstance(env, DummyEnv):
            # multi-task training
            ego_kwargs['env'] = env
        else:
            ego_kwargs['env'] = DummyEnv(env.observation_space, env.action_space_robot)
            partner_kwargs['env'] = DummyEnv(env.observation_space, env.action_space_human)
            
    elif args.env == "LBFMultiEnv-v0":
        ego_kwargs['env'] = DummyEnv(env.observation_space, env.action_space)
        partner_kwargs['env'] = DummyEnv(env.observation_space, env.action_space)
    
    ego = PPO(policy='MlpPolicy', **ego_kwargs)
    kwargs['n_updates'] = 1 # number of epoches to update the policy after each single rollout
    if not args.multi_task:
        partner = PPO(policy='MlpPolicy', **partner_kwargs)
        return MAPPO(ego, agent_2=partner, policy='MlpPolicy', **kwargs)
    else:
        # the partner is provided in multi-task training
        return MAPPO(ego, policy='MlpPolicy', **kwargs)
    

def generate_ego(env, args, ego_id=None, seed=None):
    kwargs = args.ego_config
    kwargs['env'] = env
    kwargs['device'] = args.device
    if seed is not None:
        kwargs['seed'] = seed
    elif args.seed is not None:
        kwargs['seed'] = args.seed

    kwargs['tensorboard_log'] = args.tensorboard_log

    if args.ego == 'LOAD':
        model = gen_load(kwargs, kwargs['type'], kwargs['location'])
        # wrap env in Monitor and VecEnv wrapper
        vec_env = DummyVecEnv([lambda: Monitor(env)])
        model.set_env(vec_env)
        if kwargs['type'] == 'ModularAlgorithm':
            model.policy.do_init_weights(init_partner=True)
            model.policy.num_partners = len(args.alt)
        return model
    elif args.ego == 'PPO':
        return PPO(policy='MlpPolicy', **kwargs)
    else:
        raise EnvException("Not a valid policy")


def gen_load(config, policy_type, location):
    if policy_type == 'PPO':
        agent = PPO.load(location)
    else:
        raise EnvException("Not a valid FIXED/LOAD policy")

    return agent


def gen_fixed(config, policy_type, location):
    agent = gen_load(config, policy_type, location)
    return StaticPolicyAgent(agent.policy)

def gen_rnn_fixed(config, policy_type, location):
    agent = gen_load(config, policy_type, location)
    return StaticRNNPolicyAgent(agent.policy)


def gen_ppo_partner(args):
    args.partner_num = len(args.alt) - 1
    if args.tensorboard_log is not None:
        agentarg = {
            'tensorboard_log': args.tensorboard_log,
            'tb_log_name': args.tensorboard_name+'_alt_'+str(args.partner_num)
        }   
        
    config = args.alt_config[args.partner_num]
    config['env'] = altenv
    config['device'] = args.device
    if args.seed is not None:
        config['seed'] = args.seed
    config['verbose'] = args.verbose_partner

    partner = OnPolicyAgent(PPO(policy='MlpPolicy', **config), **agentarg)
    return partner

def gen_partner(type, config, altenv, ego, args):
    if type == 'FIXED':
        return gen_fixed(config, config['type'], config['location'])
    elif type == 'DEFAULT':
        raise EnvException("No default policy available")

    if args.tensorboard_log is not None:
        agentarg = {
            'tensorboard_log': args.tensorboard_log,
            'tb_log_name': args.tensorboard_name+'_alt_'+str(args.partner_num)
        }   
    else:
        agentarg = {}

    config['env'] = altenv
    config['device'] = args.device
    if args.seed is not None:
        config['seed'] = args.seed
    config['verbose'] = args.verbose_partner

    if type == 'PPO':
        return OnPolicyAgent(PPO(policy='MlpPolicy', **config), **agentarg)
    else:
        raise EnvException("Not a valid policy")


def generate_partners(altenv, env, ego, args):
    partners = []
    for i in range(len(args.alt)):
        args.partner_num = i
        v = gen_partner(args.alt[i],
                        args.alt_config[i],
                        altenv,
                        ego,
                        args)
        print(f'Partner {i}: {v}')
        env.add_partner_agent(v)
        partners.append(v)
    return partners


def preset(args, preset_id):
    '''
    helpful defaul configuration settings
    '''

    if preset_id == 1:
        env_name = args.env
        if 'layout_name' in args.env_config:
            env_name = "%s-%s" % (args.env, args.env_config['layout_name'])
        elif 'env_name' in args.env_config:
            env_name = "%s-%s" % (args.env, args.env_config['env_name'])
        else:
            env_name = "%s" % (args.env)

        if args.tensorboard_log is None:
            if args.env == "OvercookedMultiEnv-v0":
                args.tensorboard_log = 'logs/%s' % (args.env_config['layout_name'])
            elif args.env == "LBFMultiEnv-v0":
                args.tensorboard_log = 'logs/%s' % (args.env)
            elif args.env == "AssistiveMultiEnv-v0":
                args.tensorboard_log = 'logs/%s' % (args.env_config['env_name'])
            else:
                args.tensorboard_log = 'logs/%s' % (env_name)
                
        if args.tensorboard_name is None:
            if args.env == "OvercookedMultiEnv-v0":
                if args.style is not None:
                    args.tensorboard_name = '%s-style_%d-seed_%d' % (
                        env_name, args.style, args.seed)
                else:
                    args.tensorboard_name = '%s-%s%s-%d' % (
                    env_name, args.ego, args.alt[0], args.seed)
                    
            elif args.env == "AssistiveMultiEnv-v0":
                args.tensorboard_name = '%s-%s%s-%d' % (
                    env_name, args.ego, args.alt[0], args.seed)
                
            elif args.env == "LBFMultiEnv-v0":
                args.tensorboard_name = '%s-%s%s-%d' % (
                    args.env_config['layout_name'], args.ego, args.alt[0], args.seed)

            else:
                args.tensorboard_name = '%s-%s%s-%d' % (
                    env_name, args.ego, args.alt[0], args.seed)
            
            
        if args.ego_save is None:
            if args.env == "OvercookedMultiEnv-v0":
                args.ego_save = 'diffusion_human_ai/models/%s/ego_%d-%d' % (
                    args.env_config['layout_name'], args.style, args.seed)
            elif args.env == "AssistiveMultiEnv-v0":
                args.ego_save = 'diffusion_human_ai/models/assistive/ego_%s-%d' % (args.env_config["env_name"], args.seed)
            elif args.env == "LBFMultiEnv-v0":
                args.ego_save = 'diffusion_human_ai/models/lbf_spread/ego_%d-%d' % (args.env_config["layout_name"].split('_')[-1].split('-')[0], args.seed)
            else:
                args.ego_save = 'models/%s-%s-ego-%d' % (
                    env_name, args.ego, args.seed)
                
        if args.alt_save is None:
            if args.env == "OvercookedMultiEnv-v0":
                args.alt_save = 'diffusion_human_ai/models/%s/partner_%d-%d' % (
                    args.env_config['layout_name'], args.style, args.seed)
            elif args.env == "AssistiveMultiEnv-v0":
                args.alt_save = 'diffusion_human_ai/models/assistive/partner_%s-%d' % (args.env_config["env_name"], args.seed)
            elif args.env == "LBFMultiEnv-v0":
                args.alt_save = 'diffusion_human_ai/models/lbf_spread/partner_%d-%d' % (args.env_config["layout_name"].split('_')[-1].split('-')[0], args.seed)
            else:
                args.alt_save = 'models/%s-%s-partner-%d' % (
                    env_name, args.ego, args.seed)
        
        if "masked_events" not in args.env_config.keys() and args.style is not None:
            if args.env_config["layout_name"] == "diverse_coordination":
                args.env_config["masked_events"] = DIVERSE_COORPERATION_STYLE_LIST[args.style]
                print("Masked events:", args.env_config["masked_events"])
            elif args.env_config["layout_name"] == "diverse_orders":
                args.env_config["masked_events"] = DIVERSE_ORDERS_STYLE_LIST[args.style]
                print("Masked events:", args.env_config["masked_events"])
            elif args.env_config["layout_name"] == "center_pots":
                args.env_config["masked_events"] = CENTER_POTS_STYLE_LIST[args.style]
                print("Masked events:", args.env_config["masked_events"])
            elif args.env_config["layout_name"] == "crossway":
                args.env_config["masked_events"] = CROSSWAY_STYLE_LIST[args.style]
                print("Masked events:", args.env_config["masked_events"])


    else:
        raise Exception("Invalid preset id")
    return args


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description='''\
            Train ego and partner(s) in an environment.

            Environments:
            -------------
            All MultiAgentEnv environments are supported. Some have additional
            parameters that can be passed into --env-config. Specifically,
            OvercookedMultiEnv-v0 has a required layout_name parameter, so
            one must add:

                --env-config '{"layout_name":"[SELECTED_LAYOUT]"}'

            OvercookedMultiEnv-v0 also has parameters `ego_agent_idx` and
            `baselines`, but these have default initializations. LiarsDice-v0
            has an optional parameter, `probegostart`.

            The environment can be wrapped with a framestack, which transforms
            the observation to stack previous observations as a workaround
            for recurrent networks not being supported. It can also be wrapped
            with a recorder wrapper, which will write the transitions to the
            given file.

            Ego-Agent:
            ----------
            The ego-agent is considered the main agent in the environment.
            From the perspective of the ego agent, the environment functions
            like a regular gym environment.

            Supported ego-agent algorithms include PPO, ModularAlgorithm, ADAP,
            and ADAP_MULT. The default parameters of these algorithms can
            be overriden using --ego-config.

            Alt-Agents:
            -----------
            The alt-agents are the partner agents that are embedded in the
            environment. If multiple are listed, the environment randomly
            samples one of them to be the partner at the start of each episode.

            Supported alt-agent algorithms include PPO, ADAP, ADAP_MULT,
            DEFAULT, and FIXED. DEFAULT refers to the default hand-made policy
            in the environment (if it exists). FIXED refers to a policy that
            has already been saved to a file, and will not learn anymore.

            Default parameters for these algorithms can be overriden using
            --alt-config. For FIXED policies, one must have parameters for
            `type` and `location` to load in the policies. If the FIXED
            policy is an ADAP policy, it must also have a `latent_val`
            parameter.

            NOTE:
            All configs are based on the json format, and will be interpreted
            as dictionaries for the kwargs of their initializers.

            Example usage (Overcooked with ADAP agents that share the latent
            space):

            python3 trainer.py OvercookedMultiEnv-v0 ADAP ADAP --env-config
            '{"layout_name":"random0"}' -l
            ''')

    parser.add_argument('env',
                        choices=ENV_LIST,
                        help='The environment to train in')

    parser.add_argument('ego',
                        choices=EGO_LIST,
                        help='Algorithm for the ego agent')

    parser.add_argument('alt',
                        choices=PARTNER_LIST,
                        nargs='+',
                        help='Algorithm for the partner agent')

    parser.add_argument('--total-timesteps', '-t',
                        type=int,
                        default=600000,
                        help='Number of time steps to run (ego perspective)')

    parser.add_argument('--device', '-d',
                        default='auto',
                        help='Device to run pytorch on')
    parser.add_argument('--seed', '-s',
                        default=0,
                        type=int,
                        help='Seed for randomness')

    parser.add_argument('--ego-config',
                        type=json.loads,
                        default={},
                        help='Config for the ego agent')

    parser.add_argument('--alt-config',
                        type=json.loads,
                        nargs='*',
                        help='Config for the ego agent')

    parser.add_argument('--env-config',
                        type=json.loads,
                        default={},
                        help='Config for the environment')

    # Wrappers
    parser.add_argument('--framestack', '-f',
                        type=int,
                        default=1,
                        help='Number of observations to stack')

    parser.add_argument('--record', '-r',
                        help='Saves joint trajectory into file specified')

    parser.add_argument('--ego-save',
                        help='File to save the ego agent into')
    parser.add_argument('--alt-save',
                        help='File to save the partner agent into')

    parser.add_argument('--share-latent', '-l',
                        action='store_true',
                        help='True when both actors are ADAP and want to sync \
                        latent values')

    parser.add_argument('--tensorboard-log',
                        help='Log directory for tensorboard')

    parser.add_argument('--tensorboard-name',
                        help='Name for ego in tensorboard')

    parser.add_argument('--verbose-partner',
                        action='store_true',
                        help='True when partners should log to output')

    parser.add_argument('--preset', type=int, help='Use preset args')
    
    parser.add_argument('--style', type=int, help='Determine the behavioral style of the partner agent', default=None)
    # parser.add_argument('--ingredient', type=str, default=None)

    parser.add_argument('--multi-task', type=bool, default=False)
    parser.add_argument('--instruction-follow', type=bool, default=False)
    parser.add_argument('--instruction-type', type=str, default='label')
    parser.add_argument('--n_steps_per_iteration', type=int, default=10000)
    args = parser.parse_args()

    print(args)
    if args.preset:
        args = preset(args, args.preset)
    input_check(args)


    print(f"Arguments: {args}")

    if args.env == "OvercookedMultiEnv-v0":
        env, altenv = generate_env(args)
        print(f"Environment: {env}; Partner env: {altenv}")
        ego = generate_ego(env, args)
        print(f'Ego: {ego}')
        partners = generate_partners(altenv, env, ego, args)
        
        learn_config = {'total_timesteps': args.total_timesteps}
        if args.tensorboard_log:
            learn_config['tb_log_name'] = args.tensorboard_name
        ego.learn(**learn_config)

        if args.record:
            transition = env.get_transitions()
            transition.write_transition(args.record)

        if args.ego_save:
            ego.save(args.ego_save)
        if args.alt_save:
            if len(partners) == 1:
                try:
                    partners[0].model.save(args.alt_save)
                except AttributeError:
                    print("FIXED or DEFAULT partners are not saved")
            else:
                for i in range(len(partners)):
                    try:
                        partners[i].model.save(f"{args.alt_save}/{i}")
                    except AttributeError:
                        print("FIXED or DEFAULT partners are not saved")
                        
    elif args.env == "AssistiveMultiEnv-v0":
        env = generate_subprocenv(args)
        agent_pair = generate_agent_pair(env, args)

        learn_config = {'total_timesteps': args.total_timesteps}
        if args.tensorboard_log:
            learn_config['tb_log_name'] = args.tensorboard_name
        agent_pair.learn(**learn_config)

        if args.ego_save:
            ego = agent_pair.agent_1
            ego.save(args.ego_save)
        if args.alt_save:
            partner = agent_pair.agent_2
            partner.save(args.alt_save)

    elif args.env == "LBFMultiEnv-v0":
        env = gym.make(args.env, **args.env_config)
        altenv = env.getDummyEnv(1)
        args.ego_config['learning_rate'] = 3e-4
        args.ego_config['n_steps'] = 1024
        for i in range(len(args.alt)):
            args.alt_config[i]['learning_rate'] = args.ego_config['learning_rate']
            args.alt_config[i]['n_steps'] = args.ego_config['n_steps']
        ego = generate_ego(env, args)
        print(f'Ego: {ego}')
        partners = generate_partners(altenv, env, ego, args)
        
        learn_config = {'total_timesteps': args.total_timesteps}
        if args.tensorboard_log:
            learn_config['tb_log_name'] = args.tensorboard_name
        ego.learn(**learn_config)

        if args.record:
            transition = env.get_transitions()
            transition.write_transition(args.record)

        if args.ego_save:
            ego.save(args.ego_save)
        if args.alt_save:
            if len(partners) == 1:
                try:
                    partners[0].model.save(args.alt_save)
                except AttributeError:
                    print("FIXED or DEFAULT partners are not saved")
            else:
                for i in range(len(partners)):
                    try:
                        partners[i].model.save(f"{args.alt_save}/{i}")
                    except AttributeError:
                        print("FIXED or DEFAULT partners are not saved")


