from rlf.algos.il.base_irl import BaseIRLAlgo
import torch
import torch.nn as nn
import torch.nn.functional as F
import rlf.rl.utils as rutils
import rlf.algos.utils as autils
from collections import defaultdict
from rlf.baselines.common.running_mean_std import RunningMeanStd
from rlf.algos.nested_algo import NestedAlgo
from rlf.algos.on_policy.ppo import PPO
from rlf.args import str2bool
import torch.optim as optim
import numpy as np
from rlf.rl.model import ConcatLayer
from rlf.rl.model import InjectNet
from functools import partial
from rlf.exp_mgr.viz_utils import append_text_to_image

import os.path as osp
from goal_prox.envs.gw_helper import GwIQPlotter
from goal_prox.envs.debug_viz import DebugViz, LineDebugViz


def get_visualizer(args, policy, viz_type):
    save_dir = osp.join(args.save_dir, args.env_name, args.prefix)
    if viz_type == 'gw':
        return GwIQPlotter(save_dir, args,
                rutils.get_obs_shape(policy.obs_space))
    elif viz_type is None:
        return DebugViz(save_dir, args)
    else:
        raise ValueError(f"Unexpected viz type {viz_type}")


def get_default_discrim():
    """
    - ac_dim: int will be 0 if no action are used.
    Returns: (nn.Module) Should take state AND actions as input if ac_dim
    != 0. If ac_dim = 0 (discriminator does not use actions) then ONLY take
    state as input.
    """
    hidden_dim = 64
    layers = [
            #nn.Linear(in_shape[0] + ac_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, 1)
            ]

    return nn.Sequential(*layers), hidden_dim

class IQReward(NestedAlgo):
    def __init__(self, agent_updater=PPO()):
        super().__init__([IQDiscrim(), agent_updater], 1)

class IQDiscrim(BaseIRLAlgo):
    def __init__(self):
        super().__init__()
        self.iq_reward_model = None

    # def _create_discrim(self):
    #     ob_shape = rutils.get_obs_shape(self.policy.obs_space)
    #     ac_dim = rutils.get_ac_dim(self.action_space)
    #     base_net = self.policy.get_base_net_fn(ob_shape)
    #     discrim, dhidden_dim = self.get_discrim()
    #     discrim_head = InjectNet(
    #         base_net.net,
    #         discrim,
    #         base_net.output_shape[0], dhidden_dim, ac_dim,
    #         self.args.action_input)

    #     return discrim_head.to(self.args.device)

    def init(self, policy, args):
        from iq_learn.make_reward_models import make_reward_model

        if self.exp_generator is None:
            self._load_expert_data(policy, args)
        else:
            self.exp_generator.init(policy, args, args.exp_gen_num_trans)
            print(f"Generating {args.exp_gen_num_trans} transitions for imitation")
        super().init(policy, args)
        self.args = args
        # load iq-reward model
        device = args.device
        self.iq_reward_model = make_reward_model(self.policy.obs_space, self.policy.action_space, device, args, load_pretrained_model=True)
        

        self.debug_viz = get_visualizer(args, policy, args.iq_viz_type)
    

    def first_train(self, log, eval_policy, env_interface):
        # Save a figure of the loss curve

        self.debug_viz.plot(0, ["expert"], self._get_plot_funcs())

        rutils.pend_sep()


    def _get_irl_V(self, state):
        with torch.no_grad():
            reward = self.iq_reward_model.getV(state)
        return reward
    
    def _get_iq_action(self, state):
        with torch.no_grad():
            qval = self.iq_reward_model.iq_model.choose_action(state, sample=False)
        return qval
    

    def _get_iq_reward(self, state, action, next_state, done):
        with torch.no_grad():
            qval = self.iq_reward_model(state, action, next_state, done)
        return qval.item()

    def _get_plot_funcs(self):
        if self.args.reward_gen.model.type == "iqgen":
            plot_funcs = {
                "irlV": partial(self._get_irl_V),
                "getAction": partial(self._get_iq_action),
                "getReward": partial(self._get_iq_reward),
            }
        if self.args.reward_gen.model.type == "dist":
            plot_funcs = {
                "irlV": partial(self._get_irl_V),
                "getReward": partial(self._get_iq_reward),
            }
        else:
            plot_funcs = {
                # "irlV": partial(self._get_irl_V),
                # "getAction": partial(self._get_iq_action),
                "getReward": partial(self._get_iq_reward),
            }
        return plot_funcs

    # def _get_sampler(self, storage):
    #     agent_experience = storage.get_generator(None,
    #                                              mini_batch_size=self.expert_train_loader.batch_size)
    #     return self.expert_train_loader, agent_experience

    # def _trans_batches(self, expert_batch, agent_batch):
    #     return expert_batch, agent_batch

    # def get_env_settings(self, args):
    #     settings = super().get_env_settings(args)
    #     if not args.gail_state_norm:
    #         settings.ret_raw_obs = True
    #     settings.mod_render_frames_fn = self.mod_render_frames
    #     return settings

    # def mod_render_frames(self, frame, env_cur_obs, env_cur_action, env_cur_reward,
    #         env_next_obs, **kwargs):
    #     use_cur_obs = rutils.get_def_obs(env_cur_obs)
    #     use_cur_obs = torch.FloatTensor(use_cur_obs).unsqueeze(0).to(self.args.device)

    #     if env_cur_action is not None:
    #         # if env_cur_action is a scalar, make it a one dim tensor
    #         if not isinstance(env_cur_action, torch.Tensor) and not isinstance(env_cur_action, np.ndarray):
    #             env_cur_action = torch.Tensor([env_cur_action])
    #         use_action = env_cur_action.unsqueeze(0).to(self.args.device)
    #         use_action = self._adjust_action(use_action)
    #         use_action = rutils.get_ac_repr(
    #             self.action_space, use_action)
    #         disc_val = self._compute_disc_val(use_cur_obs, use_action).item()
    #     else:
    #         disc_val = 0.0

    #     frame = append_text_to_image(frame, [
    #         "Discrim: %.3f" % disc_val,
    #         "Reward: %.3f" % (env_cur_reward if env_cur_reward is not None else 0.0)
    #         ])
    #     return frame

    # def _norm_expert_state(self, state, obsfilt):
    #     if not self.args.gail_state_norm:
    #         return state
    #     state = state.cpu().numpy()

    #     if obsfilt is not None:
    #         state = obsfilt(state, update=False)
    #     state = torch.tensor(state).to(self.args.device)
    #     return state

    def _trans_agent_state(self, state, other_state=None):
        # if not self.args.gail_state_norm:
        #     if other_state is None:
        #         return state['raw_obs']
        #     return other_state['raw_obs']
        return rutils.get_def_obs(state)

    # def _compute_discrim_loss(self, agent_batch, expert_batch, obsfilt):
    #     expert_actions = expert_batch['actions'].to(self.args.device)
    #     expert_actions = self._adjust_action(expert_actions)
    #     expert_states = self._norm_expert_state(expert_batch['state'],
    #             obsfilt)

    #     agent_states = self._trans_agent_state(agent_batch['state'],
    #             agent_batch['other_state'] if 'other_state' in agent_batch else None)
    #     agent_actions = agent_batch['action']

    #     agent_actions = rutils.get_ac_repr(
    #         self.action_space, agent_actions)
    #     expert_actions = rutils.get_ac_repr(
    #         self.action_space, expert_actions)

    #     expert_d = self._compute_disc_val(expert_states, expert_actions)
    #     agent_d = self._compute_disc_val(agent_states, agent_actions)

    #     grad_pen = self.compute_pen(expert_states, expert_actions, agent_states,
    #             agent_actions)

    #     return expert_d, agent_d, grad_pen

    # def compute_pen(self, expert_states, expert_actions, agent_states, agent_actions):
    #     grad_pen = self.args.disc_grad_pen * autils.wass_grad_pen(expert_states,
    #             expert_actions, agent_states, agent_actions,
    #             self.args.action_input, self._compute_disc_val)
    #     return grad_pen

    # def _compute_disc_val(self, state, action):
    #     return self.discrim_net(state, action)

    # def _compute_expert_loss(self, expert_d, expert_batch):
    #     return F.binary_cross_entropy_with_logits(expert_d,
    #             torch.ones(expert_d.shape).to(self.args.device))

    # def _compute_agent_loss(self, agent_d, agent_batch):
    #     return F.binary_cross_entropy_with_logits(agent_d,
    #             torch.zeros(agent_d.shape).to(self.args.device))

    def _update_reward_func(self, storage):
        return {}

    # def _compute_discrim_reward(self, storage, step, add_info):
    #     state = self._trans_agent_state(storage.get_obs(step))
    #     action = storage.actions[step]
    #     action = rutils.get_ac_repr(self.action_space, action)
    #     d_val = self._compute_disc_val(state, action)
    #     s = torch.sigmoid(d_val)
    #     eps = 1e-20
    #     if self.args.reward_type == 'airl':
    #         reward = (s + eps).log() - (1 - s + eps).log()
    #     elif self.args.reward_type == 'gail':
    #         reward = (s + eps).log()
    #     elif self.args.reward_type == 'raw':
    #         reward = d_val
    #     else:
    #         raise ValueError(f"Unrecognized reward type {self.args.reward_type}")
    #     return reward

    def _get_reward(self, step, storage, add_info):

        state = self._trans_agent_state(storage.get_obs(step))
        next_state = self._trans_agent_state(storage.get_next_obs(step))
        action = storage.actions[step]
        done = storage.dones[step]
        # action = rutils.get_ac_repr(self.action_space, action)

        with torch.no_grad():
            reward = self.iq_reward_model(state, action, next_state, done)
        return reward, {}

    # def _load_expert_data(self, policy, args):
    #     return

    def get_add_args(self, parser):
        super().get_add_args(parser)
        # add iq-reward-model arguments  
         
        # reward_gen:
        # model:
        #     lr: 1e-3
        #     type: "basic" # 'basic' or 'dist' or 'iqgen'
        #     basic:
        #     proximity:
            #     num_ensembles: 1
            #     panel_var: False
        parser.add_argument('--pretrain', type=str, default=None)
        
        parser.add_argument('--exp.gamma_scale', type=str2bool, default=False)
        parser.add_argument('--exp.gamma', type=float, default=None)
        
        parser.add_argument('--reward_gen.model.lr', type=float, default=1e-3)
        parser.add_argument('--reward_gen.model.type', type=str)
        parser.add_argument('--reward_gen.model.load_path', type=str, default=None)

        parser.add_argument('--reward_gen.model.basic.input_config', type=str, default=None) #'sas', 'sa', 'ss', 's', 'ns'

        parser.add_argument('--reward_gen.model.dist.type', type=str, default='proximity')
        parser.add_argument('--reward_gen.model.dist.num_ensembles', type=int, default=5)
        parser.add_argument('--reward_gen.model.dist.panel_var', type=str2bool, default=True)
        parser.add_argument('--reward_gen.model.dist.mse_before_mean', type=str2bool, default=False)
        parser.add_argument('--reward_gen.model.dist.var_for_both_train_and_eval', type=str2bool, default=True)

        parser.add_argument('--reward_gen.model.reg.type', type=str, default='dist_constraint')
        parser.add_argument('--reward_gen.model.reg.coef', type=float, default=1e-3)
        
        parser.add_argument('--reward_gen.train.batch', type=int, default=32)
        parser.add_argument('--reward_gen.train.learn_steps', type=int, default=10000)
        parser.add_argument('--reward_gen.train.eval_interval', type=int, default=100)
        parser.add_argument('--reward_gen.train.log_interval', type=int, default=100)
        parser.add_argument('--reward_gen.train.save_interval', type=int, default=100)
        parser.add_argument('--reward_gen.train.video_eval_interval', type=int, default=1000)

        # add_online_data: True
        # online_data_size: 500000 # original dataset size = 13556
        # online_rand_threshold: 1.0
        parser.add_argument('--reward_gen.train.add_online_data', type=str2bool, default=True)
        parser.add_argument('--reward_gen.train.online_data_size', type=int, default=500000)
        parser.add_argument('--reward_gen.train.online_rand_threshold', type=float, default=1.0)

        # agent:
        parser.add_argument('--agent.name', type=str, default="softq")
        parser.add_argument('--agent.class', type=str, default="agent.softq.SoftQ")
        parser.add_argument('--agent.obs_dim', type=str, default=None)
        parser.add_argument('--agent.action_dim', type=str, default=None)
        # parser.add_argument('--agent.critic_cfg', type=str, default="${q_net}")
        parser.add_argument('--agent.critic_lr', type=float, default=1e-4)
        parser.add_argument('--agent.critic_betas', type=list, default=[0.9, 0.999])
        parser.add_argument('--agent.init_temp', type=float, default=0.01)
        parser.add_argument('--agent.critic_target_update_frequency', type=int, default=4)
        parser.add_argument('--agent.critic_tau', type=float, default=0.1)

        parser.add_argument('--q_net.hidden_dim', type=int, default=256)
        parser.add_argument('--q_net.hidden_depth', type=int, default=2)
        parser.add_argument('--q_net._target_', type=str, default="MiniGridQNetwork")

        parser.add_argument('--train.batch', type=int, default=32)

        parser.add_argument('--method.type', type=str, default="iq")
        parser.add_argument('--method.loss', type=str, default="value")
        parser.add_argument('--method.constrain', type=str2bool, default=False)
        parser.add_argument('--method.grad_pen', type=str2bool, default=False)
        parser.add_argument('--method.chi', type=str2bool, default=False)
        parser.add_argument('--method.tanh', type=str2bool, default=False)
        parser.add_argument('--method.regularize', type=str2bool, default=False)
        parser.add_argument('--method.div', type=str, default=None)

        parser.add_argument('--method.alpha', type=float, default=0.5)
        parser.add_argument('--method.lambda_gp', type=float, default=10)
        parser.add_argument('--method.mix_coeff', type=float, default=1)

        parser.add_argument('--iq-viz-type', type=str, default=None)



