import logging

import torch
from tqdm import tqdm

from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
    IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, CIFAR100_CLASSNAMES, CALTECH101_CLASSNAMES, FLOWERS_CLASSNAMES, FOOD_CLASSNAMES, STANFORD_CLASSNAMES
from open_clip_train.precision import get_autocast


def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]


def run(model, classifier, dataloader, args):
    device = torch.device(args.device)
    autocast = get_autocast(args.precision, device_type=device.type)
    input_dtype = get_input_dtype(args.precision)



    with torch.inference_mode():
        top1, top5, n = 0., 0., 0.
        for i, (images, target) in enumerate(tqdm(dataloader, unit_scale=args.batch_size)):
            images = images.to(device=device, dtype=input_dtype)
            target = target.to(device)

            with autocast():
                # predict
                # output = model(iteration=i, image=images) #SigLIP
                output = model(image=images) #CLIP
                image_features = output['image_features'] if isinstance(output, dict) else output[0]
                logits = 100. * image_features @ classifier

            # measure accuracy
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1 += acc1
            top5 += acc5
            n += images.size(0)

    top1 = (top1 / n)
    top5 = (top5 / n)
    return top1, top5


def zero_shot_eval(model, data, epoch, args, tokenizer=None):
    # if 'imagenet-val' not in data and 'imagenet-v2' not in data:
    #     return {}
    if args.zeroshot_frequency == 0:
        return {}
    if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
        return {}
    if args.distributed and not args.horovod:
        model = model.module

    logging.info('Starting zero-shot imagenet.')
    if tokenizer is None:
        tokenizer = get_tokenizer(args.model)

    logging.info('Building zero-shot classifier')
    device = torch.device(args.device)
    autocast = get_autocast(args.precision, device_type=device.type)

    MODEL_NAME_MAPPING = {
        ('ViT-B-16-quickgelu', 'openai'): 'clip',
        ('EVA01-g-14', 'laion400m_s11b_b41k'): 'eva01clip',
        ('EVA02-B-16', 'merged2b_s8b_b131k'): 'eva02clip',
        ('ViT-B-16', 'laion2b_s34b_b88k'): 'openclip',
        ('ViT-B-16-SigLIP', 'webli'): 'siglip',
        ('ViT-B-16-SigLIP2', 'webli'): 'siglip2',
    }
    key = (args.model, args.pretrained)
    default_name = f"{args.model}-{args.pretrained}"
    model_name = MODEL_NAME_MAPPING.get(key, default_name)

    if 'caltech-101' in data: 
        classes = CALTECH101_CLASSNAMES
    elif 'flowers-102' in data:
        classes = FLOWERS_CLASSNAMES
    elif 'food-101' in data:
        classes = FOOD_CLASSNAMES
    elif 'stanford' in data:
        classes = STANFORD_CLASSNAMES

    with autocast():
        if args.imagenet_val:
            classifier = torch.load(f'/home/user/regcache/cache/imagenet1k/zero_shot_classifier4{model_name}.pt', map_location=device)
        elif args.cifar_100_val or args.cifar_100_train:
            classifier = torch.load(f'/home/user/regcache/cache/cifar100/zero_shot_classifier4{model_name}.pt', map_location=device)
        else:
            classifier = build_zero_shot_classifier(
                model,
                tokenizer=tokenizer,
                classnames=classes,
                templates=OPENAI_IMAGENET_TEMPLATES,
                num_classes_per_batch=10,
                device=device,
                use_tqdm=True,
            )

    logging.info('Using classifier')
    results = {}
    if 'imagenet-val' in data:
        top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
        results['imagenet-zeroshot-val-top1'] = top1
        results['imagenet-zeroshot-val-top5'] = top5
    if 'imagenet-v2' in data:
        top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
        results['imagenetv2-zeroshot-val-top1'] = top1
        results['imagenetv2-zeroshot-val-top5'] = top5
    if 'cifar-100' in data:
        top1, top5 = run(model, classifier, data['cifar-100'].dataloader, args)
        results['cifar-100-zeroshot-val-top1'] = top1
        results['cifar-100-zeroshot-val-top5'] = top5
    if 'caltech-101' in data:
        top1, top5 = run(model, classifier, data['caltech-101'].dataloader, args)
        results['caltech-101-zeroshot-val-top1'] = top1
        results['caltech-101-zeroshot-val-top5'] = top5
    if 'aircraft' in data:
        top1, top5 = run(model, classifier, data['aircraft'].dataloader, args)
        results['aircraft-zeroshot-val-top1'] = top1
        results['aircraft-zeroshot-val-top5'] = top5

    if 'bird' in data:
        top1, top5 = run(model, classifier, data['bird'].dataloader, args)
        results['bird-zeroshot-val-top1'] = top1
        results['bird-zeroshot-val-top5'] = top5

    if 'stanford' in data:
        top1, top5 = run(model, classifier, data['stanford'].dataloader, args)
        results['stanford-zeroshot-val-top1'] = top1
        results['stanford-zeroshot-val-top5'] = top5

    if 'DTD' in data:
        top1, top5 = run(model, classifier, data['DTD'].dataloader, args)
        results['DTD-zeroshot-val-top1'] = top1
        results['DTD-zeroshot-val-top5'] = top5

    if 'EuroSAT' in data:
        top1, top5 = run(model, classifier, data['EuroSAT'].dataloader, args)
        results['EuroSAT-zeroshot-val-top1'] = top1
        results['EuroSAT-zeroshot-val-top5'] = top5

    if 'flowers-102' in data:
        top1, top5 = run(model, classifier, data['flowers-102'].dataloader, args)
        results['flowers-102-zeroshot-val-top1'] = top1
        results['flowers-102-zeroshot-val-top5'] = top5

    if 'food-101' in data:
        top1, top5 = run(model, classifier, data['food-101'].dataloader, args)
        results['food-101-zeroshot-val-top1'] = top1
        results['food-101-zeroshot-val-top5'] = top5

    if 'pet' in data:
        top1, top5 = run(model, classifier, data['pet'].dataloader, args)
        results['pet-zeroshot-val-top1'] = top1
        results['pet-zeroshot-val-top5'] = top5

    if 'sun397' in data:
        top1, top5 = run(model, classifier, data['sun397'].dataloader, args)
        results['sun397-zeroshot-val-top1'] = top1
        results['sun397-zeroshot-val-top5'] = top5

    if 'ucf101' in data:
        top1, top5 = run(model, classifier, data['ucf101'].dataloader, args)
        results['ucf101-zeroshot-val-top1'] = top1
        results['ucf101-zeroshot-val-top5'] = top5

    logging.info('Finished zero-shot imagenet.')

    return results
