import torch
import torch.nn as nn

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.model import convert_weights

from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT

from zsrobust.utils import clip_img_preprocessing as preprocessing
from attack.pgd import attack_pgd
from dassl.evaluation import build_evaluator
from tqdm import tqdm

CUSTOM_TEMPLATES = {
    "OxfordPets": "a photo of a {}, a type of pet.",
    "OxfordFlowers": "a photo of a {}, a type of flower.",
    "FGVCAircraft": "a photo of a {}, a type of aircraft.",
    "DescribableTextures": "{} texture.",
    "EuroSAT": "a centered satellite photo of {}.",
    "StanfordCars": "a photo of a {}.",
    "Food101": "a photo of {}, a type of food.",
    "SUN397": "a photo of a {}.",
    "Caltech101": "a photo of a {}.",
    "UCF101": "a photo of a person doing {}.",
    "ImageNet": "a photo of a {}.",
    "ImageNetSketch": "a photo of a {}.",
    "ImageNetV2": "a photo of a {}.",
    "ImageNetA": "a photo of a {}.",
    "ImageNetR": "a photo of a {}.",
    "TinyImageNet": "a photo of a {}.",
}



class ZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        return logits
    
@TRAINER_REGISTRY.register()
class AdvZeroshotCLIP(TrainerX):
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        clip_model.to(self.device)

        if cfg.TRAINER.ADVZSCLIP.PREC in ["fp32", "amp"]:
            # CLIP's default precision is fp16
            clip_model.float()

        temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
        prompts = [temp.format(c.replace("_", " ")) for c in classnames]
        print(f"Prompts: {prompts}")
        prompts = torch.cat([clip.tokenize(p) for p in prompts])
        prompts = prompts.to(self.device)
        
        #ADD NORMALIZE
        self.preprocessing=preprocessing

        with torch.no_grad():
            text_features = clip_model.encode_text(prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        self.text_features = text_features
        self.clip_model = clip_model

    def test(self, split=None):
        # pass
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        
        self.evaluator.reset()
        self.evaluator_adv = build_evaluator(self.cfg, lab2cname=self.lab2cname)
        self.evaluator_adv.reset()
        torch.cuda.empty_cache()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
        else:
            split = "test"  # in case val_loader is None
            data_loader = self.test_loader

        print(f"Evaluate on the *{split}* set --Adversary")

        for batch_idx, batch in enumerate(tqdm(data_loader)):
            
            #nature test
            with torch.no_grad():
                input, label = self.parse_batch_test(batch)
                output = self.model_inference(self.preprocessing(input))

                self.evaluator.process(output, label)
                
            torch.cuda.empty_cache()

            if self.cfg.ATTACK.TEST == 'aa':  # autoattack
                from attack.auto import attack_auto
                # ############### For debugging-mode, turn this on #################
                # import autoattack.checks
                # def patched_check_dynamic(model, x, is_tf_model=False, logger=None):
                #     print("[Info] check_dynamic skipped to avoid debugger conflict.")
                # autoattack.checks.check_dynamic = patched_check_dynamic
                # ######################################################################
                eps = self.cfg.ATTACK.AA.EPS / 255.
                input_adv = attack_auto(self.model_inference_with_normalization, input, label,
                                        text_tokens=None, prompter=None, add_prompter=None,
                                        device=input.device, attacks_to_run=['apgd-ce', 'apgd-dlr'],
                                        epsilon=eps)
                tmp = self.preprocessing(input_adv)
            else:
                delta=attack_pgd(self.model_inference, self.preprocessing, input, label, alpha=self.cfg.ATTACK.PGD.ALPHA,
                                 attack_iters=self.cfg.ATTACK.PGD.TEST_ITER,epsilon=self.cfg.ATTACK.PGD.EPS)
                tmp= self.preprocessing(input + delta)

            torch.cuda.empty_cache()
            with torch.no_grad():
                # output_adv=self.model_inference(input_adv)
                output_adv = self.model_inference(tmp)
                self.evaluator_adv.process(output_adv, label)


        results = self.evaluator.evaluate()
        results_adv = self.evaluator_adv.evaluate()

        for k, v in results.items():
            tag = f"{split}/{k}"
            self.write_scalar(tag, v, self.epoch)
        
        for k, v in results_adv.items():
            tag = f"{split}/{k}"
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0],list(results_adv.values())[0]
    

    def model_inference(self, image):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logit_scale = self.clip_model.logit_scale.exp()
        logits = logit_scale * image_features @ self.text_features.t()
        return logits

    def model_inference_with_normalization(self, image):
        image = self.preprocessing(image)
        logits = self.model_inference(image)
        return logits


