from rlf.algos.il.base_irl import BaseIRLAlgo
import torch
import torch.nn as nn
import torch.nn.functional as F
from rlf.rl import utils
from rlf.rl.model import InjectNet
from collections import defaultdict
from rlf.baselines.common.running_mean_std import RunningMeanStd
from rlf.algos.nested_algo import NestedAlgo
from rlf.algos.on_policy.ppo import PPO
from rlf.args import str2bool
import torch.optim as optim
from torch import autograd
import numpy as np
from goal_prox.method.value_traj_dataset import linear_discounted, exp_discounted
from functools import partial
from rlf.comps.ensemble import Ensemble


def get_default_discrim(hidden_dim=64):
    return nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
        nn.Linear(hidden_dim, 1)), hidden_dim


class GAILEvUncert(NestedAlgo):
    def __init__(self, agent_updater=PPO(), get_discrim=None):
        super().__init__([GailEvUncertDiscrim(get_discrim), agent_updater], 1)


class GailEvUncertDiscrim(BaseIRLAlgo):
    def __init__(self, get_discrim=None):
        super().__init__()
        if get_discrim is None:
            get_discrim = get_default_discrim
        self.get_base_discrim = get_discrim

    def _create_discrim(self):
        def create_discrim():
            state_enc = self.policy.get_base_net_fn(
                utils.get_obs_shape(self.policy.obs_space))
            discrim_head, in_dim = self.get_base_discrim()
            return InjectNet(state_enc.net, discrim_head,
                             state_enc.output_shape[0], in_dim,
                             utils.get_ac_dim(self.policy.action_space),
                             self.args.action_input).to(self.args.device)
        return Ensemble(create_discrim, 5)

    def init(self, policy, args):
        super().init(policy, args)
        self.action_space = self.policy.action_space

        self.discrim_net = self._create_discrim()

        self.returns = None
        self.ret_rms = RunningMeanStd(shape=())

        self.opt = optim.Adam(
            self.discrim_net.parameters(), lr=self.args.disc_lr)

    def _get_traj_dataset(self, traj_load_path):
        return StateValueTrajDataset(traj_load_path,
                partial(linear_discounted, delta=0.02))

    def _get_sampler(self, storage):
        agent_experience = storage.get_generator(None,
                                                 mini_batch_size=self.expert_train_loader.batch_size)
        return self.expert_train_loader, agent_experience

    def _trans_batches(self, expert_batch, agent_batch):
        return expert_batch, agent_batch

    def _norm_expert_state(self, state, obsfilt):
        state = state.cpu().numpy()
        if obsfilt is not None:
            state = obsfilt(state, update=False)
        state = torch.tensor(state).to(self.args.device)
        return state

    def _compute_discrim_loss(self, agent_batch, expert_batch, obsfilt):
        expert_states = self._norm_expert_state(expert_batch['state'],
                obsfilt)

        agent_states = agent_batch['state']
        expert_d = self.discrim_net(expert_states, None)
        agent_d = self.discrim_net(agent_states, None)

        return expert_d, agent_d

    def get_env_settings(self, args):
        settings = super().get_env_settings(args)
        settings.include_info_keys.extend([
            ('ep_found_goal', lambda _: (1,)),
            ('final_obs', lambda env: utils.get_obs_shape(env.observation_space))
            ])
        return settings

    def _update_reward_func(self, storage):
        self.discrim_net.train()

        d = self.args.device
        log_vals = defaultdict(lambda: 0)
        obsfilt = self.get_env_ob_filt()

        n = 0
        expert_sampler, agent_sampler = self._get_sampler(storage)
        for epoch_i in range(self.args.n_gail_epochs):
            for expert_batch, agent_batch in zip(expert_sampler, agent_sampler):
                expert_batch, agent_batch = self._trans_batches(
                    expert_batch, agent_batch)
                n += 1
                expert_d, agent_d = self._compute_discrim_loss(agent_batch, expert_batch, obsfilt)
                n_ensembles = expert_d.shape[0]
                #expert_loss = F.mse_loss(torch.sigmoid(expert_d).view(n_ensembles, -1),
                #        expert_batch['prox'].view(1, -1).repeat(n_ensembles, 1).to(d))
                #agent_loss = F.mse_loss(torch.sigmoid(agent_d).view(n_ensembles, -1),
                #        torch.zeros(1, agent_d.shape[1]).repeat(n_ensembles, 1).to(d))
                expert_loss = F.mse_loss(expert_d.view(n_ensembles, -1),
                        expert_batch['prox'].view(1, -1).repeat(n_ensembles, 1).to(d))
                agent_loss = F.mse_loss(agent_d.view(n_ensembles, -1),
                        torch.zeros(1, agent_d.shape[1]).repeat(n_ensembles, 1).to(d))
                discrim_loss = expert_loss + agent_loss

                self.opt.zero_grad()
                discrim_loss.backward()
                self.opt.step()

                log_vals['discrim_loss'] += discrim_loss.item()
                log_vals['expert_loss'] += expert_loss.item()
                log_vals['agent_loss'] += agent_loss.item()

        for k in log_vals:
            log_vals[k] /= n

        return log_vals

    def _compute_discrim_reward(self, storage, step, add_info):
        #state = utils.get_def_obs(storage.get_obs(step+1))
        #d_val = self.discrim_net(state, None).mean(0)
        #s = torch.sigmoid(d_val)
        #eps = 1e-20
        #reward = (s + eps).log() - (1 - s + eps).log()
        #return reward

        def get_use_state(idx, sub_final):
            state = utils.get_def_obs(storage.get_obs(idx))
            state = state.clone()
            if sub_final:
                masks = storage.masks[idx]
                finished_episodes = [i for i in range(len(masks)) if masks[i] == 0.0]
                add_inputs = {k: v[idx-1] for k,v in add_info.items()}
                for i in finished_episodes:
                    state[i] = add_inputs['final_obs'][i]
            return state

        cur_state = get_use_state(step, False)
        next_masks = storage.masks[step+1]
        next_state = get_use_state(step+1, True)
        #cur_prox = torch.sigmoid(self.discrim_net(cur_state, None))
        #next_prox = torch.sigmoid(self.discrim_net(next_state, None))

        cur_uncert = self.discrim_net(cur_state, None).std(0)
        next_uncert = self.discrim_net(next_state, None).std(0)
        uncert = next_uncert

        cur_prox = torch.clamp(self.discrim_net(cur_state, None).mean(0), 0.0, 1.0)
        next_prox = torch.clamp(self.discrim_net(next_state, None).mean(0), 0.0, 1.0)
        #return (next_prox + eps).log() - (cur_prox + eps).log()
        #return (cur_prox + eps).log() - (1 - cur_prox + eps).log()
        #return next_prox - uncert
        #return (next_prox + eps).log()

        return -1.0 * (1.0 - next_prox) - (self.args.uncert_scale * next_uncert)
        #return (-1.0 * (1.0 - next_prox) + eps).log()
        #diff_prox_reward = (next_prox - cur_prox)
        #final_prox_reward = next_prox * (1.0 - next_masks)
        #return diff_prox_reward
        #return diff_prox_reward + final_prox_reward - uncert
        #return (diff_prox_reward + final_prox_reward + eps).log()

    def _get_reward(self, step, storage, add_info):
        masks = storage.masks[step]
        with torch.no_grad():
            self.discrim_net.eval()
            reward = self._compute_discrim_reward(storage, step, add_info)

            if self.args.gail_reward_norm:
                if self.returns is None:
                    self.returns = reward.clone()

                self.returns = self.returns * masks * self.args.gamma + reward
                self.ret_rms.update(self.returns.cpu().numpy())

                return reward / np.sqrt(self.ret_rms.var[0] + 1e-8), {}
            else:
                return reward, {}

    def get_add_args(self, parser):
        super().get_add_args(parser)
        #########################################
        # Overrides

        #########################################
        # New args
        parser.add_argument('--action-input', type=str2bool, default=False)
        parser.add_argument('--gail-reward-norm', type=str2bool, default=True)
        parser.add_argument('--disc-lr', type=float, default=0.001)
        parser.add_argument('--uncert-scale', type=float, default=0.01)
        parser.add_argument('--n-gail-epochs', type=int, default=1)

    def load_resume(self, checkpointer):
        super().load_resume(checkpointer)
        self.opt.load_state_dict(checkpointer.get_key('gail_disc_opt'))
        self.discrim_net.load_state_dict(checkpointer.get_key('gail_disc'))

    def save(self, checkpointer):
        super().save(checkpointer)
        checkpointer.save_key('gail_disc_opt', self.opt.state_dict())
        checkpointer.save_key('gail_disc', self.discrim_net.state_dict())
