import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import BertTokenizer

from models.blip import create_vit, init_tokenizer, load_checkpoint
from models.loss import CrossEntropyLoss, HardNegativeNCE
from models.med import BertConfig, BertModel


class BLIPCirEmbs(nn.Module):
    def __init__(
        self,
        med_config="configs/med_config.json",
        image_size=384,
        vit="base",
        vit_grad_ckpt=False,
        vit_ckpt_layer=0,
        embed_dim=256,
        queue_size=57600,
        momentum=0.995,
        negative_all_rank=False,
        train_vit=True,
        beta=0,
        hard_negatives=False,
    ):
        """
        Args:
            med_config (str): path for the mixture of encoder-decoder model's configuration file
            image_size (int): input image size
            vit (str): model size of vision transformer
        """
        super().__init__()

        self.visual_encoder, vision_width = create_vit(
            vit, image_size, vit_grad_ckpt, vit_ckpt_layer
        )
        self.tokenizer = init_tokenizer()
        med_config = BertConfig.from_json_file(med_config)
        med_config.encoder_width = vision_width
        self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)

        text_width = self.text_encoder.config.hidden_size

        self.vision_proj = nn.Linear(vision_width, embed_dim)
        self.text_proj = nn.Linear(text_width, embed_dim)

        self.queue_size = queue_size

        self.negative_all_rank = negative_all_rank

        self.train_vit = train_vit
        if not self.train_vit:
            # Do not train visual encoder
            for p in self.visual_encoder.parameters():
                p.requires_grad = False

        for p in self.vision_proj.parameters():
            p.requires_grad = False

        self.temp = nn.Parameter(0.07 * torch.ones([]))

        self.hard_negatives = hard_negatives
        self.hard_temp = nn.Parameter(0.1 * torch.ones([]))

        if self.hard_negatives:
            print("Using hard negatives")
            # self.hard_temp.requires_grad = True
        else:
            print("Not using hard negatives")
            self.hard_temp.requires_grad = False

        self.beta = beta
        if beta > 0:
            self.loss = HardNegativeNCE(beta=beta)
            self.temp.requires_grad = False
            print(f"Using HardNegativeNCE loss with beta={beta}")
        else:
            self.loss = CrossEntropyLoss()
            print("Using CrossEntropyLoss")

    def forward(
        self, ref_img, tar_feat, caption, alpha, idx, soft_inputs, soft_targets
    ):
        device = ref_img.device
        with torch.no_grad():
            self.temp.clamp_(0.001, 0.5)
            self.hard_temp.clamp_(0.001, 1)

        if self.train_vit:
            ref_img_embs = self.visual_encoder(ref_img)
        else:
            with torch.no_grad():
                ref_img_embs = self.visual_encoder(ref_img)

        # Encode the target image
        tar_feat = tar_feat.to(device)
        tar_img_feat = F.normalize(tar_feat, dim=-1)

        # Encode the reference image
        ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device)

        text = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=35,
            return_tensors="pt",
        ).to(device)

        # Shift encoder
        encoder_input_ids = text.input_ids.clone()
        encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
        query_embs = self.text_encoder(
            encoder_input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=ref_img_embs,
            encoder_attention_mask=ref_img_atts,
            return_dict=True,
        )
        query_feat = query_embs.last_hidden_state[:, 0, :]
        query_feat = F.normalize(self.text_proj(query_feat), dim=-1)

        ###============== Contrastive Learning ===================###
        if torch.distributed.is_initialized():
            query_feat = all_gather_with_grad(query_feat)
            tar_img_feat = all_gather_without_grad(tar_img_feat)

        if self.hard_negatives:
            if torch.distributed.is_initialized():
                soft_inputs = all_gather_with_grad(soft_inputs)
                soft_targets = all_gather_with_grad(soft_targets)

            soft_inputs = soft_inputs.to(device, non_blocking=True)
            soft_targets = soft_targets.to(device, non_blocking=True)

            same_inputs = torch.eq(soft_inputs.unsqueeze(1), soft_inputs.unsqueeze(0))
            same_targets = torch.eq(
                soft_targets.unsqueeze(1), soft_targets.unsqueeze(0)
            )

            hard_negatives = torch.logical_and(same_inputs, ~same_targets).float()

            easy_positives = torch.logical_and(same_inputs, same_targets).float()
            easy_positives = easy_positives * (
                1 - torch.eye(hard_negatives.size(0)).to(device)
            )  # Zeros in the diagonal

            hard_negatives = hard_negatives * self.hard_temp - easy_positives * 1e9

            if torch.distributed.is_initialized():
                torch.distributed.barrier()
                # hard_negatives = all_gather_with_grad(hard_negatives)

            return self.loss(query_feat, tar_img_feat, hard_negatives)

        if self.beta > 0:
            return self.loss(query_feat, tar_img_feat)

        return self.loss(query_feat, tar_img_feat, self.temp)


def blip_cir_embs(pretrained="", **kwargs):
    model = BLIPCirEmbs(**kwargs)
    if pretrained:
        model, msg = load_checkpoint(model, pretrained)
        print("missing keys:")
        print(msg.missing_keys)
    return model


@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


class GatherLayer(torch.autograd.Function):
    """
    Gather tensors from all workers with support for backward propagation:
    This implementation does not cut the gradients as torch.distributed.all_gather does.
    """

    @staticmethod
    def forward(ctx, x):
        output = [
            torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        torch.distributed.all_reduce(all_gradients)
        return all_gradients[torch.distributed.get_rank()]


def all_gather_with_grad(tensors):
    """
    Performs all_gather operation on the provided tensors.
    Graph remains connected for backward grad computation.
    """
    # Queue the gathered tensors
    world_size = torch.distributed.get_world_size()
    # There is no need for reduction in the single-proc case
    if world_size == 1:
        return tensors

    tensor_all = GatherLayer.apply(tensors)

    return torch.cat(tensor_all, dim=0)


def all_gather_without_grad(tensors):
    """
    Performs all_gather operation on the provided tensors.
    Graph remains disconnected for backward grad computation.
    """
    # Queue the gathered tensors
    world_size = torch.distributed.get_world_size()
    # There is no need for reduction in the single-proc case
    if world_size == 1:
        return tensors

    return concat_all_gather(tensors)
