import datetime
import json
import logging
import os
import time
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn.functional as F
import webdataset as wds
from lavis.common.dist_utils import (
    download_cached_file,
    get_rank,
    get_world_size,
    is_main_process,
    main_process,
)
from lavis.common.registry import registry
from lavis.common.utils import is_url
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split, prepare_sample
from lavis.datasets.datasets.dataloader_utils import (
    IterLoader,
    MultiIterLoader,
    PrefetchLoader,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.dataset import ChainDataset
from lavis.runners.runner_base import RunnerBase
from lavis.common.logger import MetricLogger, SmoothedValue

import torchattacks
from torch import nn
from torchvision.utils import save_image


@registry.register_runner("runner_attack")
class RunnerAttack(RunnerBase):

    def __init__(self, cfg, task, model, datasets, job_id):
        super().__init__(cfg, task, model, datasets, job_id)

    def train(self):
        """Get attacked samples"""
        # TODO: all implements based on test dataloader which samples in order.
        self.model.eval()
        atk = MultimodalPGD(self.model, steps=30)

        data_loader = self.dataloaders.get('test', None)
        image_names = data_loader.loader.dataset.image

        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
        log_freq = 1

        # if iter-based runner, schedule lr based on inner epoch.
        header = "Attack: "
        # epoch-based runner

        texts = data_loader.dataset.text
        with open("./data/flickr30k/annotations/train.json", "r") as f:
            anns = json.load(f)
        img2id = {ann['image']: ann['image_id'] for ann in anns}

        with torch.no_grad():
            sims, text_ids, text_atts = self.get_sims_matrix(self.model, data_loader)

        # if not hasattr(data_loader, "__next__"):
        #     # convert to iterator if not already
        #     data_loader = iter(data_loader)

        adv_embeds = []
        adv_anns = []
        adv_logits = []
        realign_logits = []
        adv_scores = []
        for i, samples in enumerate(metric_logger.log_every(data_loader, log_freq, header)):

            samples = prepare_sample(samples, cuda_enabled=True)  # TODO: make this configable
            samples.update(
                {
                    "epoch": 0,
                    "num_iters_per_epoch": len(data_loader),
                    "iters": i,
                }
            )

            adv_images, costs, logits, adv_image_ebds, adv_match_dict, adv_df_logits, adv_score_logits = atk(samples, sims=sims, texts=texts, text_ids=text_ids, text_atts=text_atts)

            metric_logger.update(loss=costs[-1].item())

            adv_embeds.append(adv_image_ebds.cpu())
            adv_logits.append(adv_df_logits.cpu())
            realign_logits.append(logits[-1][:, :5].detach().cpu())
            adv_scores.append(adv_score_logits.cpu())
            for k, v in adv_match_dict.items():
                for l in v:
                    adv_anns.append({"image": data_loader.loader.dataset.image[k], "caption": texts[l], "image_id": img2id[data_loader.loader.dataset.image[k]]})

            # mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(self.device)
            # std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(self.device)
            #
            # adv_images = adv_images * std[:, None, None] + mean[:, None, None]  # [-1.79, 1.93] -> [0, 1]
            print(adv_match_dict)
            for j, img in enumerate(adv_images):
                idx = i*samples["image"].shape[0] + j
                # print(f"{image_names[idx]}: {logits[0][j][0]} -> {logits[9][j][0]},   {logits[0][j][11]} -> {logits[9][j][11]}")
                # print(f"{idx}: {image_names[idx]}: {logits[0][j][-1]} -> [{adv_match_dict[idx]}]{logits[9][j][-1]}")  # [adv_step, img_idx, text_idx](text_idx changes every step)
                print(f"{idx}: {image_names[idx]}: {logits[0][j][:5]} -> {logits[-1][j][:5]}")  # [adv_step, img_idx, text_idx]
                # print(f"{adv_df_logits[j*5:j*5+5]}")

            #     save_image(img, f"./output/adv_img_test_{image_names[idx]}.jpg")

        # after train_epoch()
        # gather the stats from all processes
        adv_embeds = torch.cat(adv_embeds)
        adv_logits = torch.cat(adv_logits)
        realign_logits = torch.cat(realign_logits)
        adv_scores = torch.cat(adv_scores)
        torch.save(adv_embeds, "./output/dynamic_adv_img_embeds_sim-sort_k128_rank0-4.pt")
        torch.save(adv_logits, "./output/dynamic_adv_logits_sim-sort_k128_rank0-4.pt")
        torch.save(realign_logits, "./output/dynamic_realign_logits_sim-sort_k128_rank0-4.pt")
        torch.save(adv_scores, "./output/dynamic_adv_scores_sim-sort_k128_rank0-4.pt")
        json.dump(adv_anns, open(f"./output/dynamic_adv_anns_sim-sort_k128_rank0-4.json", "w"))
        metric_logger.synchronize_between_processes()
        logging.info("Averaged stats: " + str(metric_logger.global_avg()))
        return {
            k: "{:.3f}".format(meter.global_avg)
            for k, meter in metric_logger.meters.items()
        }

    def get_sims_matrix(self, model, data_loader):
        # texts = [ann['caption'] for ann in data_loader._dataloader.loader.dataset.datasets[0].annotation]
        texts = data_loader.dataset.text
        num_text = len(texts)
        text_bs = 128
        text_ids = []
        text_embeds = []
        text_atts = []
        for i in range(0, num_text, text_bs):
            if i/ text_bs % 5 == 0:
                print(f"text_ebd: {i}/{num_text}")
            text = texts[i: min(num_text, i + text_bs)]
            text_input = model.tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=35,
                return_tensors="pt",
            ).to(model.device)
            text_output = model.text_encoder.forward_text(text_input)
            text_embed = F.normalize(
                model.text_proj(text_output.last_hidden_state[:, 0, :])
            )
            text_embeds.append(text_embed)
            text_ids.append(text_input.input_ids)
            text_atts.append(text_input.attention_mask)

        text_embeds = torch.cat(text_embeds, dim=0)
        text_ids = torch.cat(text_ids, dim=0)
        text_atts = torch.cat(text_atts, dim=0)
        if hasattr(model.tokenizer, "enc_token_id"):
            text_ids[:, 0] = model.tokenizer.enc_token_id

        image_embeds = []
        for samples in data_loader:
            image = samples["image"]

            image = image.to(model.device)
            image_feat = model.visual_encoder.forward_features(image)
            image_embed = model.vision_proj(image_feat[:, 0, :])
            image_embed = F.normalize(image_embed, dim=-1)

            image_embeds.append(image_embed)
            if len(image_embeds) % 50 == 0:
                print(f"image_ebd: {len(image_embeds)}")

        image_embeds = torch.cat(image_embeds, dim=0)
        print(f"image_ebd: {image_embeds.shape}")

        sims_matrix = image_embeds @ text_embeds.t()
        return sims_matrix, text_ids, text_atts


class MultimodalPGD(torchattacks.PGD):

    def get_logits_and_labels(self, inputs, adv_labels=None, sim_only=False, is_targeted=True, *args, **kwargs):
        model = kwargs.pop("model")
        image_idxs = kwargs.pop("image_idx")
        sims = kwargs.pop("sims")
        texts = kwargs.pop("texts")
        text_ids = kwargs.pop("text_ids")
        text_atts = kwargs.pop("text_atts")
        k_test = kwargs.pop("k_test")

        logits = []
        labels = []
        adv_match_dict = {}
        for adv_i, (idx, image) in enumerate(zip(image_idxs, inputs)):
            topk_sim, topk_idx = sims[idx].topk(k=k_test, largest=True, dim=0)

            # forward img and topk similar texts with grad
            img_encoder_output = model.visual_encoder.forward_features(image.unsqueeze(dim=0))[0]
            image_embed = model.vision_proj(img_encoder_output[0, :])
            image_embed = F.normalize(image_embed, dim=-1)
            topk_texts = [texts[idx] for idx in topk_idx]
            topk_texts_input = model.tokenizer(
                topk_texts,
                padding="max_length",
                truncation=True,
                max_length=35,
                return_tensors="pt",
            ).to(model.device)
            topk_texts_output = model.text_encoder.forward_text(topk_texts_input)
            topk_texts_embed = F.normalize(model.text_proj(topk_texts_output.last_hidden_state[:, 0, :]))

            topk_sim = topk_texts_embed @ image_embed.t()
            # gpu_idx = torch.cuda.current_device()
            # print(torch.cuda.memory_allocated(gpu_idx)/(1024**3))

            if not sim_only:
                with torch.no_grad():
                    encoder_output = img_encoder_output.unsqueeze(dim=0).repeat(k_test, 1, 1)
                    encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
                        model.device
                    )
                    output = model.text_encoder(
                        text_ids[topk_idx],
                        attention_mask=text_atts[topk_idx],
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        return_dict=True,
                    )
                    score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
                    t_logit = score + topk_sim
                    topk_score, topk_s_idx = t_logit.topk(k=k_test, dim=0)
                    ranked_s_idx = topk_idx[topk_s_idx]

                encoder_output = img_encoder_output.unsqueeze(dim=0).repeat(12, 1, 1)
                encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
                    model.device
                )

                tmp_rank_s_idx = ranked_s_idx[:12]
                if is_targeted:
                    if adv_labels is None:
                        select_rank = [0, 1, 2, 3, 4]
                        output = model.text_encoder(
                            text_ids[tmp_rank_s_idx],
                            attention_mask=text_atts[tmp_rank_s_idx],
                            encoder_hidden_states=encoder_output,
                            encoder_attention_mask=encoder_att,
                            return_dict=True,
                        )
                        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
                        logits.append(score + sims[idx][tmp_rank_s_idx])

                        not_match_idx = []
                        for i, index in enumerate(tmp_rank_s_idx):
                            if index // 5 == idx:
                                # print(f"{texts[index]} : {sim}")
                                pass
                            else:
                                not_match_idx.append(i)
                        not_match_idx = torch.tensor(not_match_idx)

                        atked_label = torch.zeros(12)
                        target_idxs = not_match_idx[select_rank]
                        adv_match_dict[int(idx)] = list(tmp_rank_s_idx[target_idxs])
                        atked_label[target_idxs] = 1.
                        labels.append(atked_label)
                    else:
                        adv_match_dict[int(idx)] = adv_labels[adv_i]
                        atked_label = torch.zeros(12)

                        for i, l in enumerate(adv_labels[adv_i]):
                            if l not in tmp_rank_s_idx:
                                tmp_rank_s_idx[i] = l
                                atked_label[i] = 1.
                            else:
                                target_idx = torch.where(tmp_rank_s_idx == l)[0]
                                atked_label[target_idx] = 1.
                        labels.append(atked_label)

                        output = model.text_encoder(
                            text_ids[tmp_rank_s_idx],
                            attention_mask=text_atts[tmp_rank_s_idx],
                            encoder_hidden_states=encoder_output,
                            encoder_attention_mask=encoder_att,
                            return_dict=True,
                        )
                        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
                        logits.append(score + sims[idx][tmp_rank_s_idx])
                else:
                    # not targeted, return df labels
                    atked_label = torch.zeros(12).to(model.device)
                    df_mask = ((tmp_rank_s_idx >= idx*5) & (tmp_rank_s_idx < idx*5+5))
                    atked_label[df_mask] = 1
                    labels.append(atked_label)
                    output = model.text_encoder(
                        text_ids[tmp_rank_s_idx],
                        attention_mask=text_atts[tmp_rank_s_idx],
                        encoder_hidden_states=encoder_output,
                        encoder_attention_mask=encoder_att,
                        return_dict=True,
                    )
                    score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
                    logit = score + sims[idx][tmp_rank_s_idx]
                    logits.append(logit)
                    adv_match_dict[int(idx)] = tmp_rank_s_idx[logit.topk(k=5)[1]]

                if idx == 0:
                    print(f"{idx}: {tmp_rank_s_idx}  ->   {adv_match_dict}")
            else:
                # TODO: multiple target support & dynamic support
                logits.append(topk_sim)

                not_match_idx = []
                for i, (sim, index) in enumerate(zip(topk_sim, topk_idx)):
                    if index // 5 == idx:
                        # print(f"{texts[index]} : {sim}")
                        pass
                    else:
                        not_match_idx.append(i)

                atked_label = torch.zeros(k_test)
                target_idx = not_match_idx[0]
                adv_match_dict[int(idx)] = int(topk_idx[target_idx])
                atked_label[target_idx] = 1.
                labels.append(atked_label)

        logits = torch.stack(logits)
        labels = torch.stack(labels)

        return logits, labels, adv_match_dict

    def forward(self, samples, _=None, is_targeted=False, *args, **kwargs):
        sims = kwargs.pop("sims")
        texts = kwargs.pop("texts")
        text_ids = kwargs.pop("text_ids")
        text_atts = kwargs.pop("text_atts")
        sim_only = kwargs.pop("sim_only") if "sim_only" in kwargs.keys() else False

        images = samples["image"]
        image_idx = samples['index']  # TODO: maybe index != image_id
        images = images.clone().detach().to(self.device)

        # if self.targeted:
        #     target_labels = self.get_target_label(images, labels)

        loss = nn.CrossEntropyLoss()
        adv_images = images.clone().detach()

        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(self.device)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(self.device)

        adv_images = adv_images * std[:, None, None] + mean[:, None, None]  # [-1.79, 1.93] -> [0, 1]

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(
                -self.eps, self.eps
            )
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        adv_images = (adv_images - mean[:, None, None]) / std[:, None, None]  # [0, 1] -> [-1.79, 1.93]

        costs = []
        inter_logits = []
        adv_match_dict = None
        for s in range(self.steps):
            adv_images.requires_grad = True
            adv_labels = None
            if adv_match_dict is not None:
                for idx in image_idx:
                    idx = int(idx)
                    if idx in adv_match_dict.keys():
                        if adv_labels is None:
                            adv_labels = []
                        adv_labels.append(adv_match_dict[idx])
            logits, target_labels, adv_match_dict = self.get_logits_and_labels(adv_images, adv_labels=adv_labels, sim_only=False, image_idx=image_idx, model=self.model, sims=sims, texts=texts, text_ids=text_ids, text_atts=text_atts, k_test=128, is_targeted=is_targeted)

            if is_targeted:
                target_labels = target_labels.clone().detach().to(self.device)
                cost = -loss(logits, target_labels)
            else:
                if target_labels[:, :5].sum() == 0 and s > 0:
                    print(f"Attacked {adv_images.shape[0]} img(s) at step {s}.")
                    break
                cost = loss(logits, target_labels)

            # Update adversarial images
            grad = torch.autograd.grad(
                cost, adv_images, retain_graph=False, create_graph=False
            )[0]

            adv_images = adv_images * std[:, None, None] + mean[:, None, None]  # [-1.79, 1.93] -> [0, 1]
            adv_images = adv_images.detach() + self.alpha * grad.sign()
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()
            adv_images = (adv_images - mean[:, None, None]) / std[:, None, None]  # [0, 1] -> [-1.79, 1.93]
            costs.append(cost)
            inter_logits.append(logits)

        with torch.no_grad():
            adv_image_embeds = self.model.visual_encoder.forward_features(adv_images)

            df_texts = []
            for idx in image_idx:
                df_texts += texts[idx*5:idx*5+5]
            df_texts_input = self.model.tokenizer(
                df_texts,
                padding="max_length",
                truncation=True,
                max_length=35,
                return_tensors="pt",
            ).to(self.model.device)
            df_texts_output = self.model.text_encoder.forward_text(df_texts_input)
            df_texts_embed = F.normalize(self.model.text_proj(df_texts_output.last_hidden_state[:, 0, :]))

            if not sim_only:
                encoder_output = torch.cat([adv_img_embed.unsqueeze(dim=0).repeat(5, 1, 1) for adv_img_embed in adv_image_embeds]).to(adv_image_embeds.device)
                encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
                    self.model.device
                )
                image_embeds = self.model.vision_proj(encoder_output[:, 0, :])
                image_embeds = F.normalize(image_embeds, dim=-1)
                df_sims_all = image_embeds @ df_texts_embed.t()
                df_sims = []
                for i, sims in enumerate(df_sims_all[::5]):
                    df_sims.append(sims[i*5:i*5+5])
                df_sims = torch.cat(df_sims).to(df_sims_all.device)
                output = self.model.text_encoder(
                    df_texts_input.input_ids,
                    attention_mask=df_texts_input.attention_mask,
                    encoder_hidden_states=encoder_output,
                    encoder_attention_mask=encoder_att,
                    return_dict=True,
                )
                score_logits = self.model.itm_head(output.last_hidden_state[:, 0, :])
                score = score_logits[:, 1]
                adv_df_logits = score + df_sims

        return adv_images, costs, inter_logits, adv_image_embeds, adv_match_dict, adv_df_logits, score_logits