import torch
import clip
from PIL import Image
import numpy as np
from trainers.coop import CustomCLIP
from dassl.engine import build_trainer
from train import setup_cfg
import argparse


def evaluation(args, custom=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("RN50", device=device)

    image = preprocess(Image.open("/home/openness/data/vqa/train2014/COCO_train2014_000000375043.jpg")).unsqueeze(
        0).to(device)
    classnames = ['casual', 'drums', 'mustard', 'blue and green', 'robe', 'shelter', 'burgers', 'bnsf', '17', 'highway',
                  'dump truck', 'antenna', 'parsley', 'skiers', 'pillows', 'dodgers', 'pastries', '1 foot', 'commercial',
                  'ups', 'desserts', 'all way', 'on rock', 'sweater', '11:15', 'mother and child', 'doubles',
                  'baseball player', 'zucchini', 'accident', 'camping', 'deck', 'crib', '600', 'silver and black', 'goatee',
                  'serious', 'veggie', 'pitcher', 'happy birthday', 'red bull', 'chopsticks', 'high heels', 'squirrel',
                  'savannah', 'toilet', 'picture', 'life jacket', 'tags', 'desert', 'downhill', 'eggs', 'not sure', 'oar',
                  'parmesan cheese', 'bathing', 'arm', 'poop', 'holding phone', 'bikers', 'donut shop', 'in corner',
                  'wheelie', 'lilies']
    print('len(classnames) = ', len(classnames))
    # prompt = 'What seems to be the appropriate attire for this event? '
    # prompt = 'The appropriate attire for this event is '
    # prompt = 'The appropriate dressing code is '
    prompt = ''

    with torch.no_grad():
        if custom:
            cfg = setup_cfg(args)
            trainer = build_trainer(cfg)
            trainer.load_model(args.model_dir, epoch=args.load_epoch)
            logits_per_image = trainer.model(image, prompt, torch.tensor([0]))
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
        else:
            text = clip.tokenize([prompt + c + '.' for c in classnames]).to(device)
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)

            logits_per_image, logits_per_text = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    print('prompt: ', prompt)
    print("Label probs:", probs)

    index = np.argsort(-probs)
    sorted_classnames = [classnames[i] for i in index[0]]
    print(sorted_classnames)
    print('gt rank: ', sorted_classnames.index(classnames[0]) + 1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="~/CoOp/data/", help="path to dataset")
    parser.add_argument("--output-dir", type=str, default="output/evaluation/CoOp/rn50_ep50_1shots/nctx16_cscFalse_ctpend/vqa/seed1", help="output directory")
    parser.add_argument(
        "--resume",
        type=str,
        default="",
        help="checkpoint directory (from which the training resumes)",
    )
    parser.add_argument(
        "--seed", type=int, default=1, help="only positive value enables a fixed seed"
    )
    parser.add_argument(
        "--source-domains", type=str, nargs="+", help="source domains for DA/DG"
    )
    parser.add_argument(
        "--target-domains", type=str, nargs="+", help="target domains for DA/DG"
    )
    parser.add_argument(
        "--transforms", type=str, nargs="+", help="data augmentation methods"
    )
    parser.add_argument(
        "--config-file", type=str, default="configs/trainers/CoOp/rn50_ep50.yaml", help="path to config file"
    )
    parser.add_argument(
        "--dataset-config-file",
        type=str,
        default="configs/datasets/vqa.yaml",
        help="path to config file for dataset setup",
    )
    parser.add_argument("--trainer", type=str, default="CoOp", help="name of trainer")
    parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
    parser.add_argument("--head", type=str, default="", help="name of head")
    parser.add_argument("--eval-only", action="store_true", help="evaluation only")
    parser.add_argument(
        "--model-dir",
        type=str,
        default="",
        help="load model from this directory for eval-only mode",
    )
    parser.add_argument(
        "--load-epoch", type=int, default=50, help="load model weights at this epoch for evaluation"
    )
    parser.add_argument(
        "--no-train", action="store_true", help="do not call trainer.train()"
    )
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="modify config options using the command-line",
    )
    args = parser.parse_args()
    evaluation(args)
