# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from torch.cuda.amp import autocast
import torch.distributed as dist
from tqdm import tqdm
from torchvision.utils import save_image
import sys
import pdb
import wandb
import logging
import torch.nn.functional as F
from third_party.open_clip.clip import tokenize, _transform
from third_party.open_clip.simple_tokenizer import SimpleTokenizer
from utils import is_master
import random


def get_text_features(model, token_features, args):
    text = tokenize("a photo of")
    text_features = model.encode_text_img(text, token_features)
    return text_features

def get_intent_text_features(model, intent_token_features, args):
    text = tokenize("<|replace|>")
    text_features = model.encode_rewrite_text(text, intent_token_features)
    return text_features

def get_text_features_global(model, token_features, args):
    text = tokenize("a photo of <|replace|>")
    # print(text)
    text_features = model.encode_text_img_replace(text, token_features)
    return text_features

def get_text_features_replace(model, token_features, intention_text, args):
    intention_text = intention_text.cuda(args.gpu, non_blocking=True)
    text_features = model.encode_text_img_replace(intention_text, token_features)
    return text_features

def get_rewrited_text_features(model, token_features, intent_token_features, args):
    text = tokenize("a photo of <|replace|> , ")
    text_features = model.encode_text_img_rewrite(text, token_features, intent_token_features)
    return text_features

def get_loss_img2text(model, img2text, images, loss_img, loss_txt, args, memory=None):
    with torch.no_grad():
        image_features = model.encode_image(images)
    token_features = img2text(image_features)
    text_features = get_text_features(model, token_features, args)
    logit_scale = model.logit_scale.exp()
    logit_scale = logit_scale.mean()
    if args.distributed and args.aggregate:
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        # We gather tensors from all gpus to get more negatives to contrast with.
        gathered_image_features = [
            torch.zeros_like(image_features) for _ in range(world_size)
        ]
        gathered_text_features = [
            torch.zeros_like(text_features) for _ in range(world_size)
        ]
        dist.all_gather(gathered_image_features, image_features)
        dist.all_gather(gathered_text_features, text_features)

        all_image_features = torch.cat(
            [image_features]
            + gathered_image_features[:rank]
            + gathered_image_features[rank + 1 :]
        )
        all_text_features = torch.cat(
            [text_features]
            + gathered_text_features[:rank]
            + gathered_text_features[rank + 1 :]
        )

        ground_truth = torch.arange(len(all_image_features)).long()
        if args.gpu is not None:
            ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)

        # this is needed to send gradients back everywhere.
        # Image loss.
        logits_per_image = logit_scale * all_image_features @ all_text_features.t()
        loss_img_val = loss_img(logits_per_image, ground_truth)
        logits_per_text = logits_per_image.t()
        loss_txt_val = loss_txt(logits_per_text, ground_truth)
    else:
        ground_truth = torch.arange(len(image_features)).long()
        if args.gpu is not None:
            ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
        # Image loss.
        logits_per_image = logit_scale * image_features @ text_features.t()
        loss_img_val = loss_img(logits_per_image, ground_truth)
        logits_per_text = logit_scale * text_features @ image_features.t()
        loss_txt_val = loss_txt(logits_per_text, ground_truth)
    total_loss = (loss_img_val + loss_txt_val) / 2
    return total_loss

def calu_loss_muti(loss_img, loss_txt, args, world_size, rank, logit_scale, image_features, text_features):
    # We gather tensors from all gpus to get more negatives to contrast with.
    gathered_image_features = [
        torch.zeros_like(image_features) for _ in range(world_size)
    ]
    gathered_text_features = [
        torch.zeros_like(text_features) for _ in range(world_size)
    ]
    dist.all_gather(gathered_image_features, image_features)
    dist.all_gather(gathered_text_features, text_features)

    all_image_features = torch.cat(
        [image_features]
        + gathered_image_features[:rank]
        + gathered_image_features[rank + 1 :]
    )
    all_text_features = torch.cat(
        [text_features]
        + gathered_text_features[:rank]
        + gathered_text_features[rank + 1 :]
    )

    ground_truth = torch.arange(len(all_image_features)).long()
    if args.gpu is not None:
        ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)

    # this is needed to send gradients back everywhere.
    # Image loss.
    logits_per_image = logit_scale * all_image_features @ all_text_features.t()
    loss_img_val = loss_img(logits_per_image, ground_truth)
    logits_per_text = logits_per_image.t()
    loss_txt_val = loss_txt(logits_per_text, ground_truth)
    total_loss = (loss_img_val + loss_txt_val) / 2
    return total_loss

def calu_loss(loss_img, loss_txt, args, logit_scale, image_features, text_features):
    # We gather tensors from all gpus to get more negatives to contrast with.
    ground_truth = torch.arange(len(image_features)).long()
    if args.gpu is not None:
        ground_truth = ground_truth.cuda(args.gpu, non_blocking=True)
    # Image loss.
    logits_per_image = logit_scale * image_features @ text_features.t()
    loss_img_val = loss_img(logits_per_image, ground_truth)
    logits_per_text = logit_scale * text_features @ image_features.t()
    loss_txt_val = loss_txt(logits_per_text, ground_truth)
    total_loss = (loss_img_val + loss_txt_val) / 2
    return total_loss

def calu_intent_loss_muti(args, v1: torch.Tensor, v2: torch.Tensor, temperature: float, world_size: int, rank: int) -> torch.Tensor:
    # We gather tensors from all gpus to get more negatives to contrast with.
    gathered_v1 = [
        torch.zeros_like(v1) for _ in range(world_size)
    ]
    gathered_v2 = [
        torch.zeros_like(v2) for _ in range(world_size)
    ]
    dist.all_gather(gathered_v1, v1)
    dist.all_gather(gathered_v2, v2)

    all_v1 = torch.cat(
        [v1]
        + gathered_v1[:rank]
        + gathered_v1[rank + 1 :]
    )
    all_v2 = torch.cat(
        [v2]
        + gathered_v2[:rank]
        + gathered_v2[rank + 1 :]
    )

    # Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py
    device =  args.gpu
    all_v1 = F.normalize(all_v1, dim=1)
    all_v2 = F.normalize(all_v2, dim=1)

    numerator = torch.exp(torch.diag(torch.inner(all_v1, all_v2)) / temperature)
    numerator = torch.cat((numerator, numerator), 0)
    joint_vector = torch.cat((all_v1, all_v2), 0)
    pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature)
    denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device, non_blocking=True), 0)

    loss = -torch.mean(torch.log(numerator / denominator))
    return loss


def calu_intent_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor:
    # Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py
    device = v1.device
    v1 = F.normalize(v1, dim=1)
    v2 = F.normalize(v2, dim=1)

    numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature)
    numerator = torch.cat((numerator, numerator), 0)
    joint_vector = torch.cat((v1, v2), 0)
    pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature)
    denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device, non_blocking=True), 0)

    loss = -torch.mean(torch.log(numerator / denominator))
    return loss

def _intent_analyser(model, intent_analyser, text, token_features, args):
    '''
    Anlysis the user manipulation texts in to limited pseudo tokens
    :param model: original CLIP model
    :param intent_analyser: our intent analyser proposed in paper
    :param text: intent text / summary text (average 65 words)
    :param args:
    :return: The text embeddings of intent texts
    '''
    text = text.cuda(args.gpu, non_blocking=True)
    text_img_token_features = model.encode_text_img_rewrite_token(text, token_features)
    # print(text_img_token_features.shape) # torch.Size([256, 77, 768])
    intention_text_features, intent_gate = intent_analyser(text_img_token_features)
    return intention_text_features, intent_gate

def get_intent_texts_with_blank(texts_with_blank, rewrited_texts_with_blank, intention_texts_with_blank):
    intent_texts_with_blank = rewrited_texts_with_blank
    for idx in range(len(rewrited_texts_with_blank)):
        ran_num = random.random()
        if ran_num < 0.20:
            # print("use original text")
            intent_texts_with_blank[idx] = intention_texts_with_blank[idx]
        elif ran_num < 0.50:
            # print("use rewrite text")
            intent_texts_with_blank[idx] = rewrited_texts_with_blank[idx]
        else:
            # print("use summary text")
            intent_texts_with_blank[idx] = texts_with_blank[idx]
    return intent_texts_with_blank


def get_loss_img2text_rewrite(model, img2text, intent_analyser, images, texts, texts_with_blank, blank_without_text, rewirted_texts, rewrited_texts_with_blank, intention_texts, intention_texts_with_blank, loss_img, loss_txt, args, memory=None):
    with torch.no_grad():
        image_features = model.encode_image(images)
    token_features = img2text(image_features)

    # argumentation
    intent_texts_with_blank = get_intent_texts_with_blank(texts_with_blank, rewrited_texts_with_blank, intention_texts_with_blank)


    # get the pseudo manipulate intent token features
    intent_token_features, intent_gate = _intent_analyser(model, intent_analyser, intent_texts_with_blank, token_features, args) #  torch.Size([256, 4, 768])

    rewrite_text_features = get_text_features_replace(model, token_features, intent_texts_with_blank, args)
    rewrite_text_features_intent = get_intent_text_features(model, intent_token_features, args)

    # dense
    rewrite_text_features = rewrite_text_features + (rewrite_text_features_intent*intent_gate.tanh())
    intent_text_features = get_text_features_replace(model, token_features, intention_texts_with_blank, args)

    logit_scale = model.logit_scale.exp()
    logit_scale = logit_scale.mean()

    if args.distributed and args.aggregate:
        world_size = dist.get_world_size() # the number of gpus
        rank = dist.get_rank()
        # align loss
        align_loss = calu_loss_muti(loss_img, loss_txt, args, world_size, rank, logit_scale, image_features, rewrite_text_features)

        # distill loss
        distill_loss = calu_intent_loss_muti(args, intent_text_features, rewrite_text_features_intent, args.temperature, world_size, rank) # distillation loss
    else:
        # align loss
        align_loss = calu_loss(loss_img, loss_txt, args, logit_scale, image_features, rewrite_text_features)

        # intent loss
        distill_loss = calu_intent_loss(intent_text_features, rewrite_text_features_intent, args.temperature)
    return align_loss, distill_loss

def train(model, img2text, intent_analyser, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
    os.environ["WDS_EPOCH"] = str(epoch)
    model.eval()
    dataloader, sampler = data['train'].dataloader,  data['train'].sampler
    loss_img = nn.CrossEntropyLoss()
    loss_txt = nn.CrossEntropyLoss()

    if args.gpu is not None:
        loss_img = loss_img.cuda(args.gpu)
        loss_txt = loss_txt.cuda(args.gpu)

    if args.distributed and sampler is not None:
        sampler.set_epoch(epoch)

    num_batches_per_epoch = dataloader.num_batches

    end = time.time()
    for i, batch in enumerate(dataloader):
        step = num_batches_per_epoch * epoch + i
        scheduler(step)

        optimizer.zero_grad()

        # images, texts = batch[0], batch[1]
        images, texts, texts_with_blank, blank_without_text, rewirted_texts, rewrited_texts_with_blank, intention_texts, intention_texts_with_blank = batch[0], batch[1], batch[2], batch[3], batch[4], batch[5], batch[6], batch[7]


        # print(texts, "\n", rewirted_texts, "\n", intention_texts, "\n")


        if len(batch) == 3 and args.use_debiased_sampler:
            data_identifier = torch.unique(batch[2])[0].numpy()
        else:
            data_identifier = -1
        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)

        data_time = time.time() - end

        m = model.module if args.distributed or args.dp else model

        # with automatic mixed precision.
        if args.precision == "amp":
            with autocast():
                align_loss, distill_loss = get_loss_img2text_rewrite(m, img2text, intent_analyser,
                                                       images, texts, texts_with_blank, blank_without_text, rewirted_texts, rewrited_texts_with_blank, intention_texts, intention_texts_with_blank,
                                                       loss_img, loss_txt, args, data_identifier)
                total_loss = align_loss + distill_loss
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
            scaler.update()

        else:
            total_loss = get_loss_img2text(m, img2text, images, loss_img, loss_txt, args, data_identifier)
            total_loss.backward()
            optimizer.step()

        # Note: we clamp to 4.6052 = ln(100), as in the original paper.
        #m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)

        # Note: we clamp to 4.6052 = ln(100), as in the original paper.
        #m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)

        batch_time = time.time() - end
        end = time.time()

        if is_master(args) and (i % 2) == 0:
            num_samples = i * len(images) * args.world_size
            samples_per_epoch = dataloader.num_samples
            percent_complete = 100.0 * i / num_batches_per_epoch
            logging.info(
                f"Train Epoch: {epoch} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]\t"
                f"Total Loss: {total_loss.item():.7f}\tAlign Loss: {align_loss.item():.7f}\tDistill Loss: {distill_loss.item():.7f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
                f"\tLR: {optimizer.param_groups[0]['lr']:5f}\tlogit_scale {m.logit_scale.data:.3f}"
            )
            # save train loss / etc.

            timestep = epoch * num_batches_per_epoch + i
            log_data = {
                "total_loss": total_loss.item(),
                "align_loss": align_loss.item(),
                "Distill_loss": distill_loss.item(),
                "data_time": data_time,
                "batch_time": batch_time,
                "scale":  m.logit_scale.data.item(),
                "lr": optimizer.param_groups[0]["lr"]
            }


            for name, val in log_data.items():
                name = "train/" + name
                if tb_writer is not None:
                    tb_writer.add_scalar(name, val, timestep)
                if args.wandb:
                    wandb.log({name: val, 'step': timestep})