import torch
from dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.dalle2_pytorch import l2norm
from dalle2_pytorch.optimizer import get_optimizer
from torchvision.utils import make_grid, save_image

from src.callbacks.txt2img_callbacks import generate_grid_samples
from src.datamodules.utils import split_test_full_data
from src.logger.jam_wandb import prefix_metrics_keys
from src.models.base_model import BaseModule
from src.models.loss_zoo import gradientOptimality

# pylint: disable=abstract-method,too-many-ancestors,arguments-renamed,line-too-long,arguments-differ,unused-argument


class Text2ImgModule(BaseModule):
    def __init__(self, cfg) -> None:
        super().__init__(cfg)
        self.clip = OpenAIClipAdapter(cfg.clip_model)
        self.image_embed_scale = cfg.image_embed_dim**0.5

    def get_real_data(self, batch):
        src_data, trg_data = batch
        # src_data: (image_embedding, tokenized_caption)
        # trg_data: (image_embedding, tokenized_caption)
        src_tokens, trg_img_emb = src_data[1], trg_data[0]
        text_embed, text_encodings = self.clip.embed_text(src_tokens)
        src_text_cond = {"text_embed": text_embed, "text_encodings": text_encodings}
        trg_img_emb *= self.image_embed_scale
        return src_text_cond, trg_img_emb

    def loss_f(self, src_text_cond, trg_img_emb, mask=None):
        with torch.no_grad():
            tx_tensor = self.map_t(**src_text_cond)
        # assert torch.isclose(tx_tensor.norm(dim=-1).mean(), trg_img_emb.norm(dim=-1).mean(),rtol=1e-2)
        f_tx, f_y = self.f_net(tx_tensor).mean(), self.f_net(trg_img_emb).mean()
        if self.cfg.optimal_penalty:
            gradient_penalty = gradientOptimality(
                self.f_net, tx_tensor, src_text_cond["text_embed"], self.cfg.coeff_go
            )
        else:
            gradient_penalty = 0.0
        f_loss = f_tx - f_y + gradient_penalty
        log_info = prefix_metrics_keys(
            {
                "f_tx": f_tx,
                "f_y": f_y,
                "gradient_penalty": gradient_penalty,
                "f_tx-f_y": f_tx - f_y,
            },
            "f_loss",
        )
        return f_loss, log_info

    def loss_map(self, src_text_cond, mask=None):
        # src_text_cond = {"text_embed": text_embed, "text_encodings": text_encodings}
        tx_tensor = self.map_t(**src_text_cond)
        cost_loss = self.cost_func(
            src_text_cond["text_embed"], l2norm(tx_tensor), self.cfg.coeff_mse
        )
        f_tx = self.f_net(tx_tensor).mean()
        map_loss = cost_loss - f_tx
        log_info = prefix_metrics_keys(
            {"cost_loss": cost_loss, "f_tx": f_tx}, "map_loss"
        )
        return map_loss, log_info

    def validation_step(self, batch, batch_idx):
        # evaluate cosine similarity
        trg_img_emb, src_tokens = batch
        text_embed, text_encodings = self.clip.embed_text(src_tokens)
        src_text_cond = {"text_embed": text_embed, "text_encodings": text_encodings}
        self.cos_similarity(src_text_cond, trg_img_emb)

    def cos_similarity(self, src_text_cond, trg_img_emb):
        if self.cfg.ema:
            with self.ema_map.average_parameters():
                tx_tensor = l2norm(self.map_t(**src_text_cond))
        src_txt_emb = src_text_cond["text_embed"]
        txt_trg_sim = -self.cost_func(src_txt_emb, trg_img_emb)
        txt_pf_sim = -self.cost_func(src_txt_emb, tx_tensor)

        pf_trg_sim = -self.cost_func(tx_tensor, trg_img_emb)

        rdm_idx = torch.randperm(trg_img_emb.shape[0])
        unrelated_sim = -self.cost_func(tx_tensor, src_txt_emb[rdm_idx])
        log_info = prefix_metrics_keys(
            {
                "baseline similarity": txt_trg_sim,
                "similarity with text": txt_pf_sim,
                "difference from baseline similarity": abs(txt_trg_sim - txt_pf_sim),
                "similarity with original image": pf_trg_sim,
                "similarity with unrelated caption": unrelated_sim,
            },
            "validation_cos_sim",
        )
        self.log_dict(log_info)

    def test_step(self, batch, batch_idx):
        if batch_idx > 100:
            assert 1 == 0, "Too many test samples, terminate earlier."
        test_example_data = split_test_full_data(batch, self.device)
        # TODO: this callback can have a problem, we hard code it with index.
        sampling_callback = self.trainer.callbacks[3]
        test_images, test_captions = generate_grid_samples(
            self,
            sampling_callback.decoder,
            sampling_callback.prior,
            test_example_data,
            device=self.device,
            skip_ema=True,
        )
        cherry_pick_img_grid = make_grid(test_images, nrow=1, padding=2, pad_value=0)
        save_image(cherry_pick_img_grid, f"img_{batch_idx}.png")
        torch.save(
            {"images": test_images, "captions": test_captions},
            f"raw_data_{batch_idx}.pt",
        )

    def configure_optimizers(self):
        # These parameters are from LAION pretrained prior.
        optim_map_kwargs = dict(
            lr=self.cfg.lr_T, wd=self.cfg.wd, eps=1e-6, group_wd_params=True
        )

        optimizer_map = get_optimizer(self.map_t.parameters(), **optim_map_kwargs)

        optim_f_kwargs = dict(
            lr=self.cfg.lr_f, wd=self.cfg.wd, eps=1e-6, group_wd_params=True
        )

        optimizer_f = get_optimizer(self.f_net.parameters(), **optim_f_kwargs)
        return optimizer_map, optimizer_f
