import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from models.resnets import resnet18, resnet50
import numpy as np

from models.model_utils import batched_pooling
import torch.nn.functional as F


def rescale(values, current_min, current_max, new_min, new_max):
    return (values.float() - current_min) * (new_max - new_min) / (current_max - current_min) + new_min


def top_k_logits(logits, k):
    """Masks logits such that logits not in top-k are small."""
    if k == 0:
        return logits
    else:
        values, _ = torch.topk(logits, k=k)
        k_largest = torch.min(values, axis=2)[0]
        logits[logits < torch.unsqueeze(k_largest, 2)] = -1e9
        return logits


def top_p_logits(logits, p):
    """Masks logits using nucleus (top-p) sampling."""
    if p == 1:
        return logits
    else:
        logit_shape = logits.shape
        seq, dim = logit_shape[1], logit_shape[2]
        logits = logits.view([-1, dim])
        sort_indices = torch.argsort(logits, axis=1, descending=True)
        probs = torch.gather(torch.nn.Softmax(dim=1)(logits), 1, sort_indices)
        cumprobs = torch.cumsum(
            probs,
            dim=-1,
        )
        # The top 1 candidate always will not be masked.
        # This way ensures at least 1 indices will be selected.
        sort_mask = cumprobs > p
        sort_mask[:, 0] = False
        top_p_mask = torch.zeros_like(sort_mask)
        for i in range(sort_mask.shape[0]):
            for j in range(sort_mask.shape[1]):
                top_p_mask[i, sort_indices[i, j]] = sort_mask[i, j]
        logits -= top_p_mask * 1e9
        return logits.view([-1, seq, dim])


def scaled_dot_product(q, k, v, bias=None, mask=None, possible_zero_attentions=False):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if bias is not None:
        attn_logits = attn_logits + torch.squeeze(bias, 1)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)

    # handle the case where all the attentions are 0.
    if possible_zero_attentions:
        attention[bias.repeat(attention.shape[0], attention.shape[1], 1, 1) < 0] = 0

    values = torch.matmul(attention, v)
    return values, attention


class MultiheadAttention(nn.Module):
    def __init__(
        self, input_dim, embed_dim, num_heads, add_context=False, context_size=None
    ):
        super(MultiheadAttention, self).__init__()
        assert (
            embed_dim % num_heads == 0
        ), "Embedding dimension must be 0 modulo number of heads."
        assert not add_context or context_size is not None

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.add_context = add_context
        self.context_size = context_size

        # Stack all weight matrices 1...h together for efficiency
        self.q_proj = nn.Linear(input_dim, embed_dim)

        if self.add_context:
            self.k_proj = nn.Linear(context_size, embed_dim)
            self.v_proj = nn.Linear(context_size, embed_dim)
        else:
            self.k_proj = nn.Linear(input_dim, embed_dim)
            self.v_proj = nn.Linear(input_dim, embed_dim)

        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(
        self,
        x,
        context=None,
        bias=None,
        mask=None,
        return_attention=False,
        possible_zero_attentions=False,
    ):
        assert not self.add_context or context is not None
        batch_size, seq_length, embed_dim = x.size()
        context_seq_length = seq_length

        q = self.q_proj(x)

        if not self.add_context:
            k = self.k_proj(x)
            v = self.v_proj(x)
        else:
            context_seq_length = context.shape[1]
            k = self.k_proj(context)
            v = self.v_proj(context)

        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(
            0, 2, 1, 3
        )
        k = k.view(
            batch_size, context_seq_length, self.num_heads, self.head_dim
        ).permute(0, 2, 1, 3)
        v = v.view(
            batch_size, context_seq_length, self.num_heads, self.head_dim
        ).permute(0, 2, 1, 3)

        # Determine value outputs
        values, attention = scaled_dot_product(
            q,
            k,
            v,
            bias=bias,
            mask=mask,
            possible_zero_attentions=possible_zero_attentions,
        )
        values = values.permute(0, 2, 1, 3)  # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o


class TransformerEncoder(nn.Module):
    """
    Transformer encoder. Vaswani et al. 2017.
    """

    def __init__(
        self,
        hidden_size=256,
        fc_size=1024,
        num_heads=4,
        layer_norm=True,
        num_layers=8,
        dropout_rate=0.2,
        re_zero=True,
        context_embedding_size=None,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.layer_norm = layer_norm
        self.fc_size = fc_size
        self.num_layers = num_layers
        self.dropout_rate = dropout_rate
        self.re_zero = re_zero

        self.dropout = nn.Dropout(self.dropout_rate)

        self.layer_norm_op = nn.LayerNorm(self.hidden_size)

        self.attention_layers = nn.ModuleList([])
        self.fc1s = nn.ModuleList([])
        self.fc2s = nn.ModuleList([])

        for layer_num in range(self.num_layers):
            self.attention_layers.append(
                MultiheadAttention(self.hidden_size, self.hidden_size, self.num_heads)
            )

            self.fc1s.append(nn.Linear(self.hidden_size, self.fc_size))
            self.fc2s.append(nn.Linear(self.fc_size, self.hidden_size))

        if context_embedding_size is not None and context_embedding_size[0] is not None:
            # Attend to all the sequences provided
            self.num_contexs = len(context_embedding_size)

            self.context_embedding = True
            self.context_attention_layers = nn.ModuleList([])

            for layer_num in range(self.num_layers):
                for i in range(self.num_contexs):
                    self.context_attention_layers.append(
                        MultiheadAttention(
                            self.hidden_size,
                            self.hidden_size,
                            self.num_heads,
                            add_context=True,
                            context_size=context_embedding_size[i],
                        )
                    )

        if self.re_zero:
            self.context_scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True)

        self._reset_parameters()

    def _reset_parameters(self):
        for fc in self.fc1s:
            nn.init.kaiming_normal_(fc.weight, mode="fan_out", nonlinearity="relu")
            nn.init.normal_(fc.bias, std=1e-6)

        for fc in self.fc2s:
            nn.init.kaiming_normal_(fc.weight, mode="fan_out", nonlinearity="relu")
            nn.init.normal_(fc.bias, std=1e-6)

    def forward(self, inputs, sequential_context_embeddings=None):
        # Identify elements with all zeros as padding, and create bias to mask
        # out padding elements in self attention.

        encoder_padding = (inputs == 0).all(dim=-1).float()
        encoder_self_attention_bias = (
            (encoder_padding * (-1e9)).unsqueeze(1).unsqueeze(1).unsqueeze(1)
        )

        if (
            sequential_context_embeddings is not None
            and sequential_context_embeddings[0] is not None
        ):
            assert (
                len(sequential_context_embeddings) == self.num_contexs
            ), "Make sure the appropriate number of contexs is provided in the decoder"

            encoder_decoder_attention_biases = []

            for i in range(self.num_contexs):
                encoder_padding = (
                    (sequential_context_embeddings[i] == 0).all(dim=-1).float()
                )
                encoder_decoder_attention_bias = (
                    (encoder_padding * (-1e9)).unsqueeze(1).unsqueeze(1)
                )

                encoder_decoder_attention_biases.append(
                    torch.unsqueeze(encoder_decoder_attention_bias, 1)
                )

        x = inputs
        for layer_num in range(self.num_layers):

            # Multihead self-attention from Tensor2Tensor.
            res = x
            if self.layer_norm:
                res = self.layer_norm_op(res)

            res = self.attention_layers[layer_num](
                res, bias=encoder_self_attention_bias
            )

            if self.re_zero:
                res = res * self.context_scale
            res = self.dropout(res)
            x = x + res

            # Optional cross attention into sequential context
            if (
                sequential_context_embeddings is not None
                and sequential_context_embeddings[0] is not None
            ):
                assert self.context_embedding
                for i in range(self.num_contexs):
                    res = x
                    if self.layer_norm:
                        res = self.layer_norm_op(res)
                    res = self.context_attention_layers[
                        layer_num * self.num_contexs + i
                    ](
                        res,
                        context=sequential_context_embeddings[i],
                        bias=encoder_decoder_attention_biases[i],
                    )

                    if self.re_zero:
                        res = res * self.context_scale
                    res = self.dropout(res)
                    x = x + res

            # MLP
            res = x
            if self.layer_norm:
                res = self.layer_norm_op(res)

            res = self.fc1s[layer_num](res)
            res = nn.GELU()(res)
            res = self.fc2s[layer_num](res)
            if self.re_zero:
                res = res * self.context_scale
            res = self.dropout(res)
            x = x + res

        if self.layer_norm:
            output = self.layer_norm_op(x)
        else:
            output = x
        return output


class MyEmbed(nn.Module):
    def __init__(self, classes, embedding_size):
        super().__init__()
        self.embedding = nn.Embedding(classes, embedding_size)
        nn.init.kaiming_normal_(
            self.embedding.weight, mode="fan_out", nonlinearity="relu"
        )

    def forward(self, x):
        return self.embedding(x)

class EnhancedLinesModel(nn.Module):
    def __init__(
        self,
        encoder_config=None,
        decoder_config=None,
        embedding_dim=256,
        image_features=960,
        image_features_final_dimension=256,
        num_input_channels=3,
        image_size=224,
        use_discrete_embeddings=True,
        fix_backbone=False,
        max_seq_length=151,
        use_postion_embeddings=True,
        position_encodings_for_keypoints=False,
        resnet_size=18,
    ):
        if encoder_config is None:
            encoder_config = {
                "hidden_size": 256,
                "fc_size": 256,
                "num_layers": 8,
                "layer_norm": True,
                "dropout_rate": 0.1,
            }

        if decoder_config is None:
            decoder_config = {
                "hidden_size": 256,
                "fc_size": 256,
                "num_layers": 8,
                "layer_norm": True,
                "dropout_rate": 0.1,
            }

        self.encoder_config = encoder_config
        self.decoder_config = decoder_config

        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_image_features = image_features
        self.position_encodings_for_keypoints = position_encodings_for_keypoints

        self.cross_attention_keypoints = (
            "context_embedding_size" in encoder_config
            and encoder_config["context_embedding_size"] is not None
        )
        self.cross_attention_edges = (
            "context_embedding_size" in decoder_config
            and decoder_config["context_embedding_size"] is not None
        )

        self.max_seq_length = max_seq_length
        self.image_size = image_size
        self.num_input_channels = num_input_channels

        print("Using as backbone resnet:{}".format(resnet_size))

        if resnet_size == 18:
            self.image_backbone = resnet18(pretrained=True)
        elif resnet_size == 50:
            self.image_backbone = resnet50(pretrained=True)
        else:
            raise NotImplementedError

        if fix_backbone:
            self.image_backbone.requires_grad_(False)

        if num_input_channels != 3:
            self.image_backbone.conv1 = nn.Conv2d(
                num_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
            )

        self.image_backbone.fc = nn.Linear(
            self.image_backbone.fc.in_features, embedding_dim
        )

        self.use_discrete_embeddings = use_discrete_embeddings
        self.use_postion_embeddings = use_postion_embeddings

        self.encoder = TransformerEncoder(**encoder_config)
        self.decoder = TransformerEncoder(**decoder_config)

        if self.use_discrete_embeddings:
            self.keypoints_encode_embeddings_x = MyEmbed(
                self.image_size, self.embedding_dim
            )
            self.keypoints_encode_embeddings_y = MyEmbed(
                self.image_size, self.embedding_dim
            )
        else:
            self.keypoints_encode_embeddings = nn.Linear(2, self.embedding_dim)

        self.stopping_embeddings = nn.Parameter(
            torch.randn([1, 1, self.embedding_dim]) * 1e-6, requires_grad=True
        )
        self.zero_embed = nn.Parameter(
            torch.randn([1, 1, self.embedding_dim]) * 1e-6, requires_grad=True
        )
        
        self.max_seq_length = max_seq_length

        if self.use_postion_embeddings:
            self.pos_embeddings = MyEmbed(self.max_seq_length, self.embedding_dim)

        # first and second edge types
        self.type_embeddings = MyEmbed(
            2, self.embedding_dim
        )

        self.project_to_logits = nn.Linear(
            decoder_config["hidden_size"], self.embedding_dim
        )

        self.keypoints_features_projection = nn.Linear(
            self.embedding_dim + image_features, encoder_config["hidden_size"]
        )

        if self.position_encodings_for_keypoints:
            self.pos_embeddings_keypoints = MyEmbed(
                self.max_seq_length, self.embedding_dim
            )

        self.global_context_image_embedding = nn.Linear(
            image_features_final_dimension, self.embedding_dim
        )

        if self.cross_attention_keypoints or self.cross_attention_edges:
            self.image_coord_embed = nn.Linear(2, 512, bias=True) # TODO fix features image size

    def _embed_image(self, image_embeddings):
        image_embeddings = image_embeddings.permute(0, 2, 3, 1)

        # Add 2D coordinate grid embedding
        processed_image_resolution = image_embeddings.shape[-1]

        # TODO the value 10 here is hard coded here for the image size ...
        x = torch.linspace(-1.0, 1.0, 10, device=image_embeddings.device)
        image_coords = torch.stack(torch.meshgrid(x, x), axis=-1).permute(1, 0, 2)
        image_coord_embeddings = self.image_coord_embed(image_coords)

        image_embeddings = image_embeddings + image_coord_embeddings[None]

        # Reshape spatial grid to sequence
        batch_size = image_embeddings.shape[0]
        sequential_context_embedding = image_embeddings.view(
            [batch_size, -1, processed_image_resolution]
        )

        return sequential_context_embedding

    def _embed_keypoints(
        self, keypoints, keypoints_mask, keypoints_image_features, cross_attention_features
    ):
        # Discrete keypoints value embeddings
        if self.use_discrete_embeddings:
            keypoints_quantized = torch.clamp(keypoints, 0, self.image_size - 1)
            keypoints_quantized = torch.round(keypoints_quantized).long()

            keypoints_embeddings = 0.0
            keypoints_embeddings = keypoints_embeddings + self.keypoints_encode_embeddings_x(
                keypoints_quantized[..., 0]
            )
            keypoints_embeddings = keypoints_embeddings + self.keypoints_encode_embeddings_y(
                keypoints_quantized[..., 1]
            )
        else:
            keypoints_dequant = rescale(keypoints, 0, self.image_size - 1, -0.5, 0.5)
            keypoints_embeddings = self.keypoints_encode_embeddings(keypoints_dequant)

        keypoints_embeddings = torch.cat(
            [keypoints_embeddings, keypoints_image_features], -1
        ).float()

        keypoints_embeddings = self.keypoints_features_projection(keypoints_embeddings)

        if self.position_encodings_for_keypoints:
            pos_embeddings = self.pos_embeddings_keypoints(
                torch.arange(
                    0, keypoints_embeddings.shape[1], device=keypoints_embeddings.device
                )
            )

            keypoints_embeddings = keypoints_embeddings + pos_embeddings[None]

        keypoints_embeddings = keypoints_embeddings * keypoints_mask[..., None]

        stopping_embeddings = torch.tile(
            self.stopping_embeddings, (keypoints.shape[0], 1, 1)
        )

        keypoints_embeddings = torch.cat([stopping_embeddings, keypoints_embeddings], dim=1)

        if not self.cross_attention_keypoints:
            cross_attention_features = None

        return self.encoder(
            keypoints_embeddings, sequential_context_embeddings=[cross_attention_features]
        )

    def _embed_inputs(
        self, lines_long, lines_mask, keypoints_embeddings, global_context_embedding=None
    ):
        lines_embeddings = torch.zeros(
            (
                keypoints_embeddings.shape[0],
                lines_long.shape[1],
                keypoints_embeddings.shape[2],
            ),
            device=lines_long.device,
        )

        for i, (lin, vert) in enumerate(zip(lines_long, keypoints_embeddings)):
            lines_embeddings[i] = vert[lin]

        if self.use_postion_embeddings:
            # Position embeddings
            pos_embeddings = self.pos_embeddings(
                torch.arange(0, lines_long.shape[1], device=lines_long.device)
            )
        else:
            pos_embeddings = torch.zeros(
                lines_long.shape[1],
                lines_embeddings.shape[-1],
                device=lines_embeddings.device,
            )

        type_embeddings = self.type_embeddings(
            torch.arange(0, lines_long.shape[1], device=lines_long.device) % 2
        )

        # Step zero embeddings
        batch_size = lines_embeddings.shape[0]

        if global_context_embedding is None:
            zero_embed_tiled = torch.tile(self.zero_embed, [batch_size, 1, 1])
        else:
            zero_embed_tiled = global_context_embedding[:, None] + torch.tile(
                self.zero_embed, [batch_size, 1, 1]
            )

        # Aggregate embeddings
        embeddings = lines_embeddings + pos_embeddings[None] + type_embeddings[None]
        embeddings = embeddings * lines_mask[..., None]
        embeddings = torch.cat([zero_embed_tiled, embeddings], dim=1)

        return embeddings

    def _create_dist(
        self,
        keypoints_embeddings,
        keypoints_mask,
        lines_long,
        lines_masks,
        global_context_embedding=None,
        cross_attention_features=None,
        temperature=1.0,
        top_k=0,
        top_p=1.0,
        return_last=False,
        cache=None,
        return_logits=False,
        return_decoder_outputs=False,
    ):
        # Embed inputs
        decoder_inputs = self._embed_inputs(
            lines_long, lines_masks, keypoints_embeddings, global_context_embedding
        )

        if not self.cross_attention_edges:
            cross_attention_features = None

        decoder_outputs = self.decoder(
            decoder_inputs,
            sequential_context_embeddings=[cross_attention_features],
        )
        decoder_outputs = decoder_outputs[
            torch.arange(decoder_outputs.shape[0]),
            torch.zeros(decoder_outputs.shape[0]).long(),
        ].unsqueeze(1)

        if return_decoder_outputs:
            return decoder_outputs

        if return_last:
            decoder_outputs = decoder_outputs[:, -1:]

        pred_pointers = self.project_to_logits(decoder_outputs)

        logits = torch.matmul(pred_pointers, keypoints_embeddings.permute(0, 2, 1))
        logits /= math.sqrt(float(self.embedding_dim))
        f_verts_mask = F.pad(keypoints_mask, pad=(1, 0), value=1)[:, None]

        logits = logits * f_verts_mask
        logits -= (1.0 - f_verts_mask) * 1e9

        if return_logits:
            return logits

        logits /= temperature
        logits = top_k_logits(logits, top_k)
        logits = top_p_logits(logits, top_p)
        return torch.distributions.categorical.Categorical(logits=logits)

    def _prepare_context(self, images, keypoints, masks, keypoints_images=None):
        res = self.image_backbone(
            images.float(),
            return_predictions=True,
        )

        image_features, predictions = res[:4], res[4]

        if keypoints.shape[1] > 0:
            keypoints_features = batched_pooling(
                image_features, keypoints, image_size=self.image_size
            )
        else:
            keypoints_features = torch.empty(
                (keypoints.shape[0], 0, self.num_image_features),
                device=keypoints.device,
            )

        global_context = self.global_context_image_embedding(predictions)

        if self.cross_attention_keypoints or self.cross_attention_edges:
            cross_attention_features = self._embed_image(image_features[3])
        else:
            cross_attention_features = None
        return (
            self._embed_keypoints(
                keypoints,
                masks,
                keypoints_features,
                cross_attention_features=cross_attention_features,
            ),
            cross_attention_features,
            global_context,
        )

    def forward(self, images, keypoints, keypoints_masks, lines):
        # inside forward to make multi-gpu training easier
        (
            keypoints_embeddings,
            cross_attention_features,
            global_context_embedding,
        ) = self._prepare_context(images, keypoints, keypoints_masks)

        pred_dist_list = []
        for i in range(lines.shape[1]):
            # decoder does not employ triangular masking, so need to recalculate everything..
            pred_dist = self._create_dist(
                keypoints_embeddings,
                keypoints_masks,
                lines.long()[:, :i],
                torch.ones((lines.shape[0], i)).to(lines.device),
                global_context_embedding=global_context_embedding,
                cross_attention_features=cross_attention_features,
            )
            pred_dist_list.append(pred_dist)

        return pred_dist_list

    def sample(
        self,
        images,
        keypoints,
        keypoints_masks,
        max_sample_length=None,
        temperature=1.0,
        top_k=0,
        top_p=1.0,
        only_return_complete=False,
        return_last_logits=False,
        samples=None,
        replace_indices=None,
    ):
        (
            keypoints_embeddings,
            cross_attention_features,
            global_context,
        ) = self._prepare_context(images, keypoints, keypoints_masks)
        num_samples = keypoints_embeddings.shape[0]

        if samples is None:
            # Initial values for loop variables
            samples = torch.zeros(
                [num_samples, 0],
                dtype=torch.int64,
                device=next(self.parameters()).device,
            )

        max_sample_length = max_sample_length or self.max_seq_length

        i = samples.shape[1]

        while (samples != 0).all(dim=-1).any() and i < max_sample_length:
            i += 1
            pred_logits = self._create_dist(
                keypoints_embeddings,
                keypoints_masks,
                samples,
                torch.ones((samples.shape[0], samples.shape[1])).to(samples.device),
                global_context_embedding=global_context,
                cross_attention_features=cross_attention_features,
                cache=None,
                return_last=True,
                return_logits=True,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )

            logits = pred_logits / temperature
            logits = top_k_logits(logits, top_k)
            logits = top_p_logits(logits, top_p)

            next_sample = torch.distributions.categorical.Categorical(
                logits=logits
            ).sample()[:, -1:]

            if replace_indices is not None and torch.sum(replace_indices) > 0:
                logits = pred_logits[replace_indices]
                logits[
                    logits > -0.5e9
                ] = 0  # set all values equal for random sampling ..
                next_sample[
                    replace_indices
                ] = torch.distributions.categorical.Categorical(logits=logits).sample()[
                    :, -1:
                ]

            samples = torch.cat([samples, next_sample], axis=1)

        if return_last_logits:
            # here we should return the true logits right?
            last_logits = self._create_dist(
                keypoints_embeddings,
                keypoints_masks,
                samples,
                torch.ones((samples.shape[0], samples.shape[1])).to(samples.device),
                global_context_embedding=global_context,
                cache=None,
                return_last=False,
                temperature=1.0,
                top_k=0,
                top_p=1.0,
                return_logits=True,
            )

        f = samples

        # Record completed samples
        complete_samples = (f == 0).any(-1)

        zero_inds = torch.argmax((f == 0).int(), axis=-1).int()

        num_lines_indices = torch.zeros(zero_inds.shape[0]).to(samples.device).int()
        for i, (compl, zero_idx) in enumerate(zip(complete_samples, zero_inds)):
            if compl:
                num_lines_indices[i] = zero_idx.item()
            else:
                num_lines_indices[i] = f.shape[1]

        num_lines_indices = num_lines_indices + 1

        if only_return_complete:
            f = f[complete_samples]
            num_lines_indices = num_lines_indices[complete_samples]
            complete_samples = complete_samples[complete_samples]

        # outputs
        outputs = {
            "completed": complete_samples,
            "lines": f,
            "num_lines_indices": num_lines_indices,
        }

        if return_last_logits:
            return outputs, last_logits

        return outputs