import numpy as np
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from env.statement import num_statements
from env.operator import num_operators
from model.encoder import DenseEncoder, DenseQueryEncoder
from model.model import BaseModel

K=2

class Query(BaseModel):
    """Returns a compiled model described in Appendix section C.

    Arguments:
        I (int): number of inputs in each program. input count is
            padded to I with null type and vals.
        E (int): embedding dimension
        M (int): number of examples per program. default 5.

    Returns:
        keras.model compiled keras model as described in Appendix
        section C
    """
    def __init__(self, embedding):
        super(Query, self).__init__()
        self.encoder = DenseQueryEncoder(embedding)
        self.list_head = nn.Linear(params.dense_output_size, 3 * params.max_list_len * params.integer_range)
        self.int_head = nn.Linear(params.dense_output_size, 3 * 1 * params.integer_range)

    def forward(self, x, typ, hard_softmax=False):
        x = self.encoder(x, typ[:, :, :, :2])
        list_ = self.list_head(x)
        list_ = list_.view(x.shape[0], 3, params.max_list_len, params.integer_range)
        list_ = torch.softmax(list_, -1)
        list_ = torch.cat([list_, torch.zeros(*list_.shape[:-1], 1, device='cuda')], -1)

        int_ = self.int_head(x)
        int_ = int_.view(x.shape[0], 3, 1, params.integer_range)
        int_ = torch.softmax(int_, -1)
        int_ = torch.cat([int_, torch.zeros(*int_.shape[:2], params.max_list_len - 1, params.integer_range, device='cuda')], -2)
        int_ = torch.cat([int_, torch.ones(*int_.shape[:-1], 1, device='cuda')], -1)
        int_[:, :, 0, -1] = 0.

        null_ = torch.zeros(*list_.shape, device='cuda')
        null_[:, :, :, -1] = 1.

        list_ = list_.view(x.shape[0], 3, -1).unsqueeze(2)
        int_ = int_.view(x.shape[0], 3, -1).unsqueeze(2)
        null_ = null_.view(x.shape[0], 3, -1).unsqueeze(2)
        choice = torch.cat([int_, list_, null_], 2)
        # input type in 1 example
        x = typ[:, 0, :-1].unsqueeze(2).float() @ choice
        x = x.view(x.shape[0], 3, params.max_list_len, params.integer_range + 1)

        index = x.max(-1, keepdim=True)[1]

        if hard_softmax:
            x_hard = torch.zeros_like(x, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
            x_hard = (x_hard - x).detach() + x
            x = x_hard
        
        x = x.unsqueeze(1)
        return x, index.squeeze(-1).unsqueeze(1)

class Query2(BaseModel):
    def __init__(self, E=2, M=5, emb=None):
        super(Query2, self).__init__()
#        self.embed = nn.Embedding(constants.NULL + 1, E)
        self.encoder = nn.Sequential(
            nn.Linear(E*20*4, K),
            nn.LayerNorm(K),
            nn.Sigmoid(),
            nn.Linear(K, K),
            nn.LayerNorm(K),
            nn.Sigmoid(),
            nn.Linear(K, K),
        )
#        self.decoder = nn.Sequential(
#            nn.Linear(2*K, 2*K),
#            nn.LayerNorm(2*K),
#            nn.Sigmoid(),
#            nn.Linear(2*K, 2*K),
#            nn.LayerNorm(2*K),
#            nn.Sigmoid(),
#            nn.Linear(2*K, 2*K),
##            nn.LayerNorm(2*K),
##            nn.Sigmoid(),
#        )
#        self.out = nn.Sequential(nn.Linear(2*K, 3 * 20 * 514))
        self.decoder = nn.Sequential(
            nn.Linear(K, K),
            nn.LayerNorm(K),
            nn.Sigmoid(),
            nn.Linear(K, K),
            nn.LayerNorm(K),
            nn.Sigmoid(),
            nn.Linear(K, K),
#            nn.LayerNorm(K),
#            nn.Sigmoid(),
        )
        self.out = nn.Sequential(nn.Linear(K, 3 * 20 * 514))

    def forward(self, typ, x, embed):
#        torch.set_printoptions(profile="full")
        emb = embed(x)
        emb = emb.view(*emb.shape[:3], -1)
        x = emb
        x = x.view(*x.shape[:2], -1)
#        x = torch.cat([typ, emb.view(emb.shape[0], -1, 40)], -1)
#        print(emb.shape)
#        x = emb.view(*emb.shape[:-2], 40)
#        print((x[0] - x[1]).abs().mean())
        x = self.encoder(x)

        noise = torch.randn(x.size()).cuda()
        x += noise 
##        inp_enc = x[:, :, :3].mean(2, keepdim=True)
##        out_enc = x[:, :, 3:].mean(2, keepdim=True)
##        x = torch.cat([inp_enc, out_enc], axis=-2)
#        x = x.view(x.shape[0], -1, 2*K)
        #print('-------- encoder ---------')
        #print(x)
        #print('-------- decoder ---------')
#        print((x[0] - x[1]).abs().mean())
        x = self.decoder(x)
#        print((x[0] - x[1]).abs().mean())
        #print(x)
        #print('------- mean, out ---------')
        x = x.view(x.shape[0], -1, K)
        x = x.mean(-2)
#        x, _ = x.max(-2)
#        print((x[0] - x[1]).abs().mean())
        #print(x)
        x = self.out(x)
        x = x.view(x.shape[0], 3, 20, 514)
        x = torch.softmax(x, -1)
        index = x.max(-1, keepdim=True)[1]
        x_hard = torch.zeros_like(x, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        x_hard = (x_hard - x).detach() + x
        x_hard = x_hard.unsqueeze(1)
        
        #print(x)
        #print('------- clamp ---------')
        #print(x)
        #print(x.max())
        return x_hard, index.squeeze(-1).unsqueeze(1)


class BN1d(nn.Module):
    def __init__(self, dim):
        super(BN1d, self).__init__()
        self.bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        size = x.size()
        x = x.view(x.shape[0], -1, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.bn(x)
        x = x.permute(0, 2, 1)
        x = x.view(*size)
        return x 

class my_round(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.input = input
        return torch.round(input)

    @staticmethod
    def backward(ctx, g):
        return g
