import sys
sys.path.insert(0, './')
from rlf.rl.loggers.base_logger import BaseLogger
from rlf.rl.loggers.wb_logger import WbLogger, get_wb_ray_kwargs, get_wb_ray_config
from rlf import run_policy
from rlf.run_settings import RunSettings

from rlf import DistActorCritic, BasicPolicy
from rlf import RandomPolicy
from rlf.policies.action_replay_policy import ActionReplayPolicy
from rlf import PPO, GAIL, BehavioralCloning, BaseAlgo
from rlf import GailDiscrim
from rlf import BehavioralCloningPretrain, BehavioralCloningFromObs
from rlf.rl.model import MLPBasic, CNNBase
from rlf.algos.il.base_il import BaseILAlgo
from goal_prox.baselines.good_gail import GoodGAIL, GoodGailDiscrim
from goal_prox.baselines.good_gaifo import GoodGAIFO
from goal_prox.baselines.gail_ev import GAILEv
from goal_prox.baselines.gail_ev_uncert import GAILEvUncert
from rlf.algos.nested_algo import NestedAlgo
from rlf.args import str2bool


import goal_prox.gym_minigrid

import goal_prox.envs.ball_in_cup
import goal_prox.envs.goal_check
import goal_prox.envs.gridworld

from goal_prox.method.discounted_pf import DiscountedProxIL, DiscountedProxFunc
from goal_prox.method.dist_discounted_pf import CatDistDiscountedProxIL, NormDistDiscountedProxIL
from goal_prox.method.ranked_pf import RankedProxIL
from goal_prox.method.airl import ProxAirl
from goal_prox.policies.grid_world_expert import GridWorldExpert
from goal_prox.baselines.dyn_dist import DynDistIL
from goal_prox.baselines.gw_exp_generator import GwExpGenerator
from goal_prox.envs.goal_traj_saver import GoalTrajSaver
from goal_prox.method.utils import trim_episodes_trans

import goal_prox.envs.fetch

from goal_prox.models import GwImgEncoder
from rlf.algos.il.gaifo import GAIFO
from functools import partial

def get_ppo_policy(env_name, args):
    if env_name.startswith('MiniGrid') and args.gw_img:
        return DistActorCritic(
            get_base_net_fn=lambda i_shape: GwImgEncoder(i_shape)
            )
    elif env_name.startswith('Vizdoom'):
        return DistActorCritic(
            get_base_net_fn=lambda i_shape: VizDoomEncoder(i_shape)
            )

    return DistActorCritic()

def get_deep_ppo_policy(env_name, args):
    return DistActorCritic(
                get_actor_fn=lambda _, i_shape: MLPBasic(i_shape[0],
                    hidden_size=256,num_layers=2),
                get_critic_fn=lambda _, i_shape, asp: MLPBasic(i_shape[0],
                    hidden_size=256,num_layers=2),
                )


def get_basic_policy(env_name, args, is_stoch):
    if env_name.startswith('MiniGrid') and args.gw_img:
        return BasicPolicy(
                is_stoch=is_stoch,
                get_base_net_fn=lambda i_shape: GwImgEncoder(i_shape)
                )
    else:
        return BasicPolicy(
                is_stoch=is_stoch,
                get_base_net_fn=lambda i_shape: MLPBasic(i_shape[0],
                    hidden_size=256,
                    num_layers=2)
                )

    return BasicPolicy()


def get_deep_basic_policy(env_name, args):
    return BasicPolicy(
            get_base_net_fn=lambda i_shape: MLPBase(i_shape[0],
                False, (512, 512, 256, 128))
            )


class DpfBcPretrain(BehavioralCloningPretrain):
    def __init__(self, agent_updater=PPO()):
        super().__init__([DiscountedProxFunc(), agent_updater],
                designated_rl_idx=1)

class GailBcPretrain(BehavioralCloningPretrain):
    def __init__(self, agent_updater=PPO()):
        super().__init__([GailDiscrim(), agent_updater],
                designated_rl_idx=1)

class GoodGailBcPretrain(BehavioralCloningPretrain):
    def __init__(self, agent_updater=PPO()):
        super().__init__([GoodGailDiscrim(), agent_updater],
                designated_rl_idx=1)


def get_setup_dict():
    return {
        'gail': (GAIL(), get_ppo_policy),
        'gail-deep': (GAIL(), get_deep_ppo_policy),

        'gail-ev': (GAILEv(), get_deep_ppo_policy),
        'gail-ev-uncert': (GAILEvUncert(), get_deep_ppo_policy),

        'gaifo': (GAIFO(), get_ppo_policy),
        'gaifo-deep': (GAIFO(), get_deep_ppo_policy),
        'good-gaifo': (GoodGAIFO(), get_ppo_policy),
        'good-gaifo-deep': (GoodGAIFO(), get_deep_ppo_policy),

        'bc-gail': (GailBcPretrain(), get_ppo_policy),
        'bc-gail-deep': (GailBcPretrain(), get_deep_ppo_policy),

        'good-gail': (GoodGAIL(), get_ppo_policy),
        'good-gail-deep': (GoodGAIL(), get_deep_ppo_policy),
        'bc-good-gail-deep': (GoodGailBcPretrain(), get_deep_ppo_policy),

        'dyn-dist': (DynDistIL(), get_ppo_policy),
        'ppo': (PPO(), get_ppo_policy),
        'ppo-deep': (PPO(), get_deep_ppo_policy),
        'gw-exp': (BaseAlgo(), lambda env_name, _: GridWorldExpert()),
        'action-replay': (BaseAlgo(), lambda env_name, _: ActionReplayPolicy()),
        'gen-bc': (BehavioralCloning(GwExpGenerator()), get_basic_policy),
        'cat-dist-dpf': (CatDistDiscountedProxIL(), get_ppo_policy),
        'norm-dist-dpf': (NormDistDiscountedProxIL(), get_ppo_policy),
        'rnd': (BaseAlgo(), lambda env_name, _: RandomPolicy()),

        'bc': (BehavioralCloning(), partial(get_basic_policy, is_stoch=False)),
        'bco': (BehavioralCloningFromObs(), partial(get_basic_policy, is_stoch=True)),
        'bc-deep': (BehavioralCloning(), get_deep_basic_policy),

        'dpf': (DiscountedProxIL(), get_ppo_policy),
        'bc-dpf': (DpfBcPretrain(), get_ppo_policy),

        'dpf-deep': (DiscountedProxIL(), get_deep_ppo_policy),
        'dpf-deep-im': (DiscountedProxIL(
            get_pf_base=lambda i_shape: CNNBase(i_shape[0], False, 256),),
            get_deep_ppo_policy),
        'bc-dpf-deep': (DpfBcPretrain(), get_deep_ppo_policy),

        'prox-deep': (ProxAirl(), get_deep_ppo_policy),

        'rpf': (RankedProxIL(), get_ppo_policy),
        'rpf-deep': (RankedProxIL(), get_deep_ppo_policy),
    }


class GoalProxSettings(RunSettings):
    def get_policy(self):
        return get_setup_dict()[self.base_args.alg][1](self.base_args.env_name,
                self.base_args)

    def create_traj_saver(self, save_path):
        return GoalTrajSaver(save_path, False)

    def get_algo(self):
        algo = get_setup_dict()[self.base_args.alg][0]
        if isinstance(algo, NestedAlgo) and isinstance(algo.modules[0], BaseILAlgo):
            algo.modules[0].set_transform_dem_dataset_fn(trim_episodes_trans)
        return algo

    def get_logger(self):
        if self.base_args.no_wb:
            return BaseLogger()
        else:
            return WbLogger()

    def get_add_args(self, parser):
        parser.add_argument('--alg')
        parser.add_argument('--env-name')
        parser.add_argument('--gw-img', action='store_true', default=False)
        parser.add_argument('--no-wb', action='store_true', default=False)

    def import_add(self):
        import goal_prox.envs.fetch
        import goal_prox.envs.goal_check

    def get_add_ray_config(self, config):
        if self.base_args.no_wb:
            return config
        return get_wb_ray_config(config)

    def get_add_ray_kwargs(self):
        if self.base_args.no_wb:
            return {}
        return get_wb_ray_kwargs()


if __name__ == '__main__':
    run_policy(GoalProxSettings())
