import torch
import torch.nn as nn
import torch.distributed as dist
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler
from slip import SLIP_VITB16, SLIP_VITL16, SLIP_VITS16
from slip.tokenizer import SimpleTokenizer
from collections import OrderedDict
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT

CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}.",
    "MYCIFAR10": "a photo of a {}.",
    "MYCIFAR100": "a photo of a {}.",
}

_MODELS = {
    "ViT-B/16/ep25": "/home/renshuhuai/.cache/slip/slip_base_25ep.pt",
    "ViT-B/16/ep50": "/home/renshuhuai/.cache/slip/slip_base_50ep.pt",
    "ViT-B/16/ep100": "/home/renshuhuai/.cache/slip/slip_base_100ep.pt",
}


@TRAINER_REGISTRY.register()
class ZeroshotSLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading SLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        slip_model = self.load_slip_to_cpu(cfg)
        slip_model.to(self.device)

        tokenizer = SimpleTokenizer()

        if cfg.DATASET.NAME in ['FLOWERS_PETS_CARS', 'DTD_EUROSAT_CARS', 'FLOWERS_PETS_FOODS',
                                'CIFAR100_CALTECH101_SUN397', 'Food101_CALTECH101_UCF101', 'CIFAR10_CIFAR100_ImageNet']:
            prompts = []
            for c in classnames:
                dataset_name = self.dm.dataset.class2superclass[c]
                temp = CUSTOM_TEMPLATES[dataset_name]
                prompts.append(temp.format(c.replace("_", " ")))
        else:
            temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
            prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = tokenizer(prompts)
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = slip_model.encode_text(prompts)
            text_features = text_features / (
                        text_features.norm(dim=-1, keepdim=True) + 1e-10)  # adapted to declip source code

        self.text_features = text_features
        self.slip_model = slip_model

    def model_inference(self, image):
        image_features = self.slip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.slip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        return logits

    def load_slip_to_cpu(self, cfg):
        backbone_name = cfg.MODEL.BACKBONE.NAME
        model_path = _MODELS[backbone_name]

        try:
            # loading JIT archive
            model = torch.jit.load(model_path, map_location="cpu").eval()
            state_dict = None

        except RuntimeError:
            state_dict = torch.load(model_path, map_location="cpu")

        model = self.build_slip_model(state_dict or model.state_dict())

        return model

    def build_slip_model(self, state_dict: dict):
        old_args = state_dict['args']
        ckpt = state_dict
        state_dict = OrderedDict()
        for k, v in ckpt['state_dict'].items():
            state_dict[k.replace('module.', '')] = v
        if old_args.model == 'SLIP_VITB16':
            model = SLIP_VITB16(rand_embed=False, ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
        elif old_args.model == 'SLIP_VITS16':
            model = SLIP_VITS16(rand_embed=False, ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
        elif old_args.model == 'SLIP_VITL16':
            model = SLIP_VITL16(rand_embed=False, ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
        else:
            raise ValueError('not support slip model')

        model.load_state_dict(state_dict, strict=True)

        state_keys = set(state_dict.keys())
        model_keys = set(model.state_dict().keys())
        missing_keys = model_keys - state_keys
        for k in missing_keys:
            print(f'missing key: {k}')
        return model.eval()

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode('eval')
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == 'val' and self.val_loader is not None:
            data_loader = self.val_loader
            print('Do evaluation on {} set'.format(split))
        else:
            data_loader = self.test_loader
            print('Do evaluation on test set')

        matches_index = []
        preds = []
        for batch_idx, batch in enumerate(data_loader):
            input, label = self.parse_batch_test(batch)
            output = self.model_inference(input)
            self.evaluator.process(output, label)

            pred = output.max(1)[1]
            matches = pred.eq(label).int()
            matches_index.append(matches)
            preds.extend(pred.cpu().detach().numpy().tolist())

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = '{}/{}'.format(split, k)
            if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == 0):
                self.write_scalar(tag, v, self.epoch)

        matches_index = torch.cat(matches_index).cpu().detach().numpy().tolist()
        wrong_instance = [self.dm.dataset.test[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        wrong_preds = [preds[i] for i in range(len(matches_index)) if matches_index[i] == 0]
        # [impath, true_classname, wrong_classname]
        wrong_log = [[datum.impath, datum.classname, self.dm.dataset.classnames[wrong_pred]] for datum, wrong_pred in
                     zip(wrong_instance, wrong_preds)]

        return list(results.values())[0], wrong_log
