import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class TSPModel(nn.Module):

    def __init__(self,device, model_params,):
        super().__init__()
        self.model_params = model_params
        self.device = device

        self.encoder_nodes = TSP_Encoder_1(model_params,device,16)
        #self.encoder_nodes = RevMHALowerEncoder_2Encoder_Nodes(device = device,
        #                                                           n_layers=model_params['encoder_layer_num'],
        #                                                           n_heads = model_params['head_num'],
        #                                                           embedding_dim=model_params['embedding_dim'],
        #                                                           input_dim = 16,
        #                                                           intermediate_dim= model_params['embedding_dim'] * 4)
        
        #self.encoder_tunnels = TSP_Encoder_2(model_params,device,40)
        self.encoder_tunnels = RevMHALowerEncoder_2Encoder_Tunnels(device = device,
                                                                   n_layers=model_params['encoder_layer_num'],
                                                                   n_heads = model_params['head_num'],
                                                                   embedding_dim=model_params['embedding_dim'],
                                                                   input_dim = 40,
                                                                   intermediate_dim= model_params['embedding_dim'] * 4)
        self.decoder = TSP_Decoder(model_params,device)
        self.encoded_nodes = None
        self.encoded_tunnels = None
        # shape: (batch, problem, EMBEDDING_DIM)

    def pre_forward(self, embeddings_nodes,embeddings_tunnels,batch_tunnel):
        self.encoded_nodes = embeddings_nodes
        self.encoded_tunnels = embeddings_tunnels
        # shape: (batch, problem, EMBEDDING_DIM)
        _=self.decoder.regret_embedding[None, None, :].expand(self.encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
        self.encoded_nodes=torch.cat((self.encoded_nodes,_), dim=1)
        self.encoded_tunnels = torch.cat((self.encoded_tunnels,_),dim=1)
        self.batch_tunnel = batch_tunnel
        self.decoder.set_kv(self.encoded_nodes,self.encoded_tunnels)

    def forward(self, state):
        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.current_node is None:
            selected = torch.arange(pomo_size)[None, :].expand(batch_size, pomo_size)
            prob = torch.ones(size=(batch_size, pomo_size))
            action_tunnel = find_corresponding_tunnel(self.batch_tunnel,selected.to(self.device))
            encoded_first_node = _get_encoding(self.encoded_nodes, selected)
            encoded_first_tunnel = _get_encoding(self.encoded_tunnels,action_tunnel)
            # shape: (batch, pomo, embedding)
            self.decoder.set_q1(encoded_first_node,encoded_first_tunnel)

        else:
            encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
            current_tunnel = find_corresponding_tunnel(self.batch_tunnel,state.current_node.to(self.device))
            encoded_last_tunnel = _get_encoding(self.encoded_tunnels, current_tunnel.to(self.device))
            # shape: (batch, pomo, embedding)
            probs = self.decoder(encoded_last_node, encoded_last_tunnel,ninf_mask=state.ninf_mask)
            # shape: (batch, pomo, problem+1)
            if self.training or self.model_params['eval_type'] == 'softmax':
                while True:

                    selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)
                    # shape: (batch, pomo)

                    prob = probs[state.BATCH_IDX, state.POMO_IDX, selected] \
                        .reshape(batch_size, pomo_size)
                    # shape: (batch, pomo)

                    if (prob != 0).all():
                        break

            else:
                probs=probs[:,:,:-1]
                selected = probs.argmax(dim=2)
                # shape: (batch, pomo)
                prob = None


        return selected, prob


def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


########################################
# ENCODER
########################################

class TSP_Encoder_1(nn.Module):
    def __init__(self, model_params,device,input_dimm):
        super().__init__()
        self.model_params = model_params
        self.device = device
        embedding_dim = self.model_params['embedding_dim']
        encoder_layer_num = self.model_params['encoder_layer_num']

        self.embedding = nn.Linear(input_dimm, embedding_dim)
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, data):
        # data.shape: (batch, problem, 2)

        embedded_input = self.embedding(data)
        # shape: (batch, problem, embedding)

        out = embedded_input
        for layer in self.layers:
            out = layer(out)

        return out

class TSP_Encoder_2(nn.Module):
    def __init__(self, model_params,device,input_dimm):
        super().__init__()
        self.model_params = model_params
        self.device = device
        embedding_dim = self.model_params['embedding_dim']
        encoder_layer_num = self.model_params['encoder_layer_num']

        self.embedding = nn.Linear(input_dimm, embedding_dim)
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, data):
        # data.shape: (batch, problem, 2)

        embedded_input = self.embedding(data)
        # shape: (batch, problem, embedding)

        out = embedded_input
        for layer in self.layers:
            out = layer(out)
        _,dim,_ = out.shape
        out1 = out[:,:dim//2,:]
        out2 = out[:,dim//2:,:]
        outt = out1+out2
        return outt

class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.addAndNormalization1 = Add_And_Normalization_Module(**model_params)
        self.feedForward = Feed_Forward_Module(**model_params)
        self.addAndNormalization2 = Add_And_Normalization_Module(**model_params)

    def forward(self, input1):
        # input.shape: (batch, problem, EMBEDDING_DIM)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)
        # q shape: (batch, HEAD_NUM, problem, KEY_DIM)

        out_concat = multi_head_attention(q, k, v)
        # shape: (batch, problem, HEAD_NUM*KEY_DIM)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, EMBEDDING_DIM)

        out1 = self.addAndNormalization1(input1, multi_head_out)
        out2 = self.feedForward(out1)
        out3 = self.addAndNormalization2(out1, out2)

        return out3
        # shape: (batch, problem, EMBEDDING_DIM)

import revtorch as rv
from torch import Tensor
class RevMHALowerEncoder_2Encoder_Nodes(nn.Module):
    def __init__(
        self,
        device,
        n_layers: int,
        n_heads: int,
        embedding_dim: int,
        input_dim: int,
        intermediate_dim: int,
        add_init_projection=True,
    ):
        super().__init__()
        self.device = device
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        if add_init_projection or input_dim != embedding_dim:
            self.init_projection_layer = torch.nn.Linear(input_dim, embedding_dim).to(self.device)
        self.num_hidden_layers = n_layers
        blocks = []
        for _ in range(n_layers):
            f_func = MHABlock(embedding_dim, n_heads).to(device)
            g_func = FFBlock(embedding_dim, intermediate_dim).to(device)
            # we construct a reversible block with our F and G functions
            blocks.append(rv.ReversibleBlock(f_func, g_func, split_along_dim=-1))

        self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks))

    def forward(self, x: Tensor, mask=None):
        if hasattr(self, "init_projection_layer"):
            x = self.init_projection_layer(x)
        x = torch.cat([x, x], dim=-1)
        # print(self.device,x.device)
        out = self.sequence(x)
        outt = torch.stack(out.chunk(2, dim=-1))[-1]
        #_,dim,_ = outt.shape
        #out1 = outt[:,:dim//2,:]
        #out2 = outt[:,dim//2:,:]
        #outtt = out1+out2

        return outt

class RevMHALowerEncoder_2Encoder_Tunnels(nn.Module):
    def __init__(
        self,
        device,
        n_layers: int,
        n_heads: int,
        embedding_dim: int,
        input_dim: int,
        intermediate_dim: int,
        add_init_projection=True,
    ):
        super().__init__()
        self.device = device
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        if add_init_projection or input_dim != embedding_dim:
            self.init_projection_layer = torch.nn.Linear(input_dim, embedding_dim).to(self.device)
        self.num_hidden_layers = n_layers
        blocks = []
        for _ in range(n_layers):
            f_func = MHABlock(embedding_dim, n_heads).to(device)
            g_func = FFBlock(embedding_dim, intermediate_dim).to(device)
            # we construct a reversible block with our F and G functions
            blocks.append(rv.ReversibleBlock(f_func, g_func, split_along_dim=-1))

        self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks))

    def forward(self, x: Tensor, mask=None):
        if hasattr(self, "init_projection_layer"):
            x = self.init_projection_layer(x)
        x = torch.cat([x, x], dim=-1)
        # print(self.device,x.device)
        out = self.sequence(x)
        outt = torch.stack(out.chunk(2, dim=-1))[-1]
        _,dim,_ = outt.shape
        out1 = outt[:,:dim//2,:]
        out2 = outt[:,dim//2:,:]
        outtt = out1+out2

        return outtt
    
class MHABlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.mixing_layer_norm = nn.BatchNorm1d(hidden_size)
        self.mha = nn.MultiheadAttention(hidden_size, num_heads, bias=False)

    def forward(self, hidden_states: Tensor):

        assert hidden_states.dim() == 3
        hidden_states = self.mixing_layer_norm(hidden_states.transpose(1, 2)).transpose(
            1, 2
        )
        hidden_states_t = hidden_states.transpose(0, 1)
        mha_output = self.mha(hidden_states_t, hidden_states_t, hidden_states_t)[
            0
        ].transpose(0, 1)

        return mha_output


class FFBlock(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.feed_forward = nn.Linear(hidden_size, intermediate_size)
        self.output_dense = nn.Linear(intermediate_size, hidden_size)
        self.output_layer_norm = nn.BatchNorm1d(hidden_size)
        self.activation = nn.GELU()

        #hidden_states.shape:[batch_size(B),sequence_length(L),hidden_size(H)]
        #output.shape[B,L,H]
    def forward(self, hidden_states: Tensor):
        hidden_states = (
            self.output_layer_norm(hidden_states.transpose(1, 2))
            .transpose(1, 2)
            .contiguous()
        )
        intermediate_output = self.feed_forward(hidden_states)
        intermediate_output = self.activation(intermediate_output)
        output = self.output_dense(intermediate_output)

        return output

########################################
# DECODER
########################################

class TSP_Decoder(nn.Module):
    def __init__(self,model_params,device):
        super().__init__()
        self.model_params = model_params
        self.device = device
        self.embedding_dim = self.model_params['embedding_dim']
        self.qkv_dim = self.model_params['qkv_dim']
        self.n_heads = self.model_params['head_num']
        
        self.Wkn = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wvn = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wkt = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wvt = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqn_first = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqt_first = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqn_last = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqt_last = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        
        self.regret_embedding = nn.Parameter(torch.Tensor(self.embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine_1 = nn.Linear(self.n_heads * self.qkv_dim, self.embedding_dim)
        self.multi_head_combine_2 = nn.Linear(self.n_heads * self.qkv_dim, self.embedding_dim)
        self.k = None
        self.v = None
        self.q_first = None
        
    def set_kv(self,kv_matrix_nodes,kv_matrix_tunnels):
        self.knodes = reshape_by_heads(self.Wkn(kv_matrix_nodes),head_num=self.n_heads)
        self.vnodes = reshape_by_heads(self.Wvn(kv_matrix_nodes),head_num=self.n_heads)

        self.single_head_key_nodes = kv_matrix_nodes.transpose(1, 2)
        self.ktunnels = reshape_by_heads(self.Wkt(kv_matrix_tunnels),head_num=self.n_heads)
        self.vtunnels = reshape_by_heads(self.Wvt(kv_matrix_tunnels),head_num=self.n_heads)

        self.single_head_key_tunnels = kv_matrix_tunnels.transpose(1, 2)
        # shape: (batch, embedding, problem)

    def set_q1(self,q1_matrix_nodes,q1_matrix_tunnels): 
        self.q_first_nodes = reshape_by_heads(self.Wqn_first(q1_matrix_nodes),head_num=self.n_heads)
        self.q_first_tunnels = reshape_by_heads(self.Wqt_first(q1_matrix_tunnels),head_num=self.n_heads)


    def forward(self, encoded_last_node, encoded_last_tunnel, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # ninf_mask.shape: (batch, pomo, problem)

        #  Multi-Head Attention
        #######################################################
        qn_last = reshape_by_heads(self.Wqn_last(encoded_last_node), head_num=self.n_heads)
        qt_last = reshape_by_heads(self.Wqt_last(encoded_last_tunnel), head_num=self.n_heads)
        # shape: (batch, head_num, pomo, qkv_dim)

        qn = self.q_first_nodes + qn_last
        qt = self.q_first_tunnels + qt_last
        # shape: (batch, head_num, pomo, qkv_dim)

        #out_concat_1 = multi_head_attention(qn+qt, torch.cat((self.knodes,self.ktunnels),dim=2),torch.cat((self.vnodes,self.vtunnels),dim=2), rank3_ninf_mask=torch.cat((ninf_mask,ninf_mask),dim=2).to(self.device))
        out_concat_1 = multi_head_attention(qn+qt,self.ktunnels,self.vtunnels,rank3_ninf_mask=ninf_mask.to(self.device))
        out_concat_2 = multi_head_attention(qn+qt, self.knodes,self.vnodes, rank3_ninf_mask=ninf_mask.to(self.device))
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine_1(out_concat_1)+self.multi_head_combine_2(out_concat_2)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        
        score = torch.matmul(mh_atten_out, self.single_head_key_nodes+self.single_head_key_tunnels)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = np.sqrt(self.model_params['embedding_dim'])
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)#
        
        score_masked = score_clipped + ninf_mask.to(self.device)

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs


########################################
# NN SUB CLASS / FUNCTIONS
########################################

def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed


def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None):
    # q shape: (batch, head_num, n, key_dim)   : n can be either 1 or PROBLEM_SIZE
    # k,v shape: (batch, head_num, problem, key_dim)
    # rank2_ninf_mask.shape: (batch, problem)
    # rank3_ninf_mask.shape: (batch, group, problem)

    batch_s = q.size(0)
    head_num = q.size(1)
    n = q.size(2)
    key_dim = q.size(3)

    input_s = k.size(2)

    score = torch.matmul(q, k.transpose(2, 3))
    # shape: (batch, head_num, n, problem)

    score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))
    if rank2_ninf_mask is not None:
        score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)
    if rank3_ninf_mask is not None:
        score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)

    weights = nn.Softmax(dim=3)(score_scaled)
    # shape: (batch, head_num, n, problem)

    out = torch.matmul(weights, v)
    # shape: (batch, head_num, n, key_dim)

    out_transposed = out.transpose(1, 2)
    # shape: (batch, n, head_num, key_dim)

    out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)
    # shape: (batch, n, head_num*key_dim)

    return out_concat


class Add_And_Normalization_Module(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # input.shape: (batch, problem, embedding)

        added = input1 + input2
        # shape: (batch, problem, embedding)

        transposed = added.transpose(1, 2)
        # shape: (batch, embedding, problem)

        normalized = self.norm(transposed)
        # shape: (batch, embedding, problem)

        back_trans = normalized.transpose(1, 2)
        # shape: (batch, problem, embedding)

        return back_trans


class Feed_Forward_Module(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        ff_hidden_dim = model_params['ff_hidden_dim']

        self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
        self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))


def find_corresponding_tunnel(tensor1, tensor2):
    B, N, _ = tensor1.size()
    B, K = tensor2.size()
    expanded_tensor2 = tensor2.unsqueeze(2).unsqueeze(3)
    expanded_tensor1 = tensor1.unsqueeze(1)
    comparison = (expanded_tensor2 == expanded_tensor1).any(dim=-1)
    index = comparison.nonzero()[:, 2]
    tensor3 = index.view(B, K)
    return tensor3