import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from torch import nn
from delphicORL.networks.vae_network import *
from delphicORL.utils import data



class WorldModel_Markov(BetaVAE):
    """ Given a sequence of observations, the WorldModel_Markov model learns a latent representation of the sequence.
    
    For each pair of observation and action, the model also uses the latent representation to predict the next observation.
    """
    def __init__(self, kl_weight, input_dim,
                    latent_dim,
                    max_len=1,
                    hidden_size= [128, 64, 32],
                    action_dim=1, target_dim=None):
        super().__init__(kl_weight=kl_weight, input_dim=input_dim, max_len=max_len, latent_dim=latent_dim, hidden_size=hidden_size)
        self.action_dim = action_dim
        if target_dim is None:
            target_dim = input_dim - action_dim
        self.target_dim = target_dim

        self.decoder = Decoder(latent_dim + self.input_dim,
                                self.target_dim,
                                hidden_size)

    def forward(self, traj):
        x = traj.reshape(traj.shape[0], -1)
        q_z = self.encoder(x)
        z = q_z.rsample()
        z = torch.unsqueeze(z, 1)
        decoder_input = torch.cat([torch.tile(z, (1, self.max_len, 1)), traj], dim=-1)
        return self.decoder(decoder_input), q_z

    def loss(self, traj, qfunc_target_seq, kl_weight=None, mask=None, reduce='mean'):
        """ Compute the loss of the model.
        Log likelihood is measured with respect to the a target sequence `qfunc_target_seq': 
                e.g., the next observation given the current observation and action,
                or the Q-function given the current observation and action.
        KL divergence (classic VAE) is computed between the latent representation and a standard normal distribution."""
        if kl_weight is None:
            kl_weight = self.kl_weight

        traj = self.format_input(traj)
        qfunc_target_seq = self.format_input(qfunc_target_seq)
        mask = self.format_input(mask).to(torch.bool)
        p_x, q_z = self.forward(traj)

        # Evaluate likelihood of the target given the current observation, action, and latent representation
        log_likelihood = p_x.log_prob(qfunc_target_seq)

        kl = torch.distributions.kl_divergence(
            q_z, 
            torch.distributions.Normal(0, 1.)
        ).sum(-1)

        
        if reduce=='mean':
            if mask is not None:
                bs, max_len = mask.shape
                log_likelihood = log_likelihood.reshape(bs, max_len, -1)[mask]

            log_likelihood = log_likelihood.sum(-1).mean()
            kl = kl.mean()

            loss = -(log_likelihood - kl_weight * kl)

        elif reduce is None:
            if mask is not None:
                bs, max_len = mask.shape
                log_likelihood = log_likelihood.reshape(bs, max_len, -1).sum(-1)
                loss = -(log_likelihood - kl_weight * kl.unsqueeze(-1).expand((-1,max_len)))
                loss = loss[mask]
                log_likelihood = log_likelihood[mask]

            else:
                log_likelihood = log_likelihood.sum(-1)
                loss = -(log_likelihood - kl_weight * kl)

        return {'loss': loss, 'Reconstruction_LL':log_likelihood.detach(), 'KL':-kl.detach()}


class WorldModel_Q(WorldModel_Markov):
    """ Given a sequence of observations, the WorldModel_Q model learns a latent representation of the sequence
    such that the Q function can be reconstructed from the latent representation and an input observation and action.

    Training objectives are log likelihood of the Q function for this pair,
    and the KL divergence between the latent posterior and prior.
    """

    def __init__(self, kl_weight, input_dim, latent_dim, max_len=1, hidden_size= [128, 64, 32],
                    action_dim=1):
        super().__init__(kl_weight=kl_weight, input_dim=input_dim, latent_dim=latent_dim, max_len=max_len,
            hidden_size=hidden_size,
            action_dim=action_dim,
            target_dim=1)


class WorldModel_QPi(BetaVAE):
    def __init__(self, kl_weight, input_dim, latent_dim, pi_weight=1.0, max_len=1, hidden_size= [128, 64, 32],
                    action_dim=1):
        super().__init__(kl_weight=kl_weight, input_dim=input_dim, latent_dim=latent_dim, max_len=max_len,
            hidden_size=hidden_size)
        self.action_dim=action_dim
        state_dim = input_dim - action_dim

        self.qmodel = Decoder(latent_dim + self.input_dim, 1, hidden_size)
        self.pimodel = Decoder(latent_dim + state_dim, action_dim, hidden_size)
        self.pi_weight = pi_weight

    def forward(self, state, action):
        traj = torch.cat([state, action], -1)
        x = traj.reshape(traj.shape[0], -1)
        q_z = self.encoder(x)

        z = torch.unsqueeze(q_z.rsample(),1)
        qfunc_out = self.qmodel(torch.cat([torch.tile(z, (1, self.max_len, 1)), traj], dim=-1))
        pi_out = self.pimodel(torch.cat([torch.tile(z, (1, self.max_len, 1)), state], dim=-1))
        return qfunc_out, pi_out, q_z

    def loss(self, states, actions, qfunc_target_seq, kl_weight=None, mask=None, pi_weight=None):
        """ Compute the loss of the model.
        Log likelihood is measured with respect to the a target sequence `target_seq': 
                e.g., the next observation given the current observation and action,
                or the Q-function given the current observation and action.
        KL divergence (classic VAE) is computed between the latent representation and a standard normal distribution."""
        if kl_weight is None:
            kl_weight = self.kl_weight
        if pi_weight is None:
            pi_weight = self.pi_weight

        p_qfunc, p_pi, q_z = self.forward(states, actions)

        # Evaluate likelihood of the target given the current observation, action, and latent representation
        log_likelihood_q = p_qfunc.log_prob(qfunc_target_seq)
        log_likelihood_pi = p_pi.log_prob(actions)

        kl = torch.distributions.kl_divergence(
            q_z, 
            torch.distributions.Normal(0, 1.)
        ).sum(-1)

    
        if mask is not None:
            bs, max_len = mask.shape
            log_likelihood_q = log_likelihood_q.reshape(bs, max_len, -1)[mask]
            log_likelihood_pi = log_likelihood_pi.reshape(bs, max_len, -1)[mask]

        log_likelihood_q = log_likelihood_q.sum(-1).mean()
        log_likelihood_pi = log_likelihood_pi.sum(-1).mean()
        kl = kl.mean()

        loss = -(log_likelihood_q - pi_weight * log_likelihood_pi - kl_weight * kl)

        return {'loss': loss, 'LL_Q':log_likelihood_q.detach(), 'LL_pi':log_likelihood_pi.detach(), 'KL':-kl.detach()}






###############################################################

class EpistemicWorldModel(WorldModel_QPi):
    def __init__(self, kl_weight, input_dim, latent_dim, pi_weight=1.0, max_len=1, hidden_size= [128, 64, 32],
                    action_dim=1, n_bootstraps=5):
        super().__init__(kl_weight=kl_weight, input_dim=input_dim, latent_dim=latent_dim, pi_weight=pi_weight, max_len=max_len,
                hidden_size=hidden_size, action_dim=action_dim)
        self.action_dim=action_dim
        state_dim = input_dim - action_dim

        self.qmodel = None
        self.pimodel = None
        self.n_bootstraps = n_bootstraps

        self.qmodels = nn.ModuleList(
            [Decoder(latent_dim + self.input_dim, 1, hidden_size) for _ in range(n_bootstraps)])
        self.pimodels = nn.ModuleList(
            [Decoder(latent_dim + state_dim, action_dim, hidden_size) for _ in range(n_bootstraps)])

    def forward(self, state, action, decoder_k = None):
        traj = torch.cat([state, action], -1)
        x = traj.reshape(traj.shape[0], -1)
        q_z = self.encoder(x)
        z = q_z.rsample()
        z = torch.unsqueeze(z, 1)
    
        if decoder_k is not None:
            #return self.decoders[decoder_k](decoder_input), q_z
            qfunc_out = self.qmodels[decoder_k](torch.cat([torch.tile(z, (1, self.max_len, 1)), traj], dim=-1))
            pi_out = self.pimodels[decoder_k](torch.cat([torch.tile(z, (1, self.max_len, 1)), state], dim=-1))
            return qfunc_out, pi_out, q_z
        qfunc_out = [qmodel(torch.cat(torch.cat([torch.tile(z, (1, self.max_len, 1)), traj], dim=-1))) for qmodel in self.qmodels]
        pi_out = [pimodel(torch.cat([torch.tile(z, (1, self.max_len, 1)), state], dim=-1)) for pimodel in self.pimodels]
        return qfunc_out, pi_out, q_z



    def loss(self, states, actions, qfunc_target_seq, kl_weight=None, mask=None, pi_weight=None): #model_k=None, 
        if kl_weight is None:
            kl_weight = self.kl_weight
        if pi_weight is None:
            pi_weight = self.pi_weight

        bs = states.shape[0]
        #target_seq = self.format_input(target_seq)
        #mask = self.format_input(mask).to(torch.bool)
        list_p_qfunc, list_p_pi, q_z = self.forward(states, actions)

        # Evaluate likelihood of the target given the current observation, action, and latent representation
        log_likelihood_q = [p_qfunc.log_prob(qfunc_target_seq) for p_qfunc in list_p_qfunc]
        log_likelihood_pi = [p_pi.log_prob(actions) for p_pi in list_p_pi]

        kl = torch.distributions.kl_divergence(q_z, torch.distributions.Normal(0, 1.)).sum(-1)
    
        metrics = {}
        kl = kl.mean()
        loss = kl_weight * kl * self.n_bootstraps
        metrics['KL'] = -kl.detach()

        for k, (llq_k, llpi_k) in enumerate(zip(log_likelihood_q, log_likelihood_pi)):
            perm_k = torch.randperm(bs)
            llq_k = llq_k[perm_k[:int(0.8 * bs)]]
            llpi_k = llpi_k[perm_k[:int(0.8 * bs)]]
            if mask is not None:
                mask_k = mask[perm_k[:int(0.8 * bs)]]
                llq_k = llq_k[mask_k] 
                llpi_k = llpi_k[mask_k] 

            llq_k = llq_k.sum(-1).mean()
            llpi_k = llpi_k.sum(-1).mean()
            metrics[f'Reconstruction_LLQ_dec{k}'] = llq_k.detach()
            metrics[f'Reconstruction_LLpi_dec{k}'] = llpi_k.detach()
            loss -= llq_k - pi_weight * llpi_k

        metrics['loss']= loss
        return metrics
        
    def predict_aleatoric_epistemic_uncertainty(self, state, action, mask = None):
        self.eval()
        if len(state.shape) > len(action.shape):
            action = torch.unsqueeze(action, -1)
        state = state.cuda()
        action = action.cuda()
        list_p_q, _, q_z = self.forward(state, action)

        if mask is not None:
            bs, max_len = mask.shape
            sq_mean_output_var = torch.mean(torch.stack([p.variance.reshape(bs, max_len, -1)[mask] for p in list_p_q], 0), 0)**2
            var_output_std = torch.var(torch.sqrt(torch.stack([p.variance.reshape(bs, max_len, -1)[mask] for p in list_p_q], 0)), 0)
            var_output_mean = torch.var(torch.stack([p.mean.reshape(bs, max_len, -1)[mask] for p in list_p_q], 0), 0)
        else:
            sq_mean_output_var = torch.mean(torch.stack([p.variance for p in list_p_q], 0), 0)**2
            var_output_std = torch.var(torch.sqrt(torch.stack([p.variance for p in list_p_q], 0)), 0)
            var_output_mean = torch.var(torch.stack([p.mean for p in list_p_q], 0), 0)

        return sq_mean_output_var, var_output_mean + var_output_std # aleatoric, epistemic



def get_counterfactual_q_func(q_func, behav, policy):
    """ Returns a counterfactual q function for a new policy, adjusted over one step
    based on a Q-function trained on a different policy, and a model for the behavior policy.
    All models are already evaluated at the (state, action) pairs of interest.
    """
    EPSILON=1e-8
    return q_func * policy / (behav + EPSILON)




###############################################################

class CompatibleWorldModels(nn.Module):
    """ Ensemble of compatible WorldModel_QPi models. 
    
    The ensemble consists of k world models.

    """

    def __init__(self, kl_weight, input_dim, latent_dim, max_len=1,
        hidden_size= [128, 64, 32], action_dim=1, target_dim=None, pi_weight=1.0, n_worldmodels=5):
        super().__init__()
        self.kl_weight = kl_weight
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.max_len = max_len
        self.hidden_size=hidden_size
        self.action_dim=action_dim
        self.target_dim=target_dim
        self.n_worldmodels = n_worldmodels
        self.world_models = nn.ModuleList(
            [WorldModel_QPi(kl_weight=kl_weight, input_dim=input_dim, latent_dim=latent_dim, pi_weight=pi_weight, max_len=max_len,
                hidden_size=hidden_size, action_dim=action_dim)
             for _ in range(n_worldmodels)])


    def format_input(self, x):
        if x is None:
            return x
        if len(x.shape) < 2:
            return x
        if x.shape[1] < self.max_len:
            pad_len = self.max_len - x.shape[1]
            if len(x.shape) == 3:
                x = torch.nn.functional.pad(x, ((0,0,0, pad_len)))
            else:
                x = torch.nn.functional.pad(x, (0, pad_len), value=0)

        elif x.shape[1] > self.max_len:
            x = x[:, :self.max_len]
        return x


    def forward(self, state, action, model_k=None):
        if model_k is None:
            p_q, p_pi, q_x = zip(*[model(state, action) for model in self.world_models])
            return p_q, p_pi, q_x
        else:
            return self.world_models[model_k](state, action)

    def loss(self, state, action, qfunc_target_seq, model_k=None, kl_weight=None, mask=None, pi_weight=None, reduce='mean'):
        if model_k is None:
            out = torch.stack([self.loss(state, action, qfunc_target_seq, k, kl_weight, mask, pi_weight, reduce) for k in range(len(self.world_models))])
            if reduce == 'mean':
                return torch.mean(out)
            else:
                return out
        else:
            return {f'WM{model_k}_{key}': value for key, value in \
                    self.world_models[model_k].loss(state, action, qfunc_target_seq, kl_weight, mask, pi_weight, reduce).items()}

    def predict_delphic_uncertainty(self, state, action, policy=None, mask=None, statistics = None):
        self.eval()
        data.normalise(state, action, statistics = statistics)
        policy_probas = policy.predict_probas(state, action)
        if len(state.shape) > len(action.shape):
            action = torch.unsqueeze(action, -1)
        state = self.format_input(state)
        action = self.format_input(action)

        #z = torch.distributions.Normal(0, 1.).sample((self.n_worldmodels, *traj.shape[:-1], self.latent_dim)).to(traj.device)
        #p_x = [self.world_models[k].decoder(torch.cat([z[k], traj], -1)) for k in range(self.n_worldmodels)]


        counterfactuals = []
        for k in range(self.n_worldmodels):
            p_q, p_pi, _ = self.world_models[k].predict(state, action) 
            p_counterfactual = get_counterfactual_q_func(q_func = p_q.mean, behav = p_pi.mean, policy = policy_probas)

            if mask is not None:
                mask = self.format_input(mask).to(torch.bool) 
                bs, max_len = mask.shape
                p_counterfactual = p_counterfactual.reshape(bs, max_len, -1)[mask]
            counterfactuals.append(p_counterfactual)

        return torch.var(torch.stack(counterfactuals, 0), 0)


class ComaptibleEpistemicWorldModels(CompatibleWorldModels):
    def __init__(self, kl_weight, input_dim, latent_dim, max_len=1,
        hidden_size= [128, 64, 32], action_dim=1, target_dim=1, pi_weight=1.0, n_worldmodels=5, n_bootstraps=5):
        super().__init__(kl_weight=kl_weight, input_dim=input_dim, latent_dim=latent_dim, max_len=max_len,
        hidden_size= hidden_size, action_dim=action_dim, target_dim=target_dim, pi_weight=pi_weight, n_worldmodels=n_worldmodels)

        self.world_models = nn.ModuleList(
                [EpistemicWorldModel(kl_weight=kl_weight,
                                    input_dim=input_dim,
                                    latent_dim=latent_dim,
                                    max_len=max_len,
                                    hidden_size=hidden_size,
                                    action_dim=action_dim,
                                    n_bootstraps=n_bootstraps)
                for _ in range(self.n_worldmodels)])

    def predict_delphic_uncertainty(self, state, action, policy=None, mask=None, statistics = None):
        self.eval()
        data.normalise(state, action, statistics = statistics)
        policy_probas = policy.predict_probas(state, action)
        if len(state.shape) > len(action.shape):
            action = torch.unsqueeze(action, -1)
        state = self.format_input(state)
        action = self.format_input(action)

        counterfactuals = []
        for k in range(self.n_worldmodels):
            list_p_q, list_p_pi, _ = self.world_models[k].forward(state, action) 
            p_q = torch.mean([q.mean for q in list_p_q])
            p_pi = torch.mean([pi.mean for pi in list_p_pi])
            p_counterfactual = get_counterfactual_q_func(q_func = p_q, behav = p_pi, policy = policy_probas)

            if mask is not None:
                mask = self.format_input(mask).to(torch.bool) 
                bs, max_len = mask.shape
                p_counterfactual = p_counterfactual.reshape(bs, max_len, -1)[mask]
            counterfactuals.append(p_counterfactual)
        return torch.var(torch.stack(counterfactuals, 0), 0)
    

    
    def predict_aleatoric_epistemic_uncertainty(self, state, action, mask = None, statistics = None):
        self.eval()
        data.normalise(state, action, statistics = statistics)
        if len(state.shape) > len(action.shape):
            action = torch.unsqueeze(action, -1)
        state = self.format_input(state)
        action = self.format_input(action)

        sqmeanvar, var_mean, var_std = [], [], []
        for k in range(self.n_worldmodels):
            list_p_q, _, q_z = self.world_models[k].forward(state, action)

            if mask is not None:
                bs, max_len = mask.shape
                sqmeanvar.append(torch.mean(torch.stack([p.variance.reshape(bs, max_len, -1)[mask] for p in list_p_q], 0), 0)**2)
                var_std.append(torch.var(torch.sqrt(torch.stack([p.variance.reshape(bs, max_len, -1)[mask] for p in list_p_q], 0)), 0))
                var_mean.append(torch.var(torch.stack([p.mean.reshape(bs, max_len, -1)[mask] for p in list_p_q], 0), 0))
            else:
                sqmeanvar.append(torch.mean(torch.stack([p.variance for p in list_p_q], 0), 0)**2)
                var_std.append(torch.var(torch.sqrt(torch.stack([p.variance for p in list_p_q], 0)), 0))
                var_mean.append(torch.var(torch.stack([p.mean for p in list_p_q], 0), 0))

        sqmeanvar = torch.mean(torch.stack(sqmeanvar, 0), 0)
        var_std = torch.mean(torch.stack(var_std, 0), 0)
        var_mean = torch.mean(torch.stack(var_mean, 0), 0)
        return sqmeanvar, var_mean + var_std # aleatoric, epistemic
