import torch
import torch.nn as nn
import torch.nn.functional as F

from Encoding.MatNet_LIB import AddAndInstanceNormalization, FeedForward, MixedScore_MultiHeadAttention

class MatNetEncoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        encoder_layer_num = model_params['encoder_layer_num']
        self.model_params = model_params
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

        self.tw_start_embedder = nn.Linear(1, model_params['embedding_dim'])
        self.tw_end_embedder = nn.Linear(1, model_params['embedding_dim'])

    def forward(self, dists, tw_start, tw_end):
        batch_size, n_nodes, _, _ = dists.shape
        row_emb = torch.zeros(size=(batch_size, n_nodes, self.model_params['embedding_dim']))
        col_emb = torch.zeros(size=(batch_size, n_nodes, self.model_params['embedding_dim']))

        tw_start_emb = self.tw_start_embedder(tw_start.unsqueeze(-1))
        tw_end_emb = self.tw_end_embedder(tw_end.unsqueeze(-1))

        if self.model_params['tw_row_emb']:
            seed_cnt = self.model_params['one_hot_seed_cnt']
            rand = torch.rand(batch_size, seed_cnt)
            batch_rand_perm = rand.argsort(dim=1)
            rand_idx = batch_rand_perm[:, :n_nodes]

            b_idx = torch.arange(batch_size)[:, None].expand(batch_size, n_nodes)
            n_idx = torch.arange(n_nodes)[None, :].expand(batch_size, n_nodes)
            col_emb[b_idx, n_idx, rand_idx] = 1
            # col_emb.shape: (batch, col_cnt, embedding)
            # row_emb.shape: (batch, row_cnt, embedding)
            # cost_mat.shape: (batch, row_cnt, col_cnt, Nobj)
            row_emb = tw_start_emb + tw_end_emb
        else:
            col_emb = tw_start_emb + tw_end_emb

        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, dists)

        return row_emb, col_emb


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**model_params)
        self.col_encoding_block = EncodingBlock(**model_params)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt, Nobj)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))

        return row_emb_out, col_emb_out


class EncodingBlock(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.mixed_score_MHA = MixedScore_MultiHeadAttention(**model_params)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, row_emb, col_emb, cost_mat):
        # NOTE: row and col can be exchanged, if cost_mat.transpose(1,2) is used
        # input1.shape: (batch, row_cnt, embedding)
        # input2.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt, Nobj)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(row_emb), head_num=head_num)
        # q shape: (batch, head_num, row_cnt, qkv_dim)
        k = reshape_by_heads(self.Wk(col_emb), head_num=head_num)
        v = reshape_by_heads(self.Wv(col_emb), head_num=head_num)
        # kv shape: (batch, head_num, col_cnt, qkv_dim)

        out_concat = self.mixed_score_MHA(q, k, v, cost_mat)
        # shape: (batch, row_cnt, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, row_cnt, embedding)

        out1 = self.add_n_normalization_1(row_emb, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, row_cnt, embedding)

########################################
# NN SUB FUNCTIONS
########################################

def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed
