from abc import ABC

import torch
from torch import nn
from torch.nn import functional as F

from Algorithms.RIDE.config_ride import RIDEConfig
from Common.models.model_general import CNN, make_fc_layer


class RIDEModel(nn.Module, ABC):
    def __init__(self, config: RIDEConfig):
        super().__init__()
        self.config=config
        self.feature_extractor = nn.Sequential(
            CNN(config.feature_net_arch['cnn'], "rnd_cnn")
        )
        self.encoder = make_fc_layer(config.feature_net_arch["mlp"][0], config.extra_hidden_size, True)
        self.forward_dynamic=make_fc_layer(config.extra_hidden_size+config.n_actions, config.extra_hidden_size)
        self.inverse_dynamic=make_fc_layer(config.extra_hidden_size*2, config.n_actions)

    def embedding(self,x):
        if self.config.obs_is_color:
            x=x/256.
        else:
            x=x*1.0
        x=self.encoder(self.feature_extractor(x))
        return x

    def reward_process(self, states, next_states, actions):
        with torch.no_grad():
            emb_s, emb_s_ = self.embedding(states), self.embedding(next_states)
            reward = (emb_s - emb_s_).pow(2).mean(1).pow(0.5)
        return reward, {}

    def loss_process(self,states,next_states,actions):
        emb_s,emb_s_=self.embedding(states),self.embedding(next_states)
        a=F.one_hot(actions.long(),num_classes=self.config.n_actions)
        pred_s_=self.forward_dynamic(torch.cat([emb_s,a],dim=1))
        fwd_loss=(pred_s_-emb_s_).pow(2).mean()
        inv_logit=self.inverse_dynamic(torch.cat([emb_s,emb_s_],dim=1))
        inv_loss=F.nll_loss(F.log_softmax(inv_logit,dim=1),actions)
        loss=fwd_loss+inv_loss
        return loss,{"fwd_loss":fwd_loss.item(),"inv_loss":inv_loss.item()}

