import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import find_corresponding_tunnel

class TSPModel(nn.Module):
    def __init__(self, model_params,device):
        super().__init__()
        self.model_params = model_params
        self.encoder_nodes = TSP_Encoder_1(model_params,device,16)
        self.encoder_tunnels = TSP_Encoder_2(model_params,device,40)
        self.decoder = TSP_Decoder(model_params,device)
        self.encoded_nodes = None
        self.encoded_tunnels = None
        self.device = device
        # shape: (batch, problem, EMBEDDING_DIM)
    
    def pre_forward(self,embeddings_nodes,embeddings_tunnels,batch_tunnel):
        self.encoded_nodes = embeddings_nodes
        self.encoded_tunnels = embeddings_tunnels
        self.batch_tunnel = batch_tunnel
        self.decoder.set_kv(self.encoded_nodes,self.encoded_tunnels)

    def forward(self,state,batch_size,pomo_size,batch_idx_range,group_idx_range):
        if state.current_node == None:
            action = torch.arange(pomo_size)[None,:].expand(batch_size,pomo_size)
            action_tunnel = find_corresponding_tunnel(self.batch_tunnel,action.to(self.device))
            prob = torch.ones(size = (batch_size,pomo_size))
            encoded_first_node = _get_encoding(encoded_nodes=self.encoded_nodes, node_index_to_pick= action.to(self.device))
            encoded_first_tunnel = _get_encoding(encoded_nodes=self.encoded_tunnels, node_index_to_pick= action_tunnel.to(self.device))
            self.decoder.set_q1(encoded_first_node,encoded_first_tunnel)
        else:
            encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node.to(self.device))
            current_tunnel = find_corresponding_tunnel(self.batch_tunnel,state.current_node.to(self.device))
            encoded_last_tunnel = _get_encoding(self.encoded_tunnels, current_tunnel.to(self.device))
            # shape: (batch, pomo, embedding)
            probs = self.decoder(encoded_last_node, encoded_last_tunnel, ninf_mask=state.ninf_mask)
            # shape: (batch, pomo, problem)
            if self.training or self.model_params['eval_type'] == 'softmax':
                while True:
                    action = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
                        .squeeze(dim=1).reshape(batch_size, pomo_size)
                    # shape: (batch, pomo)

                    prob = probs[batch_idx_range, group_idx_range, action] \
                        .reshape(batch_size, pomo_size)
                    # shape: (batch, pomo)

                    if (prob != 0).all():
                        break

            else:
                action = probs.argmax(dim=2)
                # shape: (batch, pomo)
                prob = None
        
        return action,prob

def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)
    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)
    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


########################################
# ENCODER
########################################

class TSP_Encoder_1(nn.Module):
    def __init__(self,model_params,device,input_dimm):
        super().__init__()
        self.model_params = model_params
        self.device = device
        self.embedding_dim = model_params['embedding_dim']
        self.encoder_layer_num = model_params['n_layers']
        
        self.embedding_layer = nn.Linear(input_dimm,self.embedding_dim)
        self.layers = nn.ModuleList([TSP_EncoderLayer(model_params,device) for _ in range(self.encoder_layer_num)])


    def forward(self,data):
        embedded_input = self.embedding_layer(data)

        out = embedded_input
        for layer in self.layers:
            out = layer(out)

        return out

class TSP_Encoder_2(nn.Module):
    def __init__(self,model_params,device,input_dimm):
        super().__init__()
        self.model_params = model_params
        self.device = device
        self.embedding_dim = model_params['embedding_dim']
        self.encoder_layer_num = model_params['n_layers']
        
        self.embedding_layer = nn.Linear(input_dimm,self.embedding_dim)
        self.layers = nn.ModuleList([TSP_EncoderLayer(model_params,device) for _ in range(self.encoder_layer_num)])


    def forward(self,data):
        embedded_input = self.embedding_layer(data)

        out = embedded_input
        for layer in self.layers:
            out = layer(out)

        #IF consider direvtion, else pass
        _,dim,_ = out.shape
        out1 = out[:,:dim//2,:]
        out2 = out[:,dim//2:,:]
        outt = out1+out2
        
        return outt

class TSP_EncoderLayer(nn.Module):
    def __init__(self,model_params,device):
        super().__init__()
        self.model_params = model_params
        self.device = device
        self.embedding_dim = self.model_params['embedding_dim']
        self.qkv_dim = self.model_params['mid_dim']
        self.n_heads = self.model_params['n_heads']
        self.ff_hidden_rate = self.model_params['ff_hidden_rate']
        self.Wq = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wk = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wv = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)

        self.multi_head_combine = nn.Linear(self.qkv_dim * self.n_heads, self.embedding_dim)

        self.AddAndNormal1 = Add_and_Norm_module(self.model_params)
        self.FF = nn.Sequential(nn.Linear(self.embedding_dim,self.embedding_dim * self.ff_hidden_rate),
                                nn.ReLU(),
                                nn.Linear(self.embedding_dim * self.ff_hidden_rate,self.embedding_dim))
        self.AddAndNormal2 = Add_and_Norm_module(self.model_params)

    def forward(self,input):
        q = reshape_by_heads(self.Wq(input),head_num=self.n_heads)
        k = reshape_by_heads(self.Wk(input),head_num=self.n_heads)
        v = reshape_by_heads(self.Wv(input),head_num=self.n_heads)

        out_concat = multi_head_attention(q,k,v)
        multi_head_out = self.multi_head_combine(out_concat)
        out1 = self.AddAndNormal1(input,multi_head_out)
        out2 = self.FF(out1)
        out3 = self.AddAndNormal2(out1,out2)
        return out3


########################################
# DECODER
########################################

class TSP_Decoder(nn.Module):
    def __init__(self,model_params,device):
        super().__init__()
        self.model_params = model_params
        self.device = device
        self.embedding_dim = self.model_params['embedding_dim']
        self.qkv_dim = self.model_params['mid_dim']
        self.n_heads = self.model_params['n_heads']
        
        self.Wkn = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wvn = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wkt = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wvt = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqn_first = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqt_first = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqn_last = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        self.Wqt_last = nn.Linear(self.embedding_dim, self.qkv_dim * self.n_heads, bias = False)
        
        self.multi_head_combine_1 = nn.Linear(self.n_heads * self.qkv_dim, self.embedding_dim)
        self.multi_head_combine_2 = nn.Linear(self.n_heads * self.qkv_dim, self.embedding_dim)
        self.k = None
        self.v = None
        self.q_first = None
        
    def set_kv(self,kv_matrix_nodes,kv_matrix_tunnels):
        self.knodes = reshape_by_heads(self.Wkn(kv_matrix_nodes),head_num=self.n_heads)
        self.vnodes = reshape_by_heads(self.Wvn(kv_matrix_nodes),head_num=self.n_heads)
        self.single_head_key_nodes = kv_matrix_nodes.transpose(1, 2)
        self.ktunnels = reshape_by_heads(self.Wkt(kv_matrix_tunnels),head_num=self.n_heads)
        self.vtunnels = reshape_by_heads(self.Wvt(kv_matrix_tunnels),head_num=self.n_heads)
        self.single_head_key_tunnels = kv_matrix_tunnels.transpose(1, 2)
        # shape: (batch, embedding, problem)

    def set_q1(self,q1_matrix_nodes,q1_matrix_tunnels):
        self.q_first_nodes = reshape_by_heads(self.Wqn_first(q1_matrix_nodes),head_num=self.n_heads)
        self.q_first_tunnels = reshape_by_heads(self.Wqt_first(q1_matrix_tunnels),head_num=self.n_heads)


    def forward(self, encoded_last_node, encoded_last_tunnel, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # ninf_mask.shape: (batch, pomo, problem)

        #  Multi-Head Attention
        #######################################################
        qn_last = reshape_by_heads(self.Wqn_last(encoded_last_node), head_num=self.n_heads)
        qt_last = reshape_by_heads(self.Wqt_last(encoded_last_tunnel), head_num=self.n_heads)
        # shape: (batch, head_num, pomo, qkv_dim)

        qn = self.q_first_nodes + qn_last
        qt = self.q_first_tunnels + qt_last
        # shape: (batch, head_num, pomo, qkv_dim)

        out_concat_1 = multi_head_attention(qn+qt, self.knodes, self.vnodes, rank3_ninf_mask=ninf_mask.to(self.device))
        out_concat_2 = multi_head_attention(qt+qn, self.ktunnels, self.vtunnels, rank3_ninf_mask=ninf_mask.to(self.device))
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine_1(out_concat_1)+self.multi_head_combine_2(out_concat_2)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key_nodes+self.single_head_key_tunnels)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = np.sqrt(self.model_params['embedding_dim'])
        logit_clipping = self.model_params['tanh_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)
        
        score_masked = score_clipped + ninf_mask.to(self.device)

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs




########################################
# NN SUB CLASS / FUNCTIONS
########################################

class Add_and_Norm_module(nn.Module):
    def __init__(self,model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)


    def forward(self,input1,input2):
        added = input1 + input2
        transposed = added.transpose(1,2)
        normalized = self.norm(transposed)
        back_trans = normalized.transpose(1,2)
        return back_trans


def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, problem, head_num*key_dim) 

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, problem, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, problem, key_dim)

    return q_transposed


def multi_head_attention(q,k,v,rank2_ninf_mask=None, rank3_ninf_mask=None):
    # q.shape: (batch,head_num,n,key_dim)
    # k,v.shape: (batch,head_num,problems,key_dim)
    # rank2_ninf_mask.shape: (batch, problem)
    # rank3_ninf_mask.shape: (batch, group, problem)
    batch_s = q.size(0)
    head_num = q.size(1)
    n = q.size(2)
    key_dim = q.size(3)

    inputs = k.size(2)

    score = torch.matmul(q,k.transpose(2,3))
    #shape:(batch,head_num,n, problems[or inputs])

    score_scaled = score / torch.sqrt(torch.tensor(key_dim,dtype = torch.float))
    if rank2_ninf_mask is not None:
        score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, inputs)
    if rank3_ninf_mask is not None:
        score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, inputs)   
    weight = nn.Softmax(dim = 3)(score_scaled)
    #weight.shape(batch,head_num,n,problems)

    out = torch.matmul(weight,v)
    # out.shape(batch,head_num,n,key_dim)

    out_transpose = out.transpose(1,2)
    out_concat = out_transpose.reshape(batch_s,n,head_num*key_dim)
    
    return out_concat
