import functools
import logging
import os
from collections import OrderedDict

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from pkg_resources import packaging
from torch import nn

from .simple_tokenizer import SimpleTokenizer as _Tokenizer

logger = logging.getLogger(__name__)


# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
MODEL_PATH = 'https://huggingface.co/laion'
_MODELS = {
    "ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"),
    "ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"),
}


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
                 checkpoint_num: int = 0):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

        self.checkpoint_num = checkpoint_num

    def forward(self, x: torch.Tensor):
        if self.checkpoint_num > 0:
            segments = min(self.checkpoint_num, len(self.resblocks))
            return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
        else:
            return self.resblocks(x)


class CLIP_TEXT(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            context_length: int,
            vocab_size: int,
            transformer_width: int,
            transformer_heads: int,
            transformer_layers: int,
            checkpoint_num: int,
        ):
        super().__init__()

        self.context_length = context_length
        self._tokenizer = _Tokenizer()

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask(),
            checkpoint_num=checkpoint_num,
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
    
    def no_weight_decay(self):
        return {'token_embedding', 'positional_embedding'}

    @functools.lru_cache(maxsize=None)
    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def tokenize(self, texts, context_length=77, truncate=True):
        """
        Returns the tokenized representation of given input string(s)
        Parameters
        ----------
        texts : Union[str, List[str]]
            An input string or a list of input strings to tokenize
        context_length : int
            The context length to use; all CLIP models use 77 as the context length
        truncate: bool
            Whether to truncate the text in case its encoding is longer than the context length
        Returns
        -------
        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
        We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
        """
        if isinstance(texts, str):
            texts = [texts]

        sot_token = self._tokenizer.encoder["<|startoftext|>"]
        eot_token = self._tokenizer.encoder["<|endoftext|>"]
        all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
        if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
            result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
        else:
            result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) > context_length:
                if truncate:
                    tokens = tokens[:context_length]
                    tokens[-1] = eot_token
                else:
                    raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
            result[i, :len(tokens)] = torch.tensor(tokens)

        return result

    def forward(self, text):
        x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x


def clip_text_b16(
    embed_dim=512,
    context_length=77,
    vocab_size=49408,
    transformer_width=512,
    transformer_heads=8,
    transformer_layers=12,
    checkpoint_num=0,
    pretrained=True,
):
    # raise NotImplementedError
    model = CLIP_TEXT(
        embed_dim,
        context_length,
        vocab_size,
        transformer_width,
        transformer_heads,
        transformer_layers,
        checkpoint_num,
    )
    # pretrained = _MODELS["ViT-B/16"]
    # logger.info(f"Load pretrained weights from {pretrained}")
    # state_dict = torch.load(pretrained, map_location='cpu')
    # model.load_state_dict(state_dict, strict=False)
    # return model.eval()
    if pretrained:
        if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
            pretrained = _MODELS[pretrained]
        else:
            pretrained = _MODELS["ViT-B/16"]
        logger.info(f"Load pretrained weights from {pretrained}")
        state_dict = torch.load(pretrained, map_location='cpu')
        if context_length != state_dict["positional_embedding"].size(0):
            # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
            print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
            if context_length < state_dict["positional_embedding"].size(0):
                state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
            else:
                state_dict["positional_embedding"] = F.pad(
                    state_dict["positional_embedding"],
                    (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
                    value=0,
                )

        message = model.load_state_dict(state_dict, strict=False)
        print(f"Load pretrained weights from {pretrained}: {message}")
    return model.eval()


def clip_text_l14(
    embed_dim=768,
    context_length=77,
    vocab_size=49408,
    transformer_width=768,
    transformer_heads=12,
    transformer_layers=12,
    checkpoint_num=0,
    pretrained=True,
):
    model = CLIP_TEXT(
        embed_dim,
        context_length,
        vocab_size,
        transformer_width,
        transformer_heads,
        transformer_layers,
        checkpoint_num,
    )
    if pretrained:
        if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
            pretrained = _MODELS[pretrained]
        else:
            pretrained = _MODELS["ViT-L/14"]
        logger.info(f"Load pretrained weights from {pretrained}")
        state_dict = torch.load(pretrained, map_location='cpu')
        if context_length != state_dict["positional_embedding"].size(0):
            # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
            print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
            if context_length < state_dict["positional_embedding"].size(0):
                state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
            else:
                state_dict["positional_embedding"] = F.pad(
                    state_dict["positional_embedding"],
                    (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
                    value=0,
                )

        message = model.load_state_dict(state_dict, strict=False)
        print(f"Load pretrained weights from {pretrained}: {message}")
    return model.eval()


def clip_text_l14_336(
    embed_dim=768,
    context_length=77,
    vocab_size=49408,
    transformer_width=768,
    transformer_heads=12,
    transformer_layers=12,
):
    raise NotImplementedError
    model = CLIP_TEXT(
        embed_dim,
        context_length,
        vocab_size,
        transformer_width,
        transformer_heads,
        transformer_layers
    )
    pretrained = _MODELS["ViT-L/14_336"]
    logger.info(f"Load pretrained weights from {pretrained}")
    state_dict = torch.load(pretrained, map_location='cpu')
    model.load_state_dict(state_dict, strict=False)
    return model.eval()


def build_clip(config):
    model_cls = config.text_encoder.clip_teacher
    model = eval(model_cls)()
    return model