#!/usr/bin/env python3

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Size notations:
# B = batch_size, H = hidden_size, M = block_size, L = attn_span
def requ(x):
    return torch.relu(x)**2

def requr(x):
    return torch.requ(x)-torch.requ(x-0.5)

class SeqAttention(nn.Module):
    """Sequential self-attention layer.
    """

    def __init__(self, hidden_size, enable_mem, attn_span):
        nn.Module.__init__(self)
        self.hidden_size = hidden_size  # size of a single head
        self.attn_span = attn_span
        self.enable_mem = enable_mem

    def forward(self, query, key, value):
        # query size = B x M x H
        # key, value sizes = B x (M+L) x H

        # compute attention from context
        # B x M (dest) x (M+L) (src)
        attn = torch.matmul(query, key.transpose(-1, -2))

        attn = attn / math.sqrt(self.hidden_size)  # B x M X (M+L)
        attn = F.softmax(attn, dim=-1)


        out = torch.matmul(attn, value)  # B x M x H

        return out

    def get_cache_size(self):
            return self.attn_span


class MultiHeadSeqAttention(nn.Module):
    def __init__(self, hidden_size, enable_mem, nb_heads, attn_span):
        nn.Module.__init__(self)
        assert hidden_size % nb_heads == 0
        self.nb_heads = nb_heads
        self.head_dim = hidden_size // nb_heads
        self.attn = SeqAttention(
            hidden_size=self.head_dim, enable_mem=enable_mem, attn_span=attn_span)
        self.proj_query = nn.Linear(hidden_size, hidden_size, bias=False)
        self.proj_out = nn.Linear(hidden_size, hidden_size, bias=False)
        self.proj_val = nn.Linear(hidden_size, hidden_size, bias=False)
        self.proj_key = nn.Linear(hidden_size, hidden_size, bias=False)

        # note that the linear layer initialization in current Pytorch is kaiming uniform init

    def head_reshape(self, x):
        K = self.nb_heads
        D = self.head_dim
        x = x.view(x.size()[:-1] + (K, D))  # B x (M+L) x K x D
        x = x.transpose(1, 2).contiguous()  # B x K x (M+L) x D
        x = x.view(-1, x.size(-2), x.size(-1))  # B_K x (M+L) x D
        return x

    def forward(self, query, key, value):
        B = query.size(0)
        K = self.nb_heads
        D = self.head_dim
        M = query.size(1)

        query = self.proj_query(query)
        query = self.head_reshape(query)
        value = self.proj_val(value)
        value = self.head_reshape(value)
        key = self.proj_key(key)
        key = self.head_reshape(key)

        out = self.attn(query, key, value)  # B_K x M x D
        out = out.view(B, K, M, D)  # B x K x M x D
        out = out.transpose(1, 2).contiguous()  # B x M x K x D
        out = out.view(B, M, -1)  # B x M x K_D
        out = self.proj_out(out)
        return out


class FeedForwardLayer(nn.Module):
    def __init__(self, hidden_size, inner_hidden_size):
        nn.Module.__init__(self)
        self.fc1 = nn.Linear(hidden_size, inner_hidden_size)
        self.fc2 = nn.Linear(inner_hidden_size, hidden_size)

    def forward(self, h):
        h1 = requr(self.fc1(h))
        h2 = requ(self.fc2(h1))+h
        return h2

class TransformerSeqLayer(nn.Module):
    def __init__(self, hidden_size, enable_mem, nb_heads, attn_span, inner_hidden_size):
        nn.Module.__init__(self)
        self.attn = MultiHeadSeqAttention(
            hidden_size=hidden_size, enable_mem=enable_mem, nb_heads=nb_heads, attn_span=attn_span)
        self.ff = FeedForwardLayer(hidden_size=hidden_size, inner_hidden_size=inner_hidden_size)
        self.norm1=nn.BatchNorm1d(hidden_size)
        self.norm2=nn.BatchNorm1d(hidden_size)
        self.enable_mem = enable_mem

    def forward(self, h, h_cache):
        # h = B x M x H
        # h_cache = B x L x H
        if self.enable_mem:
            h_all = torch.cat([h_cache, h], dim=1)  # B x (M+L) x H
        else:
            h_all = h_cache                         # B x M x H
        attn_out = self.attn(h, h_all, h_all)
        h = self.norm1((h + attn_out).view(-1, h.size(-1))).view(*h.size()) # B x M x H
        ff_out = self.ff(h)
        out = self.norm2((h+ ff_out).view(-1, h.size(-1))).view(*h.size()) # B x M x H
        return out


class EncoderSeq(nn.Module):
    def __init__(self, state_size, hidden_size, nb_heads, encoder_nb_layers,
                 attn_span, inner_hidden_size):
        nn.Module.__init__(self)
        # init embeddings
        self.init_embed = nn.Linear(state_size, hidden_size)

        self.layers = nn.ModuleList()
        self.layers.extend(
            TransformerSeqLayer(
                hidden_size=hidden_size, enable_mem=True, nb_heads=nb_heads,
                attn_span=attn_span, inner_hidden_size=inner_hidden_size)
            for _ in range(encoder_nb_layers))

    def forward(self, x, h_cache):
        # x size = B x M
        block_size = x.size(1)
        h = self.init_embed(x)  # B x M x H
        h_cache_next = []
        for l, layer in enumerate(self.layers):
            cache_size = layer.attn.attn.get_cache_size()

            # B x L x H
            h_cache_next_l = torch.cat(
                [h_cache[l][:, -cache_size + 1:, :], h[:, 0:1, :]],
                dim=1).detach()

            h_cache_next.append(h_cache_next_l)

            h = layer(h, h_cache[l])  # B x M x H

        return h, h_cache_next


class QDecoder(nn.Module):
    def __init__(self, state_size, hidden_size, nb_heads, decoder_nb_layers,
                 attn_span, inner_hidden_size):
        nn.Module.__init__(self)
        # init embeddings
        self.init_embed = nn.Linear(state_size, hidden_size)

        self.layers = nn.ModuleList()
        self.layers.extend(
            TransformerSeqLayer(
                hidden_size=hidden_size, enable_mem=False, nb_heads=nb_heads,
                attn_span=attn_span, inner_hidden_size=inner_hidden_size)
            for _ in range(decoder_nb_layers))

    def forward(self, x, embedding):
        # x size = B x Q_M
        block_size = x.size(1)
        h = self.init_embed(x)  # B x Q_M x H
        h_cache_next = []
        for l, layer in enumerate(self.layers):

            h = layer(h, embedding)  # B x Q_M x H

        return h

class PackDecoder(nn.Module):
    def __init__(self, head_hidden_size, res_size, state_size, hidden_size,nb_heads, decoder_layers,attn_span, inner_hidden_size):
        nn.Module.__init__(self)

        self.att_decoder = QDecoder(state_size, hidden_size,nb_heads=nb_heads,decoder_nb_layers=decoder_layers,attn_span=attn_span, inner_hidden_size=inner_hidden_size)

        self.head = nn.Sequential(
                            nn.Linear(hidden_size, head_hidden_size),
                            nn.ReLU(),
                            nn.Linear(head_hidden_size, res_size)
                            )


    def forward(self, x, embedding):
        h = self.att_decoder(x, embedding)
        out = self.head(h)
        return out
     
class RCQL(nn.Module):
    def __init__(self, state_size, hidden_size, nb_heads, encoder_nb_layers, attn_span, inner_hidden_size,src_head_hidden_size,pos_head_hidden_size, 
                 s_res_size, r_res_size,x_res_size,y_res_size,decoder_nb_layers,item_state_size):
        nn.Module.__init__(self)
        
        self.encoder=EncoderSeq(state_size=state_size,hidden_size=hidden_size,nb_heads=nb_heads,encoder_nb_layers=encoder_nb_layers,attn_span=attn_span,
                                inner_hidden_size=inner_hidden_size)
        self.s_decoder=PackDecoder(head_hidden_size=src_head_hidden_size,res_size=s_res_size,state_size=item_state_size,hidden_size=hidden_size,
                                   decoder_layers=decoder_nb_layers,attn_span=attn_span,inner_hidden_size=inner_hidden_size,nb_heads=nb_heads)
        self.r_decoder=PackDecoder(head_hidden_size=src_head_hidden_size,res_size=r_res_size,state_size=item_state_size,hidden_size=hidden_size,
                                   decoder_layers=decoder_nb_layers,attn_span=attn_span,inner_hidden_size=inner_hidden_size,nb_heads=nb_heads)
        self.x_decoder=PackDecoder(head_hidden_size=pos_head_hidden_size,res_size=x_res_size,state_size=item_state_size,hidden_size=hidden_size,
                                   decoder_layers=decoder_nb_layers,attn_span=attn_span,inner_hidden_size=inner_hidden_size,nb_heads=nb_heads)
        self.y_decoder=PackDecoder(head_hidden_size=pos_head_hidden_size,res_size=y_res_size,state_size=item_state_size,hidden_size=hidden_size,
                                   decoder_layers=decoder_nb_layers,attn_span=attn_span,inner_hidden_size=inner_hidden_size,nb_heads=nb_heads)
        self.softmax=nn.Softmax(dim=1)
        self.alpha=torch.nn.Parameter(torch.tensor(1.0),requires_grad=True)
        
    def calc_seq_idx(self,packed_state,unpacked_state,isGreedy,maskseq,h_caches):
        actor_encoder_out, h_caches = self.encoder(packed_state, h_caches)
        s_out = self.s_decoder(unpacked_state, actor_encoder_out)
        s_out=s_out.squeeze(-1)
        batch_size=packed_state.size()[0]
        templist=[i for i in range(batch_size)]
        maskseq=np.array(maskseq)
        maskseq=maskseq.astype(int).transpose(1,0)   
        s_out[templist,maskseq]=-np.inf
        s_out=self.softmax(s_out)
        if(isGreedy):
            seqidx=torch.max(s_out,1)[1]
            return seqidx,actor_encoder_out,h_caches
        else:
            seqidx=torch.multinomial(s_out,1)
            seqidx=torch.squeeze(seqidx)
            seqpro=s_out[templist,seqidx]
            return seqidx,seqpro,actor_encoder_out,h_caches
    
    def calc_ori_idx(self,actor_encoder_out,select_seq_idx,isGreedy):
        r_out = self.r_decoder(select_seq_idx, actor_encoder_out)
        batch_size=actor_encoder_out.size()[0]
        templist=[i for i in range(batch_size)]
        r_out=torch.squeeze(r_out)
        r_out=self.softmax(r_out)
        if(isGreedy):
            oriidx=torch.max(r_out,1)[1]
            return oriidx
        else:
            oriidx=torch.multinomial(r_out,1)
            oriidx=torch.squeeze(oriidx)
            oripro=r_out[templist,oriidx]
            return oriidx,oripro
        
    def calc_x_idx(self,actor_encoder_out,select_seq_idx_ori,isGreedy,maskx):
        x_out = self.x_decoder(select_seq_idx_ori, actor_encoder_out)
        batch_size=actor_encoder_out.size()[0]
        templist=[i for i in range(batch_size)]
        x_out=torch.squeeze(x_out)
        for i in range(batch_size):
            x_out[i,maskx[i]]=-np.inf
        x_out=self.softmax(x_out)
        if(isGreedy):
            xidx=torch.max(x_out,1)[1]
            return xidx
        else:
            xidx=torch.multinomial(x_out,1)
            xidx=torch.squeeze(xidx)
            xpro=x_out[templist,xidx]
            return xidx,xpro
    
    def calc_y_idx(self,actor_encoder_out,select_seq_idx_ori,isGreedy,masky):
        y_out = self.y_decoder(select_seq_idx_ori, actor_encoder_out)
        batch_size=actor_encoder_out.size()[0]
        templist=[i for i in range(batch_size)]
        y_out=torch.squeeze(y_out)
        for i in range(batch_size):
            y_out[i,masky[i]]=-np.inf   
        y_out=self.softmax(y_out)
        if(isGreedy):
            yidx=torch.max(y_out,1)[1]
            return yidx
        else:
            yidx=torch.multinomial(y_out,1)
            yidx=torch.squeeze(yidx)
            ypro=y_out[templist,yidx]
            return yidx,ypro
    
    def set_device(self,device):
        self.encoder=self.encoder.to(device)
        self.s_decoder=self.s_decoder.to(device)
        self.r_decoder=self.r_decoder.to(device)
        self.x_decoder=self.x_decoder.to(device)
        self.y_decoder=self.y_decoder.to(device)
        
    def set_eval(self):
        self.encoder=self.encoder.eval()
        self.s_decoder=self.s_decoder.eval()
        self.r_decoder=self.r_decoder.eval()
        self.x_decoder=self.x_decoder.eval()
        self.y_decoder=self.y_decoder.eval()
    

