import torch
import torch.nn as nn

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler
from declip.declip import declip_res50, declip_vitb32
from clip.model import convert_weights
from collections import OrderedDict
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
import torch.distributed as dist

CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}.",
    "MYCIFAR10": "a photo of a {}.",
    "MYCIFAR100": "a photo of a {}.",
}

_MODELS = {
    "ViT-B/32/DeClip": "/home/renshuhuai/.cache/declip/vitb32.pth.tar",
    "RN50/DeClip": "/home/renshuhuai/.cache/declip/r50.pth.tar"
}


@TRAINER_REGISTRY.register()
class ZeroshotDeCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        declip_model = self.load_declip_to_cpu(cfg)
        declip_model.to(self.device)

        if cfg.DATASET.NAME in ['FLOWERS_PETS_CARS', 'DTD_EUROSAT_CARS', 'FLOWERS_PETS_FOODS',
                                'CIFAR100_CALTECH101_SUN397', 'Food101_CALTECH101_UCF101', 'CIFAR10_CIFAR100_ImageNet']:
            prompts = []
            for c in classnames:
                dataset_name = self.dm.dataset.class2superclass[c]
                temp = CUSTOM_TEMPLATES[dataset_name]
                prompts.append(temp.format(c.replace("_", " ")))
        else:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")

        with torch.no_grad():
            text_features = declip_model.encode_text(prompts)
            text_features = text_features / (
                        text_features.norm(dim=-1, keepdim=True) + 1e-10)  # adapted to declip source code

        self.text_features = text_features
        self.declip_model = declip_model

    def model_inference(self, image):
        image_features = self.declip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.declip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        return logits

    def load_declip_to_cpu(self, cfg):
        backbone_name = cfg.MODEL.BACKBONE.NAME
        model_path = _MODELS[backbone_name]

        try:
            # loading JIT archive
            model = torch.jit.load(model_path, map_location="cpu").eval()
            state_dict = None

        except RuntimeError:
            state_dict = torch.load(model_path, map_location="cpu")

        model = self.build_declip_model(state_dict or model.state_dict())

        return model

    def build_declip_model(self, state_dict: dict):  # TODO ugly, should be loaded from cfg
        ckpt = state_dict
        state_dict = OrderedDict()
        for k, v in ckpt['model'].items():
            state_dict[k.replace('module.', '')] = v

        vit = "visual.proj" in state_dict
        if vit:
            image_encode = {'embed_dim': state_dict['visual.proj'].size(1)}
            text_encode = {'bpe_path': '/home/renshuhuai/.cache/declip/bpe_simple_vocab_16e6.txt.gz',
                           'text_encode_type': 'Transformer',
                           'text_model_utils': {'random': False, 'freeze': False},
                           'embed_dim': state_dict['encode_text.text_projection.weight'].size(0)}
            clip = {'use_allgather': True, 'text_mask_type': 'MLM', 'return_nn_bank': True,
                    'feature_dim': state_dict['visual.proj'].size(1)}
            kwargs = {'image_encode': image_encode, 'text_encode': text_encode, 'clip': clip}
            model = declip_vitb32(**kwargs)
        else:
            image_encode = {'bn_group_size': 1, 'bn_sync_stats': True, 'embed_dim': 1024}  # TODO 32 in config?
            text_encode = {'bpe_path': '/home/renshuhuai/.cache/declip/bpe_simple_vocab_16e6.txt.gz',
                           'text_encode_type': 'Transformer',
                           'text_model_utils': {'random': False, 'freeze': False}, 'embed_dim': 2014}
            clip = {'use_allgather': True, 'text_mask_type': 'MLM', 'return_nn_bank': True}
            kwargs = {'image_encode': image_encode, 'text_encode': text_encode, 'clip': clip}
            model = declip_res50(**kwargs)

        model.load_state_dict(state_dict, strict=False)

        state_keys = set(state_dict.keys())
        model_keys = set(model.state_dict().keys())
        missing_keys = model_keys - state_keys
        for k in missing_keys:
            print(f'missing key: {k}')
        return model.eval()

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode('eval')
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == 'val' and self.val_loader is not None:
            data_loader = self.val_loader
            print('Do evaluation on {} set'.format(split))
        else:
            data_loader = self.test_loader
            print('Do evaluation on test set')

        matches_index = []
        preds = []
        for batch_idx, batch in enumerate(data_loader):
            input, label = self.parse_batch_test(batch)
            output = self.model_inference(input)
            self.evaluator.process(output, label)

            pred = output.max(1)[1]
            matches = pred.eq(label).int()
            matches_index.append(matches)
            preds.extend(pred.cpu().detach().numpy().tolist())

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = '{}/{}'.format(split, k)
            if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == 0):
                self.write_scalar(tag, v, self.epoch)

        matches_index = torch.cat(matches_index).cpu().detach().numpy().tolist()
        wrong_instance = [self.dm.dataset.test[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        wrong_preds = [preds[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        # [impath, true_classname, wrong_classname]
        wrong_log = [[datum.impath, datum.classname, self.dm.dataset.classnames[wrong_pred]] for datum, wrong_pred in
                     zip(wrong_instance, wrong_preds)]

        return list(results.values())[0], wrong_log
