from rlf.algos.nested_algo import NestedAlgo
from rlf.algos.on_policy.ppo import PPO
from goal_prox.method.prox_func import ProxFunc
from goal_prox.method.value_traj_dataset import *
import torch.nn.functional as F
import torch
import torch.nn as nn
from functools import partial
from goal_prox.method.discounted_pf import DiscountedProxFunc
from rlf.rl.distributions import Categorical, FixedCategorical, DiagGaussian


class CatDistDiscountedProxIL(NestedAlgo):
    def __init__(self, agent_updater=PPO()):
        super().__init__([CatDistDiscountedProxFunc(), agent_updater], designated_rl_idx=1)

class NormDistDiscountedProxIL(NestedAlgo):
    def __init__(self, agent_updater=PPO()):
        super().__init__([NormDistDiscountedProxFunc(), agent_updater], designated_rl_idx=1)


class DistDiscountedProxFunc(DiscountedProxFunc):
    def _convert_expert_dataset(self):
        for i in range(len(self.expert_dataset)):
            d = self.expert_dataset.data[i]
            prox_lbl = self.label_data([d[1]])[0]
            self.expert_dataset.data[i] = (d[0], prox_lbl, *d[2:])

    def label_data(self, prox_vals):
        pass

    def delabel_data(self, idx):
        pass

    def _compute_dist_loss(self, guess_prox_dist, proximity):
        pass

    def _prox_func_iter(self, data_batch):
        states = data_batch['state'].to(self.args.device)
        proximity = data_batch['prox'].to(self.args.device)
        actions = None
        if 'action' in data_batch:
            actions = data_batch['action'].to(self.args.device)

        states = self._preproc_pf_input(states)
        guess_prox_dist = self.prox_func(states, actions)

        return self._compute_dist_loss(guess_prox_dist, proximity)

    def _get_prox_uncert(self, state, action):
        states = self._preproc_pf_input(state)
        prox_dist = self.prox_func(state, action)
        pval = torch.stack([dist.sample() for dist in prox_dist])
        return pval.std(0)

    def _get_prox(self, state, action, should_clip):
        state = self._preproc_pf_input(state)
        prox_dists = self.prox_func(state, action)
        samples = [prox_dist.sample() for prox_dist in prox_dists]
        pval = torch.stack([self.delabel_data(sample) for sample in samples])
        pval = pval.mean(0)

        if should_clip:
            pval = torch.clamp(pval, 0.0, 1.0)
        return pval


class NormDistDiscountedProxFunc(DistDiscountedProxFunc):
    def __init__(self, get_pf=None):
        def get_default_dist_pf(hidden_dim=64):
            return nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
                DiagGaussian(hidden_dim, 1)), hidden_dim

        if get_pf is None:
            get_pf = get_default_dist_pf
        super().__init__(get_pf)

    def label_data(self, prox_vals):
        return prox_vals

    def delabel_data(self, idx):
        return idx

    def _compute_dist_loss(self, guess_prox_dist, proximity):
        # Pull the reparameterization trick.
        samples = torch.stack([dist.rsample() for dist in guess_prox_dist])
        return F.mse_loss(samples,
                proximity.view(1, -1, 1).repeat(samples.shape[0], 1, 1))



class CatDistDiscountedProxFunc(DistDiscountedProxFunc):
    def __init__(self, get_pf=None):
        def get_default_dist_pf(obs_shape, action_dim, args):
            input_dim = obs_shape.shape[0]
            if args.action_input:
                input_dim += action_dim

            return nn.Sequential(
                nn.Linear(input_dim, 100), nn.Tanh(),
                nn.Linear(100, 100), nn.Tanh(),
                Categorical(100, args.pf_nbins))

        if get_pf is None:
            get_pf = get_default_dist_pf
        super().__init__(get_pf)

    def init(self, policy, args):
        super().init(policy, args)
        self.min_prox, self.max_prox = self.expert_dataset.get_prox_stats()
        self.min_bin_val = self.min_prox - (self.min_prox * args.pf_ep_len_buf)
        self.max_bin_val = 1.0
        step = (self.max_bin_val - self.min_bin_val) / (args.pf_nbins - 2)

        self.bins = np.arange(self.min_bin_val, self.max_bin_val + step, step)
        assert len(self.bins) == (args.pf_nbins - 1)

    def label_data(self, prox_vals):
        bin_labels = np.digitize(prox_vals, self.bins)
        return [x + 1 for x in bin_labels]

    def delabel_data(self, idx):
        bins = torch.tensor([0.0, *self.bins]).to(self.args.device)
        return bins[idx]

    def _compute_dist_loss(self, guess_prox_dist, proximity):
        real_prox_dist = torch.zeros(len(proximity), self.args.pf_nbins).to(self.args.device)
        real_prox_dist.scatter_(1, proximity.view(-1, 1), 1)

        loss = torch.tensor([F.kl_div(dist.probs, real_prox_dist) for dist in guess_prox_dist])
        loss = loss.mean()
        return loss

    def compute_good_traj_prox(self, obs, actions):
        proxs = compute_discounted_prox(len(obs), self.compute_prox_fn)
        bin_labels = self._get_bin_for_prox(proxs)
        return torch.FloatTensor(bin_labels)

    def compute_bad_traj_prox(self, obs, actions):
        label = torch.zeros(self.args.pf_nbins)
        label[0] = 1
        return label


    def get_add_args(self, parser):
        super().get_add_args(parser)
        #########################################
        # New args
        parser.add_argument('--pf-nbins', type=int, default=5)
        parser.add_argument('--pf-ep-len-buf', type=float, default=0.2)


