import torch
import torch.nn as nn
from torch.distributions import Categorical
import os
import numpy as np
import sys

sys.path.append("../")
from bpp_env import bpp_env


class AddAndNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True,
                                      track_running_stats=False)  # nn.BatchNorm1d(embedding_dim)

    def forward(self, input1, input2):
        # input.shape: (batch, n, embedding)

        added = input1 + input2
        # shape: (batch, n, embedding)

        transposed = added.transpose(1, 2)
        # shape: (batch, embedding, n

        normalized = self.norm(transposed)
        # shape: (batch, embedding, n)

        back_trans = normalized.transpose(1, 2)
        # shape: (batch, n, embedding)

        return back_trans


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super(EncoderLayer, self).__init__()

        embed_dim = model_params['embedding_dim']
        n_heads = model_params['head_num']
        ff_dim = model_params['ff_dim']

        self.mha = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.add_n_normalization_1 = AddAndNormalization(**model_params)
        self.add_n_normalization_2 = AddAndNormalization(**model_params)

    def forward(self, x, key_padding_mask):
        """x: (Batch, n_orders, embed_dim)"""
        mh_atten_out, mh_atten_weight = self.mha(x, x, x, key_padding_mask=key_padding_mask)
        fnn_input = self.add_n_normalization_1(x, mh_atten_out)  # skip connections
        encoder_output = self.add_n_normalization_2(fnn_input, self.feed_forward(fnn_input))

        return encoder_output


class Encoder(nn.Module):
    def __init__(self, **model_params):
        super(Encoder, self).__init__()
        self.input_dim = model_params['input_dim']
        embed_dim = model_params['embedding_dim']
        encoder_layer_num = model_params['encoder_layer_num']

        self.input_embedding = nn.Linear(self.input_dim, embed_dim, bias=False)
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, x):
        """"x: (Batch, n_orders, input_dim)"""
        key_padding_mask = torch.all(x == 0, dim=2)  # 找出每个batch中全零行的位置,shape:(batch, n_orders)
        x = self.input_embedding(x)

        for layer in self.layers:
            x = layer(x, key_padding_mask)
        encoder_output = x  # (bs, n, embed_dim)
        return encoder_output


class Decoder(nn.Module):
    def __init__(self, **model_params):
        super(Decoder, self).__init__()
        self.model_params = model_params
        input_dim = model_params['input_dim']
        embed_dim = model_params['embedding_dim']
        n_heads = model_params['head_num']
        ff_dim = model_params['ff_dim']

        self.query_embedding = nn.Linear(input_dim, embed_dim, bias=False)
        self.mha = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)

        self.decoder_start = nn.Parameter(torch.zeros(1, embed_dim))  # 初始query是可学习的向量，用于产生起始点
        self.decoder_start.data.uniform_(-1., 1.)
        self.k = None
        self.v = None
        self.single_head_key = None

    def set_kv(self, encoded_orders):
        self.k = encoded_orders
        self.v = encoded_orders
        self.single_head_key = encoded_orders.transpose(1, 2)  # (bs, embed, n)

    def forward(self, query, mask):
        """

                :param query: (bs, 30, input_dim)
                :param mask: (bs, n_orders)
                :return:
                """
        q = self.query_embedding(query)
        mh_atten_out, mh_atten_weight = self.mha(q, self.k, self.v, key_padding_mask=mask)  # (bs, 30, embed)
        score = torch.matmul(mh_atten_out, self.single_head_key)  # (bs, 30, embed) * (bs, embed, n)

        logit_clipping = self.model_params['logit_clipping']
        score_clipped = logit_clipping * torch.tanh(score)  # (bs, 30, n)
        score_clipped, _ = score_clipped.max(1)  # (bs, n)

        ninf_mask = torch.where(mask == True, float('-inf'), 0.)
        score_masked = score_clipped + ninf_mask

        probs = torch.softmax(score_masked, dim=1)  # (bs, n)
        return probs


class Policy(nn.Module):
    def __init__(self, **model_params):
        super(Policy, self).__init__()
        self.model_params = model_params

        self.encoder = Encoder(**model_params)
        self.decoder = Decoder(**model_params)
        self.decoder_start = None

    def pre_forward(self, weights, classes, capacity):
        """

        :param bom_matrix: (bs, n_orders, input_dim)
        :return:
        """

        norm_weights = weights / capacity
        norm_weights = norm_weights.unsqueeze(2)
        encoded_orders = self.encoder(torch.cat((norm_weights, classes), dim=2))
        self.decoder.set_kv(encoded_orders)

        batch_size = encoded_orders.size(0)
        embed_dim = encoded_orders.size(2)
        self.decoder_start = self.decoder.decoder_start[None, :, :].expand(batch_size, 1, embed_dim)

    def forward(self, state, guided_action=None):
        query, mask = state
        probs = self.decoder(query, mask)  # (bs, n)

        item_sampler = Categorical(probs)
        selected = item_sampler.sample() if guided_action == None else guided_action
        log_probs = item_sampler.log_prob(selected)
        # selected = probs.multinomial(1)  # (bs, 1)

        return selected, log_probs


if __name__ == '__main__':

    root_dir = '../data/september'
    files = os.listdir(root_dir)
    files.sort()
    bom_matrix = np.load(os.path.join(root_dir, files[0]))

    model_params = {
        'input_dim': bom_matrix.shape[1],
        'embedding_dim': 256,
        'sqrt_embedding_dim': 256 ** (1 / 2),
        'encoder_layer_num': 5,
        'head_num': 4,
        'logit_clipping': 10,
        'ff_dim': 512,

        'eval_type': 'argmax',
        'one_hot_seed_cnt': 20,  # must be >= node_cnt
    }

    order_n = len(bom_matrix)
    batch_size = 128
    # permutaion = np.arange(order_n)
    per_batch = np.tile(bom_matrix, (batch_size, 1, 1))
    # per_batch = np.random.permuted(per_batch, axis=1)
    input_orders_vec = torch.FloatTensor(per_batch)
    #
    print('输入向量的维度', input_orders_vec.size())
    input_dim = bom_matrix.shape[1]
    embed_dim = 256
    encoder = Encoder(**model_params)
    enc_output = encoder(input_orders_vec)
    print('编码器输出向量的维度', enc_output.size())

    env = Env(bom_matrix)
    query, mask = env.reset()
    query = torch.FloatTensor(query)[None, None, :]
    mask = torch.LongTensor(mask)[None, :]
    state = [query, mask]

    policy = Policy(**model_params)
    bom_matrix_ts = torch.FloatTensor(bom_matrix)[None, :, :]
    policy.pre_forward(bom_matrix_ts)
    actions = []
    done = False
    while not done:
        a, log_p = policy(state)
        # print(log_p.size())
        a = int(a.squeeze(0).long())
        actions.append(int(a))
        next_state, r, done = env.step(a)
        state = [torch.FloatTensor(next_state[0])[None, None, :], torch.LongTensor(next_state[1])[None, :]]

    R = env.rewards[-1]
    print(f'开始动作：{actions[0]}, 装箱个数：{-R}')
    print(f"动作数：{len(set(actions))}")
    discounted_returns = normalize_return(env.rewards)
