from .layers.transformer import *
from .layers.improved_transformer import *
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from tqdm import tqdm

from model.diffloss import DiffLoss

"""
0: SVG END
1: MOVE
2: LINE
3: CURVE
4: PAD
"""

SVG_END = 0
MOVE = 1
LINE = 2
CURVE = 3
PIX_PAD = 4

command_weight = torch.tensor([50,6,2,1])

CMD_TENSOR_DIM = 12

# FIGR-SVG-svgo
BBOX = 200


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model, padding_idx=None):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx)

    def forward(self, x):
        return self.embed(x)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=250):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
        self.register_buffer('position', position)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self._init_embeddings()

    def _init_embeddings(self):
        nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")

    def forward(self, x):
        pos = self.position[:x.size(0)]
        x = x + self.pos_embed(pos)
        return self.dropout(x)


class SVGDecoder(nn.Module):
    """
    Autoregressive generative model
    """

    def __init__(self,
                 config,
                 command_len,
                 text_len,
                 num_text_token,
                 word_emb_path=None,
                 pos_emb_path=None,
                 diffloss_d=3,
                 diffloss_w=256,
                 num_sampling_steps='100',
                 diffusion_batch_mul = 4,
                 eos_alpha = 1,
                 length_loss_weight = 1,
                 grad_checkpointing=False,):
        """
        Initializes FaceModel.
        """
        super(SVGDecoder, self).__init__()
        self.command_len = command_len
        self.embed_dim = config['embed_dim']

        self.text_len = text_len
        self.num_text_token = num_text_token
        self.num_command_token = CMD_TENSOR_DIM
        self.loc_dim = CMD_TENSOR_DIM - PIX_PAD
        # self.total_token = self.num_command_token
        self.total_seq_len = text_len + command_len
        self.loss_pix_weight = 100
        self.length_loss_weight = length_loss_weight

        seq_range = torch.arange(self.total_seq_len)
        token_range = torch.arange(self.num_command_token)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        token_range = rearrange(token_range, 'd -> () () d')

        output_mask = ((seq_range < text_len) & (token_range > -1))
        #
        self.register_buffer('output_mask', output_mask, persistent=False)
        # SVG and Text encoders

        self.command_embed = nn.Linear(self.num_command_token, self.embed_dim)
        self.pos_embed = PositionalEncoding(max_len=self.total_seq_len, d_model=self.embed_dim)
        self.softmax_layer_1 = nn.Linear(self.embed_dim,self.embed_dim)
        self.softmax_layer_2 = nn.Linear(self.embed_dim,self.embed_dim)

        self.eos_mlp = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(self.embed_dim * 2, 1)
        )
        self.eos_alpha = eos_alpha
        self.eos_token_idx = 0

        self.output_cmd = nn.Linear(self.embed_dim, PIX_PAD)


        decoder_layers = TransformerDecoderLayerImproved(d_model=self.embed_dim,
                                                         dim_feedforward=config['hidden_dim'],
                                                         nhead=config['num_heads'], dropout=config['dropout_rate'])
        decoder_norm = LayerNorm(self.embed_dim)
        self.decoder = TransformerDecoder(decoder_layers, config['num_layers'], decoder_norm)

        assert word_emb_path is not None, 'text_emb_dir must be provided'
        if word_emb_path is not None:
            self.text_emb = nn.Embedding.from_pretrained(torch.load(word_emb_path, map_location='cpu'))
            assert self.embed_dim == self.text_emb.weight.shape[1], 'emb_dim must match pretrained text embedded dim'
        if pos_emb_path is not None:
            self.text_pos_emb = nn.Embedding.from_pretrained(torch.load(pos_emb_path, map_location='cpu'))
            assert self.embed_dim == self.text_pos_emb.weight.shape[
                1], 'emb_dim must match pretrained text embedded dim'

        self.diffloss = DiffLoss(
            target_channels=self.loc_dim,
            z_channels=config['embed_dim'],
            width=diffloss_w,
            depth=diffloss_d,
            num_sampling_steps=num_sampling_steps,
            grad_checkpointing=grad_checkpointing
        )
        self.diffusion_batch_mul = diffusion_batch_mul

    def forward(self, command, mask, text, diffusion_mask, return_loss=False):
        '''
        command.shape  [batch_size, max_len, CMD_TENSOR_DIM]
        mask.shape [batch_size, max_len]
        text.shape [batch_size, text_len]
        '''
        command_v = command[:, :-1] if return_loss else command
        command_mask = mask[:, :-1] if return_loss else mask

        c_bs, c_seqlen, device = text.shape[0], text.shape[1], text.device
        if command_v[0] is not None:
            c_seqlen += command_v.shape[1]

            # Context embedding values
        context_embedding = torch.zeros((1, c_bs, self.embed_dim)).to(device)  # [1, bs, dim]

        # tokens.shape [batch_size, text_len, emb_dim]
        tokens = self.text_emb(text)

        # Data input embedding
        if command_v[0] is not None:
            # command_embed.shape [batch_size, max_len-1, emb_dim]
            command_embed = self.command_embed(command_v)
            embed_inputs = command_embed

            # tokens.shape [batch_size, text_len+max_len-1, emb_dim]
            tokens = torch.cat((tokens, embed_inputs), dim=1)

        # embeddings.shape [text_len+1 or text_len+max_len, batch_size, emb_dim]
        embeddings = torch.cat([context_embedding, tokens.transpose(0, 1)], axis=0)
        decoder_inputs = self.pos_embed(embeddings)


        memory_encode = torch.zeros((1, c_bs, self.embed_dim)).to(device)

        # nopeak_mask.shape [c_seqlen+1, c_seqlen+1]
        nopeak_mask = torch.nn.Transformer.generate_square_subsequent_mask(c_seqlen + 1).to(device)  # masked with -inf
        if command_mask is not None:
            # command_mask.shape [batch_size, text_len+max_len]
            command_mask = torch.cat(
                [(torch.zeros([c_bs, context_embedding.shape[0] + self.text_len]) == 1).to(device), command_mask], axis=1)
        decoder_out = self.decoder(tgt=decoder_inputs, memory=memory_encode, memory_key_padding_mask=None,
                                   tgt_mask=nopeak_mask, tgt_key_padding_mask=command_mask)

        # Logits fc

        cmd_hidden1 = self.softmax_layer_1(decoder_out)
        cmd_hidden1 = nn.ReLU()(cmd_hidden1)
        cmd_hidden2 = self.softmax_layer_2(cmd_hidden1)
        cmd_hidden2 = nn.ReLU()(cmd_hidden2)
        output_cmd = self.output_cmd(cmd_hidden2) # [seqlen, bs, dim]
        output_cmd = output_cmd.transpose(1, 0)     # [bs, textlen+seqlen, total_token]
        z = self.pos_embed(decoder_out)
        z = z.transpose(1, 0)

        eos_enhance = self.eos_mlp(z)
        eos_enhance = eos_enhance.squeeze(-1) * self.eos_alpha
        output_cmd[:, :, self.eos_token_idx] += eos_enhance

        softmax = torch.nn.Softmax(dim=2)

        outputs_discrete = softmax(output_cmd)

        output_mask_discrete = self.output_mask[:, :c_seqlen + 1, :PIX_PAD]
        output_mask_continuous = self.output_mask[:, :c_seqlen + 1, PIX_PAD:]



        if return_loss:
            bsz, seq_len, _ = command.shape
            loc_target = command[:,:,PIX_PAD:].reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
            z_tokens = z[:,self.text_len:,:]
            z = z_tokens.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
            diffusion_mask = (~mask).reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
            diffusion_loss = self.diffloss(z=z, target=loc_target, mask=diffusion_mask)

            outputs_discrete = rearrange(outputs_discrete, 'b n c -> b c n')
            text_outputs = outputs_discrete[:, :, :self.text_len]
            command_outputs = outputs_discrete[:, :, self.text_len:]

            if self.length_loss_weight > 0:
                eos_probs = command_outputs[:, self.eos_token_idx, :]
                target_length = (~mask).sum(dim=1).float()
                survival_probs = torch.cumprod(1.0 - eos_probs + 1e-8, dim=1)            


                survival_shifted = torch.cat(
                    [torch.ones(c_bs, 1, device=device), survival_probs[:, :-1]],
                    dim=1
                )
                ending_probs = eos_probs * survival_shifted

                positions = torch.arange(1, seq_len + 1, device=device, dtype=torch.float32).unsqueeze(0)
                expected_lengths = torch.sum(positions * ending_probs, dim=1)
                
                length_loss = F.mse_loss(expected_lengths, target_length, reduction='mean')
            else:
                length_loss = torch.tensor(0.0, device=device)

            command_outputs = rearrange(command_outputs, 'b c n -> (b n) c')
            command_mask = ~mask.reshape(-1)
            command_target = rearrange(command[:,:,:PIX_PAD], 'b n c -> (b n) c')
            command_CE = F.cross_entropy(command_outputs[command_mask], command_target[command_mask],weight=command_weight.to(device))


            loss = (self.loss_pix_weight * diffusion_loss + command_CE + length_loss* self.length_loss_weight)
            return loss, command_CE, diffusion_loss, length_loss
        else:
            bsz, seq_len, _ = z.shape
            z_token = z[:, -1, :]
            diffusion_result = self.diffloss.sample(z_token)
            sampled_loc = diffusion_result
            output_discrete = outputs_discrete[:,-1,:]
            return torch.cat([output_discrete, sampled_loc], axis = 1)


    def sample(self, n_samples, text, command_seq = None) :
        """ sample from distribution (top-k, top-p) """
        results = []

        top_k = 0
        top_p = 0.5

        # Sample per token
        text = text[:, :self.text_len]
        command_len = 0 if command_seq is None else command_seq.shape[1]
        end_token_max_value = 0
        for k in tqdm(range(text.shape[1] + command_len, self.total_seq_len)):
            if k == text.shape[1]:
                command_seq = [None,None,None,None,None,None,None,None,None,None,None,None,None] * n_samples

            # pass through model
            with torch.no_grad():
                output = self.forward(command_seq, None, text, None)
                command_pred = output

            next_commands = []
            # Top-p sampling of next pixel
            for i,logit in enumerate(command_pred):
                softmax = logit[:PIX_PAD]
                if softmax[0] > end_token_max_value:
                    end_token_max_value = softmax[0]
                next_command = torch.zeros_like(softmax)
                next_command[torch.argmax(softmax)] = 1
                if (logit[PIX_PAD:] < -0.98).sum() == 8:
                    next_command[torch.argmax(softmax)] = 0
                    next_command[SVG_END] = 1
                elif (logit[PIX_PAD:] < -0.99).sum() == 6:
                    if torch.argmax(softmax) != LINE:
                        next_command[torch.argmax(softmax)] = 0
                        next_command[LINE] = 1
                elif (logit[PIX_PAD:] < -0.99).sum() == 2:
                        next_command[torch.argmax(softmax)] = 0
                        next_command[CURVE] = 1
                next_commands.append(torch.cat((next_command, torch.clip(logit[PIX_PAD:],min=-1,max=1)), dim=0).cpu())


            next_commands = np.vstack(next_commands)

            # Add next tokens
            next_command_seq = torch.FloatTensor(next_commands).view(len(next_commands),1, -1).cuda()
            if command_seq[0] is None:
                command_seq = next_command_seq
            else:
                command_seq = torch.cat([command_seq, next_command_seq], 1)

            # Early stopping
            check_end = next_command_seq[:, :, SVG_END] == 1
            check_end = check_end.view(-1)
            if check_end.sum() > 0:
                print('%d samples are finished' % (check_end.sum()))
                results.append(command_seq[check_end].cpu().detach())
                command_seq = command_seq[~check_end]
                if len(command_seq) != len(text):
                    text = text[:len(command_seq)]

            if len(command_seq) == 0:

                break
        #print('end token max value : %.4f' % end_token_max_value)
        return results


