"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import warnings
from copy import deepcopy

import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.albef_models import AlbefBase
from lavis.models.albef_models.albef_outputs import (
    AlbefIntermediateOutput,
    AlbefOutputWithLogits,
)
from lavis.models.base_model import MomentumDistilationMixin
from lavis.models.albef_models.albef_retrieval_cons_unlearn import AlbefRetrievalConsUnlearn
from lavis.models.albef_models.albef_classification import AlbefClassification
from lavis.models.med import XBertEncoder
from lavis.models.vit import VisionTransformerEncoder
from torch import nn


@registry.register_model("albef_classification_cons_unlearn")
class AlbefClassificationConsUnlearn(AlbefClassification):
    def __init__(
            self,
            image_encoder,
            text_encoder,
            num_classes,
            momentum=0.995,
            alpha=0.4,
            use_distill=True,
            max_txt_len=40,
            embed_dim=256,
            queue_size=65536,
    ):
        super().__init__(image_encoder,
                         text_encoder,
                         num_classes,
                         momentum,
                         alpha,
                         use_distill,
                         max_txt_len,)

        text_width = text_encoder.config.hidden_size
        vision_width = image_encoder.vision_width

        self.vision_proj = nn.Linear(vision_width, embed_dim)
        self.text_proj = nn.Linear(text_width, embed_dim)

        self.vision_proj_m = deepcopy(self.vision_proj)
        self.text_proj_m = deepcopy(self.text_proj)

        self.model_pairs.extend([[self.vision_proj, self.vision_proj_m], [self.text_proj, self.text_proj_m]])
        self.copy_params()

        # create the queue
        self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
        self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
        self.register_buffer("idx_queue", torch.full((1, queue_size), -100))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)

    def forward(self, samples, is_train=True):
        sentences = samples["text_input"]
        sentences = self.tokenizer(
            sentences,
            padding="longest",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(self.device)
        samples.update({"tokenized_text": sentences})

        targets = samples["label"]
        image = samples["image"]
        image_embeds = self.visual_encoder.forward_features(image)
        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)

        encoder_output = self.text_encoder.forward_automask(
            samples["tokenized_text"], image_embeds
        )
        text_feat = F.normalize(self.text_proj(encoder_output[:, 0, :]), dim=-1)

        idx = samples["image_id"]
        idx = idx.view(-1, 1)
        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)
        pos_idx = torch.eq(idx, idx_all).float()
        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)


        prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])

        if is_train:
            if self.use_distill:
                with torch.no_grad():
                    self._momentum_update()

                    image_embeds_m = self.visual_encoder_m(image)
                    image_feat_m = F.normalize(
                        self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1
                    )
                    image_feat_all = torch.cat(
                        [image_feat_m.t(), self.image_queue.clone().detach()], dim=1
                    )

                    encoder_output_m = self.text_encoder_m.forward_automask(
                        samples["tokenized_text"], image_embeds_m
                    )
                    text_feat_m = F.normalize(self.text_proj_m(encoder_output_m[:, 0, :]), dim=-1)
                    text_feat_all = torch.cat(
                        [text_feat_m.t(), self.text_queue.clone().detach()], dim=1
                    )

                    prediction_m = self.cls_head_m(
                        encoder_output_m.last_hidden_state[:, 0, :]
                    )


                alpha = self.alpha * self._rampup_factor(
                    epoch=samples["epoch"],
                    iters=samples["iters"],
                    num_iters_per_epoch=samples["num_iters_per_epoch"],
                )

                sim_i2t_m = image_feat_m @ text_feat_all / self.temp
                sim_t2i_m = text_feat_m @ image_feat_all / self.temp
                sim_i2t_targets = (
                        alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
                )
                sim_t2i_targets = (
                        alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
                )

                sim_i2t = image_feat @ text_feat_all / self.temp
                sim_t2i = text_feat @ image_feat_all / self.temp

                loss_i2t = -torch.sum(
                    F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1
                ).mean()
                loss_t2i = -torch.sum(
                    F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1
                ).mean()


                loss_itc = (loss_i2t + loss_t2i) / 2

                loss = (1 - alpha) * F.cross_entropy(
                    prediction, targets
                ) - alpha * torch.sum(
                    F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),
                    dim=1,
                ).mean()
            else:
                loss = F.cross_entropy(prediction, targets)

                image_embeds_m, encoder_output_m, prediction_m = None, None, None

            # return {"loss": loss}
            return AlbefOutputWithLogits(
                loss=loss,
                intermediate_output=AlbefIntermediateOutput(
                    image_embeds=image_embeds,
                    image_embeds_m=image_embeds_m,
                    encoder_output=encoder_output,
                    encoder_output_m=encoder_output_m,
                ),
                logits=prediction,
                logits_m=prediction_m,
            )
        else:
            return {"predictions": prediction, "targets": targets,
            "image_embeds": image_embeds,
            'encoder_output': encoder_output.last_hidden_state[:, 0, :]}


# class AlbefClassificationConsUnlearn(AlbefRetrievalConsUnlearn):
#     PRETRAINED_MODEL_CONFIG_DICT = {
#         "ve": "configs/models/albef_classification_ve.yaml",
#     }
    # def __init__(self,
    #              image_encoder,
    #              text_encoder,
    #              queue_size,
    #              num_classes,
    #              embed_dim=256,
    #              temp=0.07,
    #              use_distill=True,
    #              momentum=0.995,
    #              alpha=0.4,
    #              max_txt_len=30,
    #              ):
    #     super().__init__(image_encoder,
    #                      text_encoder,
    #                      queue_size,
    #                      embed_dim,
    #                      temp,
    #                      use_distill,
    #                      momentum,
    #                      alpha,
    #                      max_txt_len,)
    #
    #     hidden_size = text_encoder.config.hidden_size
    #
    #     if num_classes > 0:
    #         self.cls_head = nn.Sequential(
    #             nn.Linear(hidden_size, hidden_size),
    #             nn.ReLU(),
    #             nn.Linear(hidden_size, num_classes),
    #         )
    #     else:
    #         warnings.warn(
    #             f"Found num_classes=0, initializing {type(self)} without classifier."
    #         )
    #
    #     if self.use_distill:
    #         self.cls_head_m = deepcopy(self.cls_head)
    #
    #         self.model_pairs.appand([self.cls_head, self.cls_head_m])
    #
    #         self.copy_params()
    #
    # def forward(self, samples, is_train=True):
    #     cons_output = super().forward(samples)
    #     loss_cons = cons_output.loss
    #
    #     sentences = samples["text_input"]
    #     sentences = self.tokenizer(
    #         sentences,
    #         padding="longest",
    #         truncation=True,
    #         max_length=self.max_txt_len,
    #         return_tensors="pt",
    #     ).to(self.device)
    #     samples.update({"tokenized_text": sentences})
    #
    #     targets = samples["label"]
    #
    #     image_embeds = cons_output.intermediate_output.image_embeds
    #     encoder_output = self.text_encoder.forward_automask(
    #         samples["tokenized_text"], image_embeds
    #     )
    #
    #     prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])
    #
    #     if is_train:
    #         if self.use_distill:
    #             with torch.no_grad():
    #                 self._momentum_update()
    #
    #                 image_embeds_m = cons_output.intermediate_output.image_embeds_m
    #                 encoder_output_m = self.text_encoder_m.forward_automask(
    #                     samples["tokenized_text"], image_embeds_m
    #                 )
    #
    #                 prediction_m = self.cls_head_m(
    #                     encoder_output_m.last_hidden_state[:, 0, :]
    #                 )
    #
    #             alpha = self.alpha * self._rampup_factor(
    #                 epoch=samples["epoch"],
    #                 iters=samples["iters"],
    #                 num_iters_per_epoch=samples["num_iters_per_epoch"],
    #             )
    #
    #             loss_cls = (1 - alpha) * F.cross_entropy(
    #                 prediction, targets
    #             ) - alpha * torch.sum(
    #                 F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),
    #                 dim=1,
    #             ).mean()
    #         else:
    #             loss_cls = F.cross_entropy(prediction, targets)
    #
    #             image_embeds_m, encoder_output_m, prediction_m = None, None, None
    #
    #         # return {"loss": loss}
    #         return AlbefOutputWithLogits(
    #             loss=loss_cons + loss_cls,
    #             intermediate_output=AlbefIntermediateOutput(
    #                 image_embeds=image_embeds,
    #                 image_embeds_m=image_embeds_m,
    #                 encoder_output=encoder_output,
    #                 encoder_output_m=encoder_output_m,
    #             ),
    #             logits=prediction,
    #             logits_m=prediction_m,
    #         )
    #     else:
    #         return {"predictions": prediction, "targets": targets,
    #         "image_embeds": image_embeds,
    #         'encoder_output': encoder_output.last_hidden_state[:, 0, :]}
    #
    # @classmethod
    # def from_config(cls, cfg=None):
    #     image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=False)
    #     text_encoder = XBertEncoder.from_config(cfg)
    #
    #     embed_dim = cfg.get("embed_dim", 256)
    #     momentum = cfg.get("momentum", 0.995)
    #     alpha = cfg.get("alpha", 0.4)
    #     temp = cfg.get("temp", 0.07)
    #     max_txt_len = cfg.get("max_txt_len", 30)
    #     queue_size = cfg.get("queue_size", 0)
    #     use_distill = cfg.get("use_distill", True)
    #
    #     num_classes = cfg.get("num_classes", -1)
    #
    #     assert num_classes > 1, "Invalid number of classes provided, found {}".format(
    #         num_classes
    #     )
    #
    #     model = cls(
    #         image_encoder=image_encoder,
    #         text_encoder=text_encoder,
    #         queue_size=queue_size,
    #         embed_dim=embed_dim,
    #         temp=temp,
    #         momentum=momentum,
    #         alpha=alpha,
    #         max_txt_len=max_txt_len,
    #         use_distill=use_distill,
    #         num_classes=num_classes,
    #     )
    #
    #     load_pretrained = cfg.get("load_pretrained", True)
    #     load_finetuned = cfg.get("load_finetuned", False)
    #     if load_pretrained or load_finetuned:
    #         model.load_checkpoint_from_config(cfg)
    #
    #     return model
