from abc import ABC

import torch
from torch import nn
from torch.nn import functional as F

from Algorithms.ETD.config_etd import ETDConfig
from Common.models.model_general import CNN, MLP, make_fc_layer



class ETDModel(nn.Module, ABC):
    def __init__(self, config:ETDConfig):
        super(ETDModel, self).__init__()
        self.config = config
        self.feature_extractor = nn.Sequential(
            CNN(config.feature_net_arch["cnn"],"etd_cnn",norm_type=self.config.model_norm_type,origin_size=self.config.state_shape,first_layer_norm=self.config.policy_kwargs.get("first_layer_norm",False)),
            MLP(config.etd_model_kwargs.get("etd_mlp_arch",[]),"etd_mdp",non_linearity_last=True,norm_type=self.config.model_norm_type),
        )
        etd_encoder_arch=config.etd_model_kwargs.get("etd_encoder_arch",[])
        if etd_encoder_arch:
            self.encoder = nn.Sequential(
                MLP(etd_encoder_arch[:-1],"etd_encoder",non_linearity_last=True,norm_type=self.config.model_norm_type),
                make_fc_layer(etd_encoder_arch[-2],etd_encoder_arch[-1]),
            )
            self.potential = nn.Sequential(
                MLP(etd_encoder_arch[:-1], "etd_p_mlp", non_linearity_last=True,
                    norm_type=self.config.model_norm_type),
                make_fc_layer(etd_encoder_arch[-2], 1,True),
            )
        else:
            self.encoder = make_fc_layer(config.feature_net_arch["mlp"][0], config.extra_hidden_size, True)
            self.potential=nn.Sequential(
                MLP([config.feature_net_arch["mlp"][0], config.extra_hidden_size],"etd_p_mlp",non_linearity_last=True,norm_type=self.config.model_norm_type),
                make_fc_layer(config.extra_hidden_size,1, True)
            )

    def embedding(self,x):
        if self.config.obs_is_color:
            x=x/256.
        else:
            x=x*1.0
        x=self.feature_extractor(x)
        return x

    @staticmethod
    def mrn_distance(x,y):
        eps = 1e-8
        d = x.shape[-1]
        x_prefix = x[..., :d // 2]
        x_suffix = x[..., d // 2:]
        y_prefix = y[..., :d // 2]
        y_suffix = y[..., d // 2:]
        max_component = torch.max(F.relu(x_prefix - y_prefix), dim=-1).values
        l2_component = torch.sqrt(torch.square(x_suffix - y_suffix).sum(-1) + eps)
        return max_component + l2_component

    def etd_embedding_process(self,x):
        x=self.encoder(self.embedding(x))
        return x

    def reward_process(self,next_states,goal_states):
        with torch.no_grad():
            f_g=self.etd_embedding_process(goal_states)
            f_s_=self.etd_embedding_process(next_states)
            dist_g_s_=self.mrn_distance(f_g,f_s_)
            return dist_g_s_,{}

    def similarity_process(self,states):
        with torch.no_grad():
            embs=self.etd_embedding_process(states)
            dists=self.mrn_distance(embs[:,None],embs[None,:])
        return dists

    def loss_process(self,states,future_states):
        emb_s=self.embedding(states)
        emb_s_=self.embedding(future_states)
        f_s=self.encoder(emb_s)
        f_s_=self.encoder(emb_s_)
        h_s_=self.potential(emb_s_)
        f_s=f_s[:,None]
        f_s_=f_s_[None,:]
        logits=h_s_.T-self.mrn_distance(f_s,f_s_)
        logits_target=torch.eye(emb_s.shape[0]).to(logits.device)
        loss=0.5*(F.cross_entropy(logits,logits_target)+F.cross_entropy(logits.T,logits_target))
        return loss,{"etd_loss":loss.item()}



