import os.path as osp
from collections import OrderedDict
import math

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.autograd import Variable
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
import os
import copy
_tokenizer = _Tokenizer()

def Sinkhorn(K, u, v, max_iter=100):
    r = torch.ones_like(u, dtype=K.dtype, device=K.device)
    c = torch.ones_like(v, dtype=K.dtype, device=K.device)
    thresh = 1e-2
    for i in range(max_iter):
        r0 = r
        r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1)
        c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1)
        err = (r - r0).abs().mean()
        if err.item() < thresh:
            break

    T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K

    return T

def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")
    design_details = {"trainer": 'PBPrompt',
                      "vision_depth": 0,
                      "language_depth": 0,
                      "language_ctx": 0,
                      "ctx_length": cfg.TRAINER.PBPROMPT.N_CTX}
    model = clip.build_model(state_dict or model.state_dict(), design_details)

    return model


class TextEncoderPBPrompt(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
        combined = [x, compound_prompts_deeper_text, 0]  # third argument is the counter which denotes depth of prompt
        outputs = self.transformer(combined)
        x = outputs[0]  # extract the x back from here
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        # x = x @ self.text_projection

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        text_features = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return text_features


class BayesianPromptLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.PBPROMPT.N_CTX
        ctx_init = cfg.TRAINER.PBPROMPT.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        assert cfg.TRAINER.PBPROMPT.TEXT_PROMPT_DEPTH >= 1
        self.compound_prompts_depth = cfg.TRAINER.PBPROMPT.TEXT_PROMPT_DEPTH  # max=12, but will create 11 such shared prompts.
        self.prompt_sample_number = cfg.TRAINER.PBPROMPT.TEXT_PROMPT_NUMBER

        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        self.compound_prompts_text = nn.Parameter(torch.empty(self.compound_prompts_depth - 1, n_ctx, ctx_dim, dtype=dtype))
        nn.init.normal_(self.compound_prompts_text, std=0.02)

        # self.layer_emb = nn.Embedding(self.compound_prompts_depth, ctx_dim)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")
        print(f'TEST prompt sample number: "{self.prompt_sample_number}"')

        self.ctx = nn.Parameter(ctx_vectors)

        self.infer_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(ctx_dim, 2 * ctx_dim, bias=True)),
        ]))

        self.attn = nn.MultiheadAttention(ctx_dim, 8, batch_first=True)   


        if cfg.TRAINER.PBPROMPT.PREC == "fp16":
            self.infer_net.half()
            self.attn.half()
            # self.layer_emb.half()

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]

        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])  # (n_cls, n_tkn)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)    # (n_cls, n_tkn, model_dim)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS  (n_cls, 1, model_dim)
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS

        contextual_embedding = []

        emb = embedding
        n_l = name_lens

        for each_embedding, each_len in zip(emb, n_l):
            if each_len == 1:
                contextual_embedding.append(each_embedding[1 + n_ctx])
            else:
                for i in range(each_len):
                    contextual_embedding.append(each_embedding[1 + n_ctx + i])
                # contextual_embedding.append(each_embedding[1 + n_ctx:1 + n_ctx + each_len].mean(0))

        self.register_buffer("contextual_embedding", torch.stack(contextual_embedding))

        self.prior_mu = self.contextual_embedding

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.ctx_dim = ctx_dim
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]

        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,  # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )

        return prompts

    def reparameterize(self, mu, logvar, layer_depth=1, sample_num=10, keep_dim=False):
        """Returns a sample from a Gaussian distribution via reparameterization.
        todo: test
        """
        logvar = logvar.unsqueeze(0)
        logvar = logvar.unsqueeze(0)
        logvar = logvar.expand(sample_num, layer_depth, -1, -1)
        mu = mu.unsqueeze(0)
        mu = mu.unsqueeze(0)
        mu = mu.expand(sample_num, layer_depth, -1, -1)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        # eps = torch.normal(0, 0.02, size=mu.size(), device=mu.device, dtype=mu.dtype)
        if keep_dim:
            return eps.mul_(std).add_(mu)
        return (eps.mul_(std).add_(mu)).mean(0)

    def forward(self):
        prefix = self.token_prefix    # (n_cls, 1, model_dim)
        suffix = self.token_suffix

        contextual_embedding = self.contextual_embedding

        ctx = self.ctx  #### n_ctx, ctx_dim

        mu_t, logsigma_t = torch.chunk(self.infer_net(contextual_embedding), 2, dim=-1)  ### n_class, ctx_dim
        mu = []
        logsigma = []

        final_name_lens = self.name_lens
        pos = 0
        for i, length in enumerate(final_name_lens):
            mu.append(mu_t[pos:pos + length].mean(0))
            logsigma.append(logsigma_t[pos:pos + length].mean(0))
            pos += length
        mu = torch.stack(mu)
        logsigma = torch.stack(logsigma)

        #### mu for TEST
        if not self.training:
            self.prompt_sample_number = 10  # 1

        keep_dim = True

        if self.training:
            z = self.reparameterize(mu, logsigma, layer_depth=self.compound_prompts_depth, sample_num=self.prompt_sample_number, keep_dim=keep_dim)  ### n_class, ctx_dim
        else:
            #### mu for TEST
            # z = mu.unsqueeze(0).expand(self.compound_prompts_depth, -1, -1)
            # z = z.unsqueeze(0)
            z = self.reparameterize(mu, logsigma, layer_depth=self.compound_prompts_depth, sample_num=self.prompt_sample_number, keep_dim=keep_dim)  ### n_class, ctx_dim

        # z.shape : [sample_number, prompts_depth, n_cls, ctx_dim]
        z = z.permute(1, 0, 2, 3)  # [prompts_depth, sample_number, n_cls, ctx_dim]
        z = z.contiguous().view([self.compound_prompts_depth * self.prompt_sample_number * self.n_cls, -1])
        z = z.unsqueeze(1)  # [prompts_depth * sample_number * n_cls, 1, ctx_dim]

        prefix = prefix.unsqueeze(0).expand(self.prompt_sample_number, -1, -1, -1)
        suffix = suffix.unsqueeze(0).expand(self.prompt_sample_number, -1, -1, -1)
        prefix = prefix.contiguous().view(self.prompt_sample_number * self.n_cls, prefix.shape[-2], prefix.shape[-1])
        suffix = suffix.contiguous().view(self.prompt_sample_number * self.n_cls, suffix.shape[-2], suffix.shape[-1])

        ctx_one = ctx.unsqueeze(0)  # [1, n_ctx, ctx_dim]
        ctx_one = torch.cat((ctx_one, self.compound_prompts_text), dim=0)  # [depth, ctx_num, ctx_dim]
        ctx_one = ctx_one.unsqueeze(0).expand(self.prompt_sample_number * self.n_cls, -1, -1, -1)  # [sample_number * n_cls, prompt_depth, ctx_num, ctx_dim]
        ctx_one = ctx_one.permute(1, 0, 2, 3)  # [prompt_depth, sample_number * n_cls, ctx_num, ctx_dim]
        ctx_one = ctx_one.contiguous().view(self.compound_prompts_depth * self.prompt_sample_number * self.n_cls, self.n_ctx, self.ctx_dim)  # [prompt_depth * sample_number * n_cls, ctx_num, ctx_dim]


        # Attention Net
        x = torch.cat([z, ctx_one], dim=1)  # [sample_number * prompts_depth * n_cls, n_ctx + 1, ctx_dim]
        x = F.normalize(x, dim=-1)
        prompt_list, _ = self.attn(x, x, x)
        prompt_list += x  # [sample_number * prompts_depth * n_cls, n_ctx + 1, ctx_dim]
        prompt_list = prompt_list.contiguous().view(
            [self.compound_prompts_depth, self.prompt_sample_number * self.n_cls, self.n_ctx+1, self.ctx_dim])  # [prompts_depth, sample_number * n_cls, n_ctx + 1, ctx_dim]

        prompt = self.construct_prompts(prompt_list[0, :, 1:, :], prefix, suffix)  # [sample_number * n_cls, clip_sen_len, ctx_dim]

        return prompt, prompt_list[1:, :, 1:, :], mu_t, logsigma_t


def kl_gaussian(mu, prior_mu, logsigma):

    n_dim = mu.shape[1]
    mu = mu.type(dtype=torch.float64)
    logsigma = logsigma.type(dtype=torch.float64)
    kl_loss = (0.5 * ((mu - prior_mu).pow(2).sum(1) + logsigma.exp().sum(1) - logsigma.sum(1) - n_dim)).mean()

    return kl_loss


class CustomCLIP(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = BayesianPromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoderPBPrompt(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.n_cls = self.prompt_learner.n_cls
        self.eps = 0.1
        self.backbone = cfg.MODEL.BACKBONE.NAME

    def cost_ct_mmp(self, inner_p, cost_c, score):
        '''
            score: shape [bs * sample_number, n_cls]
        '''
        dis_d = torch.exp(inner_p)  # [bs * sample_number, M, N]

        forward_dis = dis_d.permute(0, 2, 1) * score.unsqueeze(-1)
        forward_dis = forward_dis.permute(0, 2, 1)
        forward_pi = forward_dis / torch.sum(forward_dis, dim=2, keepdim=True)
        # forward_pi = dis_d / torch.sum(dis_d, dim=2, keepdim=True)
        backward_pi = dis_d / torch.sum(dis_d, dim=1, keepdim=True)

        forward_cost = torch.sum(cost_c * forward_pi, dim=(1, 2)).mean()
        backward_cost = torch.sum(cost_c * backward_pi, dim=(1, 2)).mean()

        return forward_cost, backward_cost

    def forward(self, image, label=None):
        tokenized_prompts = self.tokenized_prompts
        logit_scale = self.logit_scale.exp()

        image_features = self.image_encoder(image.type(self.dtype), patch=True)
        image_features = F.normalize(image_features, dim=-1)

        if self.backbone == "ViT-B/16":
            image_features = image_features.permute(1, 0, 2)

        prompt, deep_compound_prompts_text, mu_t, logsigma_t = self.prompt_learner()

        tokenized_prompts = tokenized_prompts.unsqueeze(0).expand(self.prompt_learner.prompt_sample_number, -1, -1)
        tokenized_prompts = tokenized_prompts.contiguous().view(self.prompt_learner.prompt_sample_number * self.prompt_learner.n_cls, -1)

        train_sample_number = self.prompt_learner.prompt_sample_number

        text_features = self.text_encoder(prompt, tokenized_prompts, deep_compound_prompts_text)
        text_features = F.normalize(text_features, dim=-1)

        text_features = text_features.view(train_sample_number, self.prompt_learner.n_cls, -1)  # [sample_number, n_cls, vis_dim]
        text_features_pool = text_features.mean(0)  # [n_cls, -1]

        # mu_zero = torch.zeros_like(mu_t)
        kl_loss = kl_gaussian(mu_t, self.prompt_learner.prior_mu, logsigma_t)

        logits = logit_scale * image_features[0] @ text_features_pool.t()

        if self.prompt_learner.training:
            # CT regularization
            logits_score = F.softmax(logits, dim=-1)  # [bs, n_cls]
            patch_emb = image_features[1:].permute(1, 0, 2)  # [bs, patch_number, dim]
            M = patch_emb.shape[1]  # patch_number
            N = text_features.shape[1]  # n_cls
            bs = image.shape[0]  # batch_size
            text_number = text_features.shape[0]  # sample_number

            sim = torch.einsum('mbd,ncd->mnbc', patch_emb, text_features).contiguous()  # [bs, sample_number, patch_number, n_cls]

            yy = logits_score.to(sim.dtype)
            yy = yy.unsqueeze(0).expand(text_number, -1, -1).contiguous().view(bs * text_number, N)  # [bs * sample_number, n_cls]

            sim = sim.view(bs * text_number, M, N)
            cost = torch.exp(-sim)

            forward_cost, backward_cost = self.cost_ct_mmp(sim, cost, yy)
            total_cost = 0.5 * forward_cost + 0.5 * backward_cost

            return F.cross_entropy(logits, label), 0.01 * kl_loss, 0.01 * total_cost

        return logits


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


@TRAINER_REGISTRY.register()
class PBPrompt(TrainerX):
    def check_cfg(self, cfg):
        assert cfg.TRAINER.PBPROMPT.PREC in ["fp16", "fp32", "amp"]

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        
        if cfg.TRAINER.PBPROMPT.PREC == "fp32" or cfg.TRAINER.PBPROMPT.PREC == "amp":
            # CLIP's default precision is fp16
            clip_model.float()

        print("Building custom CLIP")
        self.model = CustomCLIP(cfg, classnames, clip_model)

        print("Turning off gradients in both the image and the text encoder")
        name_to_update = "prompt_learner"
        
        for name, param in self.model.named_parameters():
            if name_to_update not in name:
                param.requires_grad_(False)
        
        # Double check
        enabled = set()
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)

        if cfg.DATASET.NAME== "ImageNet" or "SUN397":
            self.device = torch.device("cuda:0")
            device1 = torch.device("cuda")
            self.model.to(self.device)
            self.model.text_encoder.to(device1)
            self.model.text_encoder=nn.DataParallel(self.model.text_encoder)
        else:
            self.model.to(self.device)

        # NOTE: only give prompt_learner to the optimizer
        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)

        self.scaler = GradScaler() if cfg.TRAINER.PBPROMPT.PREC == "amp" else None

        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        # device_count = torch.cuda.device_count()
        # if device_count > 1:
        #     print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
        #     self.model = nn.DataParallel(self.model)

    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)

        model = self.model
        optim = self.optim
        scaler = self.scaler
        
        prec = self.cfg.TRAINER.PBPROMPT.PREC
        if prec == "amp":
            with autocast():
                loss, kl_loss = model(image, label)
            loss_total = loss + kl_loss #+ ot_loss
            optim.zero_grad()
            scaler.scale(loss_total).backward()
            scaler.step(optim)
            scaler.update()
        else:
            loss, kl_loss, ct_loss = model(image, label)
            loss_total = loss + kl_loss + ct_loss
            optim.zero_grad()
            loss_total.backward()
            # nn.utils.clip_grad_norm(parameters=model.parameters(), max_norm=20, norm_type=2)
            optim.step()

        loss_summary = {"loss": loss.item(), "kl": kl_loss.item(), "ct loss": ot_loss.item()}

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    def load_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # By default, the best model is loaded
        model_file = "model-best.pth.tar"

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            if "contextual_embedding" in state_dict:
                del state_dict["contextual_embedding"]
            
            if "prior_mu" in state_dict:
                del state_dict["prior_mu"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
            # set strict=False
            self._models[name].load_state_dict(state_dict, strict=False)
