import torch
import torch.nn as nn
import os
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler
from collections import defaultdict
from clip import clip
from clip.model import convert_weights
import torch.distributed as dist
from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT, CIFAR10_TEMPLATES, CIFAR100_TEMPLATES, \
    StanfordCars_TEMPLATES, Caltech101_TEMPLATES, DescribableTextures_TEMPLATES, EuroSAT_TEMPLATES, \
    Flowers102_TEMPLATES, Food101_TEMPLATES, SUN397_TEMPLATES, OxfordPets_TEMPLATES, UCF101_TEMPLATES

CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}.",
    "MYCIFAR10": "a photo of a {}.",
    "MYCIFAR100": "a photo of a {}.",
}

DATASET_PATH = {
    'MYCIFAR10': '/home/openness/data/cifar10/images/test',
    'MYCIFAR100': '/home/openness/data/cifar100/images/test',
    'ImageNet': '/home/openness/data/imagenet/images'
}

IMAGE_FEATURES = {
    'ViT-B/16': 'clip_vit_b16_features_grouped.pt',
    'ViT-B/32': 'clip_vit_b32_pretrained_image_features_grouped.pt',
    # 'ViT-B/32': 'clip_vit_b32_features_grouped.pt',
    'ViT-B/32/DeClip': 'declip_vit_b32_features_grouped.pt'
}

TEXT_FEATURES = {
    # 'ViT-B/32': 'clip_vit_b32_pretrained_text_features_grouped.pt',
    'ViT-B/32': 'clip_vit_b32_pretrained_text_features_grouped_filtered.pt',
    # 'ViT-B/32': 'clip_vit_b32_features_grouped.pt',
}


@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        if cfg.DATASET.NAME in ['FLOWERS_PETS_CARS', 'DTD_EUROSAT_CARS', 'FLOWERS_PETS_FOODS',
                                'CIFAR100_CALTECH101_SUN397', 'Food101_CALTECH101_UCF101', 'CIFAR10_CIFAR100_ImageNet']:
            prompts = []
            for c in classnames:
                dataset_name = self.dm.dataset.class2superclass[c]
                temp = CUSTOM_TEMPLATES[dataset_name]
                prompts.append(temp.format(c.replace("_", " ")))
        else:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        # prompts = [temp.format(c.replace("_", " "), self.dm.dataset.class2superclass[c]) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            assert not (cfg.IMAGE_ENHANCED_SHIFT and cfg.TEXT_ENHANCED_SHIFT)
            if cfg.IMAGE_ENHANCED_SHIFT:
                feature_dir = os.path.join(DATASET_PATH[cfg.DATASET.NAME], IMAGE_FEATURES[cfg.MODEL.BACKBONE.NAME])
                text_features = self.shift_text_features_with_image(cfg.RETRIEVED_NUM, feature_dir, classnames,
                                                                    text_features)
            if cfg.TEXT_ENHANCED_SHIFT:
                feature_dir = os.path.join(DATASET_PATH[cfg.DATASET.NAME], TEXT_FEATURES[cfg.MODEL.BACKBONE.NAME])
                text_features = self.shift_text_features_with_text(cfg.RETRIEVED_NUM, feature_dir, classnames,
                                                                   text_features)
            if cfg.CROSS_MODAL_MIXUP:
                self.id_cn_caption_tf_if = torch.load(cfg.PRETRAINED_FEATURE_DIR)
                self.tf = torch.cat(
                    [cn_caption_tf_if[-2] for cn_caption_tf_if in list(self.id_cn_caption_tf_if.values())])
                self.norm_tf = self.tf / self.tf.norm(dim=-1, keepdim=True)
                self.if_ = torch.cat(
                    [cn_caption_tf_if[-1] for cn_caption_tf_if in list(self.id_cn_caption_tf_if.values())])
                self.norm_if = self.if_ / self.if_.norm(dim=-1, keepdim=True)
                self.cn = [cn_caption_tf_if[0] for cn_caption_tf_if in list(self.id_cn_caption_tf_if.values())]
                # text_features = self.cross_modal_mixup_at_text(classnames, text_features)
                text_features = self.intra_modal_mixup_at_text(classnames, text_features)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def shift_text_features_with_image(self, retrieved_num, feature_dir, classnames, text_features):
        retrieved_features = torch.load(feature_dir)
        for i, cn in enumerate(classnames):
            cur_retrieved_features = retrieved_features[cn][:retrieved_num]
            # norm then average
            cur_retrieved_features = cur_retrieved_features / cur_retrieved_features.norm(dim=-1, keepdim=True)
            avg_retrieved_features = torch.mean(cur_retrieved_features, dim=0)
            # avg_retrieved_features = avg_retrieved_features / avg_retrieved_features.norm(dim=-1, keepdim=True)
            text_features[i] = (1 - self.cfg.SHIFT_LAMBDA) * text_features[
                i] + self.cfg.SHIFT_LAMBDA * avg_retrieved_features
            # text_features[i] += self.cfg.SHIFT_LAMBDA * avg_retrieved_features
        return text_features

    def shift_text_features_with_text(self, retrieved_num, feature_dir, classnames, text_features):
        retrieved_features = torch.load(feature_dir)
        for i, cn in enumerate(classnames):
            cur_retrieved_features = retrieved_features[cn]
            # norm then average
            cur_retrieved_features = cur_retrieved_features / cur_retrieved_features.norm(dim=-1, keepdim=True)
            avg_retrieved_features = torch.mean(cur_retrieved_features, dim=0)
            # avg_retrieved_features = avg_retrieved_features / avg_retrieved_features.norm(dim=-1, keepdim=True)
            # shift_lambda = SHIFT_LAMBDA[self.cfg.DATASET.NAME][self.dm.dataset.class2superclass[cn]]
            # text_features[i] = (1 - shift_lambda) * text_features[i] + shift_lambda * avg_retrieved_features
            text_features[i] = (1 - self.cfg.SHIFT_LAMBDA) * text_features[
                i] + self.cfg.SHIFT_LAMBDA * avg_retrieved_features
            # text_features[i] += self.cfg.SHIFT_LAMBDA * avg_retrieved_features
        return text_features

    def cross_modal_mixup_at_text(self, classnames, text_features):
        cn_if = defaultdict(list)
        for _, cn_caption_tf_if in self.id_cn_caption_tf_if.items():
            classname, caption, _, if_ = cn_caption_tf_if
            cns = classname.replace(' ', '_').split('_')
            for cn in cns:
                if cn.lower() in caption.lower():
                    cn_if[classname].append(if_)
                    break

        for i, cn in enumerate(classnames):
            if cn in cn_if:
                cur_retrieved_features = torch.cat(cn_if[cn], dim=0)
                # norm then average
                cur_retrieved_features = cur_retrieved_features / cur_retrieved_features.norm(dim=-1, keepdim=True)
                avg_retrieved_features = torch.mean(cur_retrieved_features, dim=0)
                # avg_retrieved_features = avg_retrieved_features / avg_retrieved_features.norm(dim=-1, keepdim=True)
                text_features[i] = (1 - self.cfg.SHIFT_LAMBDA) * text_features[
                    i] + self.cfg.SHIFT_LAMBDA * avg_retrieved_features
                # text_features[i] += self.cfg.SHIFT_LAMBDA * avg_retrieved_features
        return text_features

    def cross_modal_mixup_at_image(self, retrieved_num, image_features):
        image_sim = self.clip_model.logit_scale.exp() * image_features @ self.norm_if.t()  # (100, if_num)
        _, knn_indices = image_sim.topk(retrieved_num, dim=-1)  # knn_indices (100, 100)
        # gt_cn = [[self.cn[i] for i in knn_indices[j]] for j in range(knn_indices.size(0))]
        selected_tf = self.norm_tf[knn_indices]  # (100, 100, 512)
        avg_selected_tf = torch.mean(selected_tf, dim=-2)  # (100, 512)
        image_features = (1 - self.cfg.SHIFT_LAMBDA) * image_features + self.cfg.SHIFT_LAMBDA * avg_selected_tf
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features

    def intra_modal_mixup_at_text(self, classnames, text_features):
        cn_tf = defaultdict(list)
        for _, cn_caption_tf_if in self.id_cn_caption_tf_if.items():
            classname, caption, tf, _ = cn_caption_tf_if
            cns = classname.replace(' ', '_').split('_')
            for cn in cns:
                if cn.lower() in caption.lower():
                    cn_tf[classname].append(tf)
                    break

        for i, cn in enumerate(classnames):
            if cn in cn_tf:
                cur_retrieved_features = torch.cat(cn_tf[cn], dim=0)
                # norm then average
                cur_retrieved_features = cur_retrieved_features / cur_retrieved_features.norm(dim=-1, keepdim=True)
                avg_retrieved_features = torch.mean(cur_retrieved_features, dim=0)
                # avg_retrieved_features = avg_retrieved_features / avg_retrieved_features.norm(dim=-1, keepdim=True)
                text_features[i] = (1 - self.cfg.SHIFT_LAMBDA) * text_features[
                    i] + self.cfg.SHIFT_LAMBDA * avg_retrieved_features
                # text_features[i] += self.cfg.SHIFT_LAMBDA * avg_retrieved_features
        return text_features

    def intra_modal_mixup_at_image(self, retrieved_num, image_features):
        image_sim = self.clip_model.logit_scale.exp() * image_features @ self.norm_if.t()  # (100, if_num)
        _, knn_indices = image_sim.topk(retrieved_num, dim=-1)  # knn_indices (100, 100)
        # gt_cn = [[self.cn[i] for i in knn_indices[j]] for j in range(knn_indices.size(0))]
        selected_if = self.norm_if[knn_indices]  # (100, 100, 512)
        avg_selected_if = torch.mean(selected_if, dim=-2)  # (100, 512)
        image_features = (1 - self.cfg.SHIFT_LAMBDA) * image_features + self.cfg.SHIFT_LAMBDA * avg_selected_if
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        if self.cfg.CROSS_MODAL_MIXUP:
            # image_features = self.cross_modal_mixup_at_image(self.cfg.RETRIEVED_NUM, image_features)
            image_features = self.intra_modal_mixup_at_image(self.cfg.RETRIEVED_NUM, image_features)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        return logits

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode('eval')
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == 'val' and self.val_loader is not None:
            data_loader = self.val_loader
            print('Do evaluation on {} set'.format(split))
        else:
            data_loader = self.test_loader
            print('Do evaluation on test set')

        matches_index = []
        preds = []
        for batch_idx, batch in enumerate(data_loader):
            input, label = self.parse_batch_test(batch)
            output = self.model_inference(input)
            self.evaluator.process(output, label)

            pred = output.max(1)[1]
            matches = pred.eq(label).int()
            matches_index.append(matches)
            preds.extend(pred.cpu().detach().numpy().tolist())

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = '{}/{}'.format(split, k)
            if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == 0):
                self.write_scalar(tag, v, self.epoch)

        matches_index = torch.cat(matches_index).cpu().detach().numpy().tolist()
        wrong_instance = [self.dm.dataset.test[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        wrong_preds = [preds[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        # [impath, true_classname, wrong_classname]
        wrong_log = [[datum.impath, datum.classname, self.dm.dataset.classnames[wrong_pred]] for datum, wrong_pred in
                     zip(wrong_instance, wrong_preds)]

        return list(results.values())[0], wrong_log

    @torch.no_grad()
    def test_with_reassigned_adv_cn(self, split=None, reassigned_adv_cn=None):
        """A generic testing pipeline."""
        self.set_model_mode('eval')
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == 'val' and self.val_loader is not None:
            data_loader = self.val_loader
            print('Do evaluation on {} set'.format(split))
        else:
            data_loader = self.test_loader
            print('Do evaluation on test set')

        matches_index = []
        for batch_idx, batch in enumerate(data_loader):
            image, label = self.parse_batch_test(batch)

            if reassigned_adv_cn is not None:  # re-assign a new adv classname
                temp = CUSTOM_TEMPLATES[self.cfg.DATASET.NAME]
                adv_prompt = temp.format(reassigned_adv_cn.replace("_", " "))
                adv_prompt = clip.tokenize(adv_prompt).to(self.device)
                with torch.no_grad():
                    adv_text_features = self.clip_model.encode_text(adv_prompt)
                    adv_text_features = adv_text_features / adv_text_features.norm(dim=-1, keepdim=True)

            image_features = self.clip_model.encode_image(image)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logit_scale = self.clip_model.logit_scale.exp()
            logits = logit_scale * image_features @ torch.cat([self.text_features, adv_text_features]).t()

            output = logits
            self.evaluator.process(output, label)

            pred = output.max(1)[1]
            matches = pred.eq(label).int()
            matches_index.append(matches)

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = '{}/{}'.format(split, k)
            if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == 0):
                self.write_scalar(tag, v, self.epoch)
                # wandb.log({tag: v})

        matches_index = torch.cat(matches_index).cpu().detach().numpy().tolist()
        wrong_instance = [self.dm.dataset.test[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        wrong_log = [[datum.impath, datum.classname, reassigned_adv_cn] for datum in wrong_instance]

        return list(results.values())[0], wrong_log


@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
    """Prompt ensembling."""
    template_map = {'ImageNet': IMAGENET_TEMPLATES_SELECT, 'MYCIFAR10': CIFAR10_TEMPLATES,
                    'MYCIFAR100': CIFAR100_TEMPLATES, 'StanfordCars': StanfordCars_TEMPLATES,
                    'Caltech101': Caltech101_TEMPLATES, 'DescribableTextures': DescribableTextures_TEMPLATES,
                    'EuroSAT': EuroSAT_TEMPLATES, 'OxfordFlowers': Flowers102_TEMPLATES,
                    'Food101': Food101_TEMPLATES, 'SUN397': SUN397_TEMPLATES, 'OxfordPets': OxfordPets_TEMPLATES,
                    'UCF101': UCF101_TEMPLATES}

    # templates = IMAGENET_TEMPLATES
    templates = IMAGENET_TEMPLATES_SELECT

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        for params in clip_model.parameters():
            params.requires_grad_(False)

        if cfg.DATASET.NAME in ['FLOWERS_PETS_CARS', 'DTD_EUROSAT_CARS', 'FLOWERS_PETS_FOODS',
                                'CIFAR100_CALTECH101_SUN397', 'Food101_CALTECH101_UCF101', 'CIFAR10_CIFAR100_ImageNet']:
            mean_text_features = []
            for c in classnames:
                dataset_name = self.dm.dataset.class2superclass[c]
                templates = self.template_map[dataset_name]
                prompts = [temp.format(c.replace("_", " ")) for temp in templates]
                prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
                text_features = clip_model.encode_text(prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                text_features = text_features.mean(dim=0, keepdim=True)
                mean_text_features.append(text_features)
            mean_text_features = torch.cat(mean_text_features)
            mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
        else:
            self.templates = self.template_map[cfg.DATASET.NAME]
            # add custom-made prompt
            # if cfg.DATASET.NAME != "ImageNet":
            #     self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]

            num_temp = len(self.templates)
            print(f"Prompt ensembling (n={num_temp})")

            mean_text_features = 0
            for i, temp in enumerate(self.templates):
                prompts = [temp.format(c.replace("_", " ")) for c in classnames]
                prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
                text_features = clip_model.encode_text(prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                assert not (cfg.IMAGE_ENHANCED_SHIFT and cfg.TEXT_ENHANCED_SHIFT)
                if cfg.IMAGE_ENHANCED_SHIFT:
                    feature_dir = os.path.join(DATASET_PATH[cfg.DATASET.NAME], IMAGE_FEATURES[cfg.MODEL.BACKBONE.NAME])
                    text_features = self.shift_text_features_with_image(cfg.RETRIEVED_NUM, feature_dir, classnames,
                                                                        text_features)
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                if cfg.TEXT_ENHANCED_SHIFT:
                    feature_dir = os.path.join(DATASET_PATH[cfg.DATASET.NAME], TEXT_FEATURES[cfg.MODEL.BACKBONE.NAME])
                    text_features = self.shift_text_features_with_text(cfg.RETRIEVED_NUM, feature_dir, classnames,
                                                                       text_features)
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                if cfg.CROSS_MODAL_MIXUP:
                    self.id_cn_caption_tf_if = torch.load(cfg.PRETRAINED_FEATURE_DIR)
                    self.tf = torch.cat(
                        [cn_caption_tf_if[-2] for cn_caption_tf_if in list(self.id_cn_caption_tf_if.values())])
                    self.norm_tf = self.tf / self.tf.norm(dim=-1, keepdim=True)
                    self.if_ = torch.cat(
                        [cn_caption_tf_if[-1] for cn_caption_tf_if in list(self.id_cn_caption_tf_if.values())])
                    self.norm_if = self.if_ / self.if_.norm(dim=-1, keepdim=True)
                    self.cn = [cn_caption_tf_if[0] for cn_caption_tf_if in list(self.id_cn_caption_tf_if.values())]
                    # text_features = self.cross_modal_mixup_at_text(classnames, text_features)
                    text_features = self.intra_modal_mixup_at_text(classnames, text_features)
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                mean_text_features = mean_text_features + text_features
            mean_text_features = mean_text_features / num_temp
            mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)

        self.text_features = mean_text_features
        self.clip_model = clip_model
