from iq_learn.reward_models.proximity_reward_model import DistReward, IQGenReward, BasicReward, RegReward
import torch

reward_model_load_path = {
    'MiniGrid-FourRooms-v0': {
    },
    'MiniGrid-Deceptive-v0': {
    },
}

gamma_scale_load_path = {
    'MiniGrid-FourRooms-v0': {
    },
    'MiniGrid-Deceptive-v0':{
    }
}

def get_reward_model_path(args):
    if args.reward_gen.model.load_path:
        load_path = args.reward_gen.model.load_path
    elif args.exp.gamma_scale:
        load_path = gamma_scale_load_path[args.env_name][args.exp.gamma]
    else:
        if args.reward_gen.model.type == 'basic':
            load_path = reward_model_load_path[args.env_name]['basic'][args.reward_gen.model.basic.input_config]
        elif args.reward_gen.model.type == 'dist':
            # if not args.reward_gen.model.dist.panel_var:
            #     load_path = reward_model_load_path[args.env.name]['dist'][0]
            # elif args.reward_gen.model.dist.var_for_both_train_and_eval:
                # load_path = reward_model_load_path[args.env.name]['dist'][1]
            load_path = reward_model_load_path[args.env_name]['dist']
        elif args.reward_gen.model.type == 'reg':
            if args.reward_gen.model.reg.type == 'none':
                load_path = reward_model_load_path[args.env_name]['reg']['none']
            elif args.reward_gen.model.reg.type == 'dist_constraint':
                load_path = reward_model_load_path[args.env_name]['reg']['dist_constraint']
            elif args.reward_gen.model.reg.type == 'dim_reduction':
                load_path = reward_model_load_path[args.env_name]['reg']['dim_reduction']
        elif args.reward_gen.model.type == 'iq_gen':
            load_path = reward_model_load_path[args.env_name][args.reward_gen.model.type]
        else:
            load_path = None
    return load_path


def make_reward_model(ob_space, action_space, device, args, load_pretrained_model=False):
    if args.reward_gen.model.type == 'iqgen':
        reward_model = IQGenReward(ob_space=ob_space, ac_space=action_space, args=args)
    elif args.reward_gen.model.type == 'basic':
        reward_model = BasicReward(ob_space=ob_space, ac_space=action_space, args=args)
    elif args.reward_gen.model.type == 'dist':
        reward_model = DistReward(ob_space=ob_space, ac_space=action_space, args=args)
    elif args.reward_gen.model.type == 'reg':
        reward_model = RegReward(ob_space=ob_space, ac_space=action_space, args=args)
    else:
        raise NotImplementedError
    reward_model.to(device)

    if load_pretrained_model:
        load_path = get_reward_model_path(args)
        if load_path and load_path != '':
            reward_model.load_state_dict(torch.load(load_path))
            print(f'Loaded reward model from {load_path}')
        else:
            print("No pretrain reward model loaded")
    else:
        print("Not loading reward model")
    return reward_model