import torch
from torch.nn import Linear, TransformerEncoderLayer, LayerNorm
from torchvision.ops import MLP
import numpy as np
import time

from torch.nn.functional import softmax
from torch.distributions import Categorical

class PottsDecoder(torch.nn.Module):

    def __init__(self, q, n_layers, d_model, input_encoding_dim, param_embed_dim, n_heads, n_param_heads, dropout=0.0):

        super().__init__()
        self.q = q
        self.n_layers = n_layers
        self.d_model = d_model
        self.input_encoding_dim = input_encoding_dim
        self.param_embed_dim = param_embed_dim
        self.n_heads = n_heads
        self.n_param_heads = n_param_heads
        self.dropout = dropout

        self.input_MLP = Linear(self.input_encoding_dim, self.d_model)


        self.attention_layers = torch.nn.ModuleList([])
        self.relu = torch.nn.ReLU()
        for _ in range(n_layers):
            attention_layer = TransformerEncoderLayer(self.d_model, self.n_heads,
                                                      dropout=self.dropout, batch_first=True)
            self.attention_layers.append(attention_layer)


        
        self.P = Linear(self.d_model, self.n_param_heads*self.d_model, bias=False)   ## this uses a more sensible initialization
        self.output_linear = Linear(self.d_model, self.q)

        self.field_linear = Linear(self.q, self.q)

    def _get_params(self, param_embeddings, N, padding_mask):
        
        padding_mask_inv = (~padding_mask)

        # set embeddings to zero where padding is present
        param_embeddings = param_embeddings * padding_mask_inv.unsqueeze(1).unsqueeze(3)

        # get fields ---> here I sum over K!
        fields = torch.sum(self.field_linear(param_embeddings), dim=1)

        # set fields to 0 depending on the padding
        fields = fields * padding_mask_inv.unsqueeze(2)

        # flatten fields
        fields = fields.view(-1, N * self.q)

        # flatten to (B, n_param_heads, N*q)
        param_embeddings = param_embeddings.flatten(start_dim=2, end_dim=3)

        # outer to (B, N*q, N*q)
        couplings = torch.einsum('bpi, bpj -> bij', (param_embeddings, param_embeddings))

        # create mask for couplings
        t = torch.ones(self.q, self.q)
        mask_couplings = (1 - torch.block_diag(*([t] * N))).to(couplings.device)
        mask_couplings.requires_grad = False

        couplings = couplings * mask_couplings

        return couplings/np.sqrt(self.n_param_heads), fields/np.sqrt(self.n_param_heads)
    
    def forward(self, encodings, padding_mask):

        B, N, _ = encodings.shape

        assert B == padding_mask.shape[0]
        assert N == padding_mask.shape[1]
        
        #with profiler.record_function("Embeddings"):
        embeddings = self.input_MLP(encodings)
        #with profiler.record_function("Attention Layers"):
        for attention_layer in self.attention_layers:
            embeddings = attention_layer(embeddings, src_key_padding_mask=padding_mask)
            embeddings = self.relu(embeddings)

        param_embeddings = torch.transpose(self.P(embeddings).reshape(B, N, self.n_param_heads, self.d_model), 1, 2)

        # apply relu
        param_embeddings = self.relu(param_embeddings)

        # (B, n_param_heads, N, q)
        param_embeddings = self.output_linear(param_embeddings)
        couplings, fields = self._get_params(param_embeddings, N, padding_mask)

        return couplings, fields
    
    def _get_params_new(self, param_embeddings, N, padding_mask):
        
        padding_mask_inv = (~padding_mask)

        # set embeddings to zero where padding is present
        param_embeddings = param_embeddings * padding_mask_inv.unsqueeze(1).unsqueeze(3) 

        # get fields
        fields = torch.sum(self.field_linear(param_embeddings), dim=1) *  self.n_param_heads**(-1/2)

        # set fields to 0 depending on the padding
        fields = fields * padding_mask_inv.unsqueeze(2)

        ## To normalize later computations
        param_embeddings = param_embeddings * self.n_param_heads**(-1/4)



        return param_embeddings, fields


    
    def forward_new(self, encodings, padding_mask):

        B, N, _ = encodings.shape

        assert B == padding_mask.shape[0]
        assert N == padding_mask.shape[1]

        embeddings = self.input_MLP(encodings)
        for attention_layer in self.attention_layers:
            embeddings = attention_layer(embeddings, src_key_padding_mask=padding_mask)
            embeddings = self.relu(embeddings)

        param_embeddings = torch.transpose(self.P(embeddings).reshape(B, N, self.n_param_heads, self.d_model), 1, 2) #@ embeddings.unsqueeze(1).unsqueeze(4)
        # (1, n_param_heads, 1, d_model, d_model) x (B, 1, N, d_model, 1) -> (B, n_param_heads, N, d_model)
        param_embeddings = self.relu(param_embeddings)

        # (B, n_param_heads, N, q)
        param_embeddings = self.output_linear(param_embeddings)
        param_embeddings, fields = self._get_params_new(param_embeddings, N, padding_mask)

        return param_embeddings, fields

    
    def _get_params_ardca(self, param_embeddings, N, padding_mask):
        
        padding_mask_inv = (~padding_mask)

        # set embeddings to zero where padding is present
        param_embeddings = param_embeddings * padding_mask_inv.unsqueeze(1).unsqueeze(3)

        # get fields ---> here I sum over K!
        fields = torch.sum(self.field_linear(param_embeddings), dim=1)

        # set fields to 0 depending on the padding
        fields = fields * padding_mask_inv.unsqueeze(2)

        # flatten fields
        fields = fields.view(-1, N * self.q)

        # flatten to (B, n_param_heads, N*q)
        param_embeddings = param_embeddings.flatten(start_dim=2, end_dim=3)

        # outer to (B, N*q, N*q)
        couplings = torch.einsum('bpi, bpj -> bij', (param_embeddings, param_embeddings))

        # create mask for couplings
        t = torch.ones(self.q, self.q)
        mask_couplings = (1 - torch.block_diag(*([t] * N))).to(couplings.device)
        mask_couplings.requires_grad = False

        couplings = couplings * mask_couplings

        #### We keen only lower triangular since we want to do arDCA
        couplings = torch.tril(couplings)

        return couplings/np.sqrt(self.n_param_heads), fields/np.sqrt(self.n_param_heads)
    
    def forward_ardca(self, encodings, padding_mask):

        B, N, _ = encodings.shape

        assert B == padding_mask.shape[0]
        assert N == padding_mask.shape[1]
        
        embeddings = self.input_MLP(encodings)
        for attention_layer in self.attention_layers:
            embeddings = attention_layer(embeddings, src_key_padding_mask=padding_mask)
            embeddings = self.relu(embeddings)

        param_embeddings = torch.transpose(self.P(embeddings).reshape(B, N, self.n_param_heads, self.d_model), 1, 2)

        # apply relu
        param_embeddings = self.relu(param_embeddings)

        # (B, n_param_heads, N, q)
        param_embeddings = self.output_linear(param_embeddings)
        #with profiler.record_function("Get params"):
        couplings, fields = self._get_params_ardca(param_embeddings, N, padding_mask)

        return couplings, fields
    
    def forward_ardca_scaled(self, encodings, padding_mask):

        B, N, _ = encodings.shape

        assert B == padding_mask.shape[0]
        assert N == padding_mask.shape[1]
        
        embeddings = self.input_MLP(encodings)
        for attention_layer in self.attention_layers:
            embeddings = attention_layer(embeddings, src_key_padding_mask=padding_mask)
            embeddings = self.relu(embeddings)

        param_embeddings = torch.transpose(self.P(embeddings).reshape(B, N, self.n_param_heads, self.d_model), 1, 2)

        # apply relu
        param_embeddings = self.relu(param_embeddings)

        # (B, n_param_heads, N, q)
        param_embeddings = self.output_linear(param_embeddings)
        #with profiler.record_function("Get params"):
        couplings, fields = self._get_params_ardca(param_embeddings, N, padding_mask)

        aux1 = torch.tensor(np.arange(N), dtype=torch.float, device=couplings.device).reshape(N,1)
        aux1[0] = 1.0
        aux1 = torch.matmul(aux1, torch.ones(1,self.q))
        aux1 = torch.matmul(aux1.reshape(N*self.q,1), torch.ones(1,N*self.q))
        aux1=torch.einsum('i,jk->ijk', torch.ones(B), aux1)
        ### AUX1 SHOULD BE [B, Nq, Nq]
        ### Scale the couplings
        couplings = couplings/aux1
        
        return couplings, fields

    def sample_ardca_vect(self, fields, couplings, N, n_samples=100, q=21, rec_times=False, device='cpu'):
        """ This function currently works for a single sample"""
        ### Assume no batch-dimension
        fields = fields.reshape(N, q).to(device)
        couplings = couplings.reshape((N, q, N*q)).to(device)
        samples = torch.zeros((n_samples, N), requires_grad=False, dtype=torch.int).to(device)
        selected = torch.zeros((N*q,n_samples), requires_grad=False).to(device)
        p_pos = softmax(-fields[0], dim=0)

        val = Categorical(p_pos).sample((n_samples,))
        samples[:,0] = val
        selected[val, 0] = 1

        if rec_times:
            times = []

        #Ham = torch.zeros((n_samples, q), requires_grad=False)
        Ham = torch.zeros((q), requires_grad=False).to(device)
        for sam in range(n_samples):
            start = time.time()
            for pos in range(1,N):
                Ham = fields[pos, :] + torch.matmul(couplings[pos,:,:], selected[:, sam]).squeeze()
                p_pos = softmax(-Ham, dim=0)
                val = Categorical(p_pos).sample()
                samples[sam, pos] = val.item()
                selected[q*pos+val.item(), sam] = 1
            if rec_times:
                times.append(time.time()-start)
        return samples, times
    
    def samples_ardca_vect_batch(self, fields, couplings, N, n_samples=1000, q=21, rec_times=False, device='cpu'):
        """ These function currently work for a single structure at the time"""
        fields = fields.reshape(N, q).to(device)
        couplings = couplings.reshape((N, q, N*q)).to(device)

        #def sample_ardca_new(fields, couplings, n_samples=100)
        samples = torch.zeros((n_samples, N), requires_grad=False, dtype=torch.int).to(device)
        selected = torch.zeros((N*q,n_samples), requires_grad=False).to(device)


        p_pos = softmax(-fields[0], dim=0)

        val = Categorical(p_pos).sample((n_samples,))
        samples[:,0] = val
        selected[val, range(n_samples)] = 1
        Ham = torch.zeros((n_samples, q), requires_grad=False).to(device)
        #for sam in range(n_samples):
        with torch.no_grad():
            start = time.time()
            for pos in range(1,N):
                Ham = fields[pos, :].unsqueeze(-1) + torch.matmul(couplings[pos,:,:], selected).squeeze()
                p_pos = softmax(-Ham, dim=0)
                #p_pos[:] = softmax(Ham, dim=0)
                val = Categorical(torch.transpose(p_pos,0,1)).sample()
                samples[:, pos] = val
                selected[q*pos+val, range(n_samples)] = 1
        if rec_times:
            return samples, time.time()-start

        return samples
    
    

    
    
    