import json
import os
import sys
import time

import numpy as np
import open_clip
import torch
import torch.nn.functional as F
from torchvision.transforms import Resize
from torchvision import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
import wandb
import argparse
from robustbench import benchmark
from robustbench.data import load_clean_dataset
from autoattack import AutoAttack
from robustbench.model_zoo.enums import BenchmarkDataset
from CLIP_eval.eval_utils import compute_accuracy_no_dataloader, load_clip_model
from train.utils import str2bool

parser = argparse.ArgumentParser(description="Script arguments")

parser.add_argument('--clip_model_name', type=str, default='none', help='ViT-L-14, ViT-B-32, don\'t use if wandb_id is set')
parser.add_argument('--pretrained', type=str, default='none', help='Pretrained model ckpt path, don\'t use if wandb_id is set')
parser.add_argument('--wandb_id', type=str, default='none', help='Wandb id of training run, don\'t use if clip_model_name and pretrained are set')
parser.add_argument('--logit_scale', type=str2bool, default=True, help='Whether to scale logits')
parser.add_argument('--full_benchmark', type=str2bool, default=False, help='Whether to run full RB benchmark')
parser.add_argument('--dataset', type=str, default='imagenet')
parser.add_argument('--imagenet_root', type=str, default='/home/datasets/imagenet', help='Imagenet dataset root directory')
parser.add_argument('--cifar10_root', type=str, default='/home/datasets/CIFAR10', help='CIFAR10 dataset root directory')
parser.add_argument('--cifar100_root', type=str, default='/home/datasets/CIFAR100', help='CIFAR100 dataset root directory')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--n_samples_imagenet', type=int, default=5000, help='Number of samples from ImageNet for benchmark')
parser.add_argument('--n_samples_cifar', type=int, default=1000, help='Number of samples from CIFAR for benchmark')
parser.add_argument('--template', type=str, default='ensemble', help='Text template type; std, ensemble')
parser.add_argument('--norm', type=str, default='linf', help='Norm for attacks; linf, l2')
parser.add_argument('--eps', type=float, default=4., help='Epsilon for attack')
parser.add_argument('--beta', type=float, default=0., help='Model interpolation parameter')
parser.add_argument('--alpha', type=float, default=2., help='APGD alpha parameter')
parser.add_argument('--experiment_name', type=str, default='', help='Experiment name for logging')
parser.add_argument('--blackbox_only', type=str2bool, default=False, help='Run blackbox attacks only')
parser.add_argument('--save_images', type=str2bool, default=False, help='Save images during benchmarking')
parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')


CIFAR10_LABELS = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

class ClassificationModel(torch.nn.Module):
    def __init__(self, model, text_embedding, args, input_normalize, resizer=None, logit_scale=True):
        super().__init__()
        self.model = model
        self.args = args
        self.input_normalize = input_normalize
        self.resizer = resizer if resizer is not None else lambda x: x
        self.text_embedding = text_embedding
        self.logit_scale = logit_scale

    def forward(self, vision, output_normalize=True):
        assert output_normalize
        embedding_norm_ = self.model.encode_image(
            self.input_normalize(self.resizer(vision)),
            normalize=True
        )
        logits = embedding_norm_ @ self.text_embedding
        if self.logit_scale:
            logits *= self.model.logit_scale.exp()
        return logits

def interpolate_state_dict(m1, beta=0.2):
    m = {}

    m2 = torch.load("/path/to/ckpt.pt", map_location='cpu')
    for k in m1.keys():
        # print(m1[k].shape, m2[k].shape)
        m[k] = beta * m1[k] + (1 - beta) * m2[k]
    return m


if __name__ == '__main__':
    # set seeds
    torch.manual_seed(0)
    np.random.seed(0)

    # Parse command-line arguments
    args = parser.parse_args()
    # print args
    print(f"Arguments:\n{'-' * 20}", flush=True)
    for arg, value in vars(args).items():
        print(f"{arg}: {value}")
    print(f"{'-' * 20}")

    args.eps /= 255
    # make sure there is no string in args that should be a bool
    assert not any(
        [isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values(
        )])

    if args.dataset == 'imagenet':
        num_classes = 1000
        data_dir = args.imagenet_root
        n_samples = args.n_samples_imagenet
        resizer = None
    elif args.dataset == 'cifar100':
        num_classes = 100
        data_dir = args.cifar100_root
        n_samples = args.n_samples_cifar
        resizer = Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=False)
    elif args.dataset == 'cifar10':
        num_classes = 10
        data_dir = args.cifar10_root
        n_samples = args.n_samples_cifar
        resizer = Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=False)
    eps = args.eps

    # init wandb
    os.environ['WANDB__SERVICE_WAIT'] = '300'
    wandb_user, wandb_project = None, None
    while True:
        try:
            run_eval = wandb.init(
                project=wandb_project,
                job_type='eval',
                name=f'{"rb" if args.full_benchmark else "aa"}-clip-{args.dataset}-{args.norm}-{eps:.2f}'
                     f'-{args.wandb_id if args.wandb_id is not None else args.pretrained}-{args.blackbox_only}-{args.beta}',
                save_code=True,
                config=vars(args),
                mode='online' if args.wandb else 'disabled'
            )
            break
        except wandb.errors.CommError as e:
            print('wandb connection error', file=sys.stderr)
            print(f'error: {e}', file=sys.stderr)
            time.sleep(1)
            print('retrying..', file=sys.stderr)

    if args.devices != '':
        # set cuda visible devices
        os.environ["CUDA_VISIBLE_DEVICES"] = args.devices
    main_device = 0
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print(f"Number of GPUs available: {num_gpus}")
    else:
        print("No multiple GPUs available.")

    if not args.blackbox_only:
        attacks_to_run = ['apgd-ce', 'apgd-t']
    else:
        attacks_to_run = ['square']
    print(f'[attacks_to_run] {attacks_to_run}')


    if args.wandb_id not in [None, 'none', 'None']:
        assert args.pretrained in [None, 'none', 'None']
        assert args.clip_model_name in [None, 'none', 'None']
        api = wandb.Api()
        run_train = api.run(f'{wandb_user}/{wandb_project}/{args.wandb_id}')
        clip_model_name = run_train.config['clip_model_name']
        print(f'clip_model_name: {clip_model_name}')
        pretrained = run_train.config["output_dir"]
        if pretrained.endswith('_temp'):
            pretrained = pretrained[:-5]
        pretrained += "/checkpoints/final.pt"
    else:
        clip_model_name = args.clip_model_name
        pretrained = args.pretrained
        run_train = None
    del args.clip_model_name, args.pretrained

    print(f'[loading pretrained clip] {clip_model_name} {pretrained}')

    model, preprocessor_without_normalize, normalize = load_clip_model(clip_model_name, pretrained, args.beta)

    if args.dataset != 'imagenet':
        # make sure we don't resize outside the model as this influences threat model
        preprocessor_without_normalize = transforms.ToTensor()
    print(f'[resizer] {resizer}')
    print(f'[preprocessor] {preprocessor_without_normalize}')

    model.eval()
    model.to(main_device)
    with torch.no_grad():
        # Get text label embeddings of all ImageNet classes
        if not args.template == 'ensemble':
            if args.template == 'std':
                template = 'This is a photo of a {}'
            else:
                raise ValueError(f'Unknown template: {args.template}')
            print(f'template: {template}')
            if args.dataset == 'imagenet':
                texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
            elif args.dataset == 'cifar10':
                texts = [template.format(c) for c in CIFAR10_LABELS]
            text_tokens = open_clip.tokenize(texts)
            embedding_text_labels_norm = []
            text_batches = [text_tokens[:500], text_tokens[500:]] if args.dataset == 'imagenet' else [text_tokens]
            for el in text_batches:
                # we need to split the text tokens into two batches because otherwise we run out of memory
                # note that we are accessing the model directly here, not the CustomModel wrapper
                # thus its always normalizing the text embeddings
                embedding_text_labels_norm.append(
                    model.encode_text(el.to(main_device), normalize=True).detach().cpu()
                )
            model.cpu()
            embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
        else:
            assert args.dataset == 'imagenet', 'ensemble only implemented for imagenet'
            with open('CLIP_eval/zeroshot-templates.json', 'r') as f:
                templates = json.load(f)
            templates = templates['imagenet1k']
            print(f'[templates] {templates}')
            embedding_text_labels_norm = []
            for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values():
                texts = [template.format(c=c) for template in templates]
                text_tokens = open_clip.tokenize(texts).to(main_device)
                class_embeddings = model.encode_text(text_tokens)
                class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
                class_embedding /= class_embedding.norm()
                embedding_text_labels_norm.append(class_embedding)
            embedding_text_labels_norm = torch.stack(embedding_text_labels_norm, dim=1).to(main_device)

        assert torch.allclose(
            F.normalize(embedding_text_labels_norm, dim=0),
            embedding_text_labels_norm
        )
        if clip_model_name == 'ViT-B-32':
            assert embedding_text_labels_norm.shape == (512, num_classes), embedding_text_labels_norm.shape
        elif clip_model_name == 'ViT-L-14':
            assert embedding_text_labels_norm.shape == (768, num_classes), embedding_text_labels_norm.shape
        else:
            raise ValueError(f'Unknown model: {clip_model_name}')

    # get model
    model = ClassificationModel(
        model=model,
        text_embedding=embedding_text_labels_norm,
        args=args,
        resizer=resizer,
        input_normalize=normalize,
        logit_scale=args.logit_scale,
    )

    if num_gpus > 1:
        model = torch.nn.DataParallel(model)
    model = model.cuda()
    model.eval()

    model_name = None
    # device = [torch.device(el) for el in range(num_gpus)]  # currently only single gpu supported
    device = torch.device(main_device)
    torch.cuda.empty_cache()

    dataset_short = (
        'img' if args.dataset == 'imagenet' else
        'c10' if args.dataset == 'cifar10' else
        'c100' if args.dataset == 'cifar100' else
        'unknown'
    )

    start = time.time()
    if args.full_benchmark:
        clean_acc, robust_acc = benchmark(
            model, model_name=model_name, n_examples=n_samples,
            batch_size=args.batch_size,
            dataset=args.dataset, data_dir=data_dir,
            threat_model=args.norm.replace('l', 'L'), eps=eps,
            preprocessing=preprocessor_without_normalize,
            device=device, to_disk=False
            )
        clean_acc *= 100
        robust_acc *= 100
        duration = time.time() - start
        print(f"[Model] {pretrained}")
        print(
            f"[Clean Acc] {clean_acc:.2f}% [Robust Acc] {robust_acc:.2f}% [Duration] {duration / 60:.2f}m"
            )
        if run_train is not None:
            # reload the run to make sure we have the latest summary
            del api, run_train
            api = wandb.Api()
            run_train = api.run(f'{wandb_user}/{wandb_project}/{args.wandb_id}')
            eps_descr = str(int(eps * 255)) if args.norm == 'linf' else str(eps)
            run_train.summary.update({f'rb/acc-{dataset_short}': clean_acc})
            run_train.summary.update({f'rb/racc-{dataset_short}-{args.norm}-{eps_descr}': robust_acc})
            run_train.update()
    else:
        adversary = AutoAttack(
            model, norm=args.norm.replace('l', 'L'), eps=eps, version='custom', attacks_to_run=attacks_to_run,
            alpha=args.alpha, verbose=True
        )

        x_test, y_test = load_clean_dataset(
            BenchmarkDataset(args.dataset), n_examples=n_samples, data_dir=data_dir,
            prepr=preprocessor_without_normalize,)

        acc = compute_accuracy_no_dataloader(model, data=x_test, targets=y_test, device=device, batch_size=args.batch_size) * 100
        print(f'[acc] {acc:.2f}%', flush=True)
        x_adv, y_adv = adversary.run_standard_evaluation(x_test, y_test, bs=args.batch_size, return_labels=True)  # y_adv are preds on x_adv
        racc = compute_accuracy_no_dataloader(model, data=x_adv, targets=y_test, device=device, batch_size=args.batch_size) * 100
        print(f'[acc] {acc:.2f}% [racc] {racc:.2f}%')

        # save adv images
        if args.save_images:
            # save the adversarial images
            img_save_path = (f'/path/to/save/dir/'
                             f'{args.dataset}/{args.wandb_id}-{args.pretrained}-{args.norm}-{eps:.3f}-'
                             f'alph{args.alpha:.3f}-{n_samples}smpls-{time.strftime("%Y-%m-%d_%H-%M-%S")}')
            os.makedirs(img_save_path, exist_ok=True)

            print(f'[saving images to] {img_save_path}')
            x_adv = x_adv.detach().cpu()
            y_adv = y_adv.detach().cpu()
            x_clean = x_test.detach().cpu()
            y_clean = y_test.detach().cpu()
            torch.save(x_adv, f'{img_save_path}/x_adv.pt')
            torch.save(y_adv, f'{img_save_path}/y_adv.pt')
            torch.save(x_clean, f'{img_save_path}/x_clean.pt')
            torch.save(y_clean, f'{img_save_path}/y_clean.pt')
            with open(f'{img_save_path}/args.json', 'w') as f:
                json.dump(vars(args), f)
            with open(f'{img_save_path}/results.json', 'w') as f:
                f.write(f"acc:{acc:.2f}%")
                f.write(f"Racc:{racc:.2f}%")

        # write to wandb
        if run_train is not None:
            # reload the run to make sure we have the latest summary
            del api, run_train
            api = wandb.Api()
            run_train = api.run(f'{wandb_user}/{wandb_project}/{args.wandb_id}')
            if args.dataset == 'imagenet':
                assert args.norm == 'linf'
                eps_descr = str(int(eps * 255))
                if eps_descr == '4':
                    descr = dataset_short
                else:
                    descr = f'{dataset_short}-eps{eps_descr}'
                if n_samples != 5000:
                    acc = f'{acc:.2f}*'
                    racc = f'{racc:.2f}*'
            elif args.dataset == 'cifar10':
                if args.norm == 'linf':
                    descr = dataset_short
                else:
                    descr = f'{dataset_short}-{args.norm}'
                if n_samples != 10000:
                    acc = f'{acc:.2f}*'
                    racc = f'{racc:.2f}*'
            else:
                raise ValueError(f'Unknown dataset: {args.dataset}')
            run_train.summary.update({f'aa/acc-{dataset_short}': acc})
            run_train.summary.update({f'aa/racc-{descr}': racc})
            run_train.summary.update()
            run_train.update()
    run_eval.finish()













