from abc import ABC
from typing import List

import numpy as np
from torch import nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F

from Common import Config


class PolicyModel(nn.Module,ABC):
    def __init__(self,config:Config):
        super().__init__()
        self.config=config
        self.feature_extractor=self.init_feature_extractor()
        self.init_head()

    def init_head(self):
        head_hidden_size=self.config.extra_hidden_size if "head_hidden_size" not in self.config.policy_kwargs else self.config.policy_kwargs["head_hidden_size"]

        self.extra_value = make_fc_layer(self.config.extra_hidden_size, head_hidden_size, True, 0.1,self.config.model_norm_type)
        self.extra_policy = make_fc_layer(self.config.extra_hidden_size, head_hidden_size, True, 0.1,self.config.model_norm_type)
        self.policy = make_fc_layer(head_hidden_size, self.config.n_actions, True, 0.01)
        self.int_value = make_fc_layer(head_hidden_size, 1, True, 0.01)
        self.ext_value = make_fc_layer(head_hidden_size, 1, True, 0.01)

    def init_feature_extractor(self):
        return nn.Sequential(
            CNN(self.config.feature_net_arch['cnn'], "p_cnn",norm_type=self.config.model_norm_type,origin_size=self.config.state_shape,first_layer_norm=self.config.policy_kwargs.get("first_layer_norm",False)),
            MLP(self.config.feature_net_arch['mlp'], "p_mlp", non_linearity_last=True,norm_type=self.config.model_norm_type),
        )

    def forward(self,input):
        x=input*1.0 if not self.config.obs_is_color else input/256.
        x=self.feature_extractor(x)
        if self.config.policy_kwargs.get("no_res",False):
            x_p = F.relu(self.extra_policy(x))
            x_v = F.relu(self.extra_value(x))
        else:
            x_p=x+F.relu(self.extra_policy(x))
            x_v=x+F.relu(self.extra_value(x))
        policy=self.policy(x_p)
        int_value=self.int_value(x_v)
        ext_value=self.ext_value(x_v)
        probs = F.softmax(policy, dim=1)
        dist = Categorical(probs)

        return dist, int_value, ext_value, probs

class PolicyModelRNN(PolicyModel):
    def __init__(self,config:Config):
        super().__init__(config)
        self.gru=nn.GRUCell(self.config.extra_hidden_size,self.config.extra_hidden_size)



    def forward(self,input):
        input,hidden_state=input
        x=input*1.0 if not self.config.obs_is_color else input/256.
        x=self.feature_extractor(x)
        h=self.gru(x,hidden_state)
        if self.config.policy_kwargs.get("no_res", False):
            x_p = F.relu(self.extra_policy(h))
            x_v = F.relu(self.extra_value(h))
        else:
            x_p = h + F.relu(self.extra_policy(h))
            x_v = h + F.relu(self.extra_value(h))
        policy=self.policy(x_p)
        int_value=self.int_value(x_v)
        ext_value=self.ext_value(x_v)
        probs = F.softmax(policy, dim=1)
        dist = Categorical(probs)

        return dist, int_value, ext_value, probs,h


class RNDModel(nn.Module, ABC):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.feature_extractor = nn.Sequential(
            CNN(config.obs_net_arch['cnn'], "rnd_cnn")
        )
        self.encoder=make_fc_layer(config.obs_net_arch["mlp"][0],config.rnd_hidden_size,True)

    def forward(self, input):
        x = input * 1.0
        x = self.feature_extractor(x)
        return self.encoder(x)

def make_fc_layer(in_features: int, out_features: int, use_bias=True, gain=np.sqrt(2),norm_type=0):
    fc_layer = nn.Linear(in_features, out_features, bias=use_bias)
    nn.init.orthogonal_(fc_layer.weight, gain=gain)
    if use_bias:
        nn.init.zeros_(fc_layer.bias)
    if norm_type == 0:
        return fc_layer
    elif norm_type == 1:
        return nn.Sequential(
            fc_layer,
            nn.BatchNorm1d(out_features,momentum=0.1),
        )
    elif norm_type == 2:
        return nn.Sequential(
            fc_layer,
            nn.LayerNorm(out_features),
        )

def make_conv_layer(in_channel,out_channel,kernel_size,stride=1,padding=0,gain=np.sqrt(2)):
    cnn_layer=nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding)
    nn.init.orthogonal_(cnn_layer.weight, gain=gain)
    nn.init.zeros_(cnn_layer.bias)

    return cnn_layer


def conv_shape(input, kernel_size, stride, padding=0):
    if not isinstance(input, tuple):
        return (input + 2 * padding - kernel_size) // stride + 1
    else:
        return tuple((i + 2 * padding - kernel_size) // stride + 1 for i in input)

class CNN(nn.Module):
    def __init__(
            self,
            cnn_feat_dim_list,
            name,
            non_linearity=nn.ReLU,
            non_linearity_last=True,
            norm_type=0,
            origin_size=None,
            first_layer_norm=False
    ):
        super().__init__()
        if norm_type==2:
            assert origin_size is not None
            cnn_inter_sizes=[]
            conv_size=origin_size[1:]
            for i, arch in enumerate(cnn_feat_dim_list):
                conv_size=conv_shape(conv_size,arch[2],arch[3],arch[4])
                cnn_inter_sizes.append(conv_size)


        self.cnn=nn.Sequential()
        if norm_type==2 and first_layer_norm:
            self.cnn.add_module("fl_norm",
                                nn.LayerNorm(origin_size))

        for i,arch in enumerate(cnn_feat_dim_list):
            self.cnn.add_module("{0}_cnn{1}".format(name,i),make_conv_layer(*arch))
            if norm_type==1:
                self.cnn.add_module("{0}_norm{1}".format(name,i), nn.BatchNorm2d(arch[1],momentum=0.1))
            elif norm_type==2:
                self.cnn.add_module("{0}_norm{1}".format(name,i), nn.LayerNorm([arch[1],cnn_inter_sizes[i][0],cnn_inter_sizes[i][1]]))
            if i<len(cnn_feat_dim_list)-1 or non_linearity_last:
                self.cnn.add_module("{0}_nonl{1}".format(name,i),non_linearity())
        self.cnn.add_module("{0}_flatten".format(name),nn.Flatten())

    def forward(self,x):
        x=self.cnn(x)
        return x

class MLP(nn.Module):
    def __init__(
            self,
            fc_feat_dim_list: List[int],
            name: str,
            non_linearity: nn.Module = nn.ReLU,
            non_linearity_last: bool = False,
            gain=np.sqrt(2),
            non_linearity_module_last=None,
            norm_type=0,
    ):
        super(MLP, self).__init__()
        if non_linearity_module_last is None:
            non_linearity_module_last = non_linearity
        self.fc_layers = nn.Sequential()
        for i in range(len(fc_feat_dim_list) - 1):
            fc_layer = make_fc_layer(fc_feat_dim_list[i], fc_feat_dim_list[i + 1], gain=gain,norm_type=norm_type)
            self.fc_layers.add_module("{0}_fc{1}".format(name, i + 1), fc_layer)
            if i + 1 < len(fc_feat_dim_list) - 1:
                self.fc_layers.add_module("{0}_non_linear{1}".format(name, i + 1), non_linearity())
            elif non_linearity_last:
                self.fc_layers.add_module("{0}_non_linear{1}".format(name, i + 1), non_linearity_module_last())

    def forward(self, data):
        return self.fc_layers(data)


