from abc import ABC

from torch import nn
from torch.nn import functional as F

from Algorithms.Slav3.config_slav3 import Slav3Config
from Common.models.model_general import CNN, MLP, make_fc_layer


class SlaValueModel(nn.Module,ABC):
    def __init__(self,config:Slav3Config):
        super().__init__()
        self.config=config
        self.feature_extractor=nn.Sequential(
            CNN(config.sla_value_net_arch['cnn'],"p_cnn",norm_type=config.model_norm_type,origin_size=[self.config.state_shape,self.config.obs_shape,self.config.cell_shape][self.config.sla_value_state_type]),
            MLP(config.sla_value_net_arch['mlp'],"p_mlp",non_linearity_last=True,norm_type=self.config.model_norm_type)
        )
        hidden_size=config.sla_value_net_arch['mlp'][-1]
        self.extra_value=make_fc_layer(hidden_size,hidden_size,True,0.1,norm_type=self.config.model_norm_type)
        self.sla_value = make_fc_layer(hidden_size, 1, True, 0.01)

    def forward(self,input):
        x=input*1.0 if (not self.config.obs_is_color) or self.config.sla_value_state_type==2 else input/256.
        x=self.feature_extractor(x)
        x_v=x+F.relu(self.extra_value(x))
        sla_value=self.sla_value(x_v)

        return sla_value

class SlaValueRNDModel(nn.Module,ABC):
    def __init__(self,config:Slav3Config,is_rnd=False):
        super().__init__()
        self.config=config
        self.scale=self.config.sla_reward*self.config.max_frames_per_episode if not is_rnd else 1.0
        norm_type=self.config.model_norm_type
        self.feature_extractor=nn.Sequential(
            CNN(config.sla_value_net_arch['cnn'],"p_cnn",norm_type=norm_type,origin_size=[self.config.state_shape,self.config.obs_shape,self.config.cell_shape][self.config.sla_value_state_type]),
        )
        self.pred=MLP(config.sla_value_net_arch['mlp'],"pred_mlp",non_linearity_last=False,norm_type=norm_type)
        self.target = MLP(config.sla_value_net_arch['mlp'], "pred_mlp", non_linearity_last=False,
                        norm_type=norm_type)

    def forward(self,input):
        x=input*1.0 if (not self.config.obs_is_color) or self.config.sla_value_state_type==2 else input/256.
        x=self.feature_extractor(x)
        x1=self.pred(x)
        x2=self.target(x).detach()
        sla_value=self.scale*((x1-x2).pow(2).mean(-1).unsqueeze(-1))

        return sla_value

