import argparse
import json
import os
import sys

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
import open_clip
import torch
import torch.nn as nn
import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

from evaluate.elevater_metrics import get_metric
from evaluate.cvinw_zeroshot_templates import (
    openai_templates,
    flower_templates,
    food_templates,
    aircraft_templates,
    eurosat_templates,
    country211_templates,
)
from evaluate.dataset import ImgFolderDataset

PROMPTs = {
    "fgvc-aircraft-2013b-variants102": aircraft_templates,
    "food-101": food_templates,
    "oxford-flower-102": flower_templates,
    "eurosat_clip": eurosat_templates,
    "resisc45_clip": eurosat_templates,
    "country211": country211_templates,
    "openai": openai_templates,
}

METRICS = {
    "caltech-101": "mean-per-class",
    "fgvc-aircraft-2013b-variants102": "mean-per-class",
    "oxford-flower-102": "mean-per-class",
    "oxford-iiit-pets": "mean-per-class",
    "hateful-memes": "roc_auc",
    "voc-2007-classification": "11point_mAP"
}


class ModelWrapper(nn.Module):
    """ Wrap model for DataParallel multi-gpu testing. """

    def __init__(self, model: nn.Module, forward_func: str = 'forward'):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.forward_func = forward_func

    def forward(self, **kwargs):
        return getattr(self.model, self.forward_func)(**kwargs)


class DefaultDataLoader(DataLoader):
    def __init__(self,
                 dataset,
                 batch_size=64,
                 shuffle=False,
                 sampler=None,
                 num_workers=4,
                 pin_memory=False,
                 drop_last=False,
                 prefetch_factor=2,
                 persistent_workers=False):
        if sampler is not None:
            shuffle = False

        super(DefaultDataLoader, self).__init__(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor
        )


def zero_shot_classifier(text_model, tokenizer, classnames, templates):
    def tokenize(text):
        ret = tokenizer(text)
        results = dict()
        results['text'] = ret.squeeze(0)
        return results

    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm.tqdm(classnames):
            data_list = [tokenize(template(classname)) for template in templates]  # format with class
            data = default_collate(data_list)
            class_embeddings = text_model(**data)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    print(f'classifier shape: {zeroshot_weights.shape}')
    return zeroshot_weights


def eval_all_zero_shot_classification(model, preprocess, tokenizer, elevater_path, batch_size=256):
    model.eval()
    model.cuda()

    text_model = ModelWrapper(model, forward_func='encode_text')
    text_model = nn.DataParallel(text_model.cuda(), device_ids=list(range(torch.cuda.device_count())))

    img_model = ModelWrapper(model, forward_func='encode_image')
    img_model = nn.DataParallel(img_model.cuda(), device_ids=list(range(torch.cuda.device_count())))

    all_datasets = ['cifar-10', 'cifar-100', 'dtd', 'eurosat_clip', 'fer-2013',
                    'fgvc-aircraft-2013b-variants102', 'kitti-distance',
                    'mnist', 'patch-camelyon', 'voc-2007-classification', 'caltech-101',
                    'country211', 'food-101', 'gtsrb', 'hateful-memes',
                    'oxford-flower-102', 'oxford-iiit-pets', 'rendered-sst2',
                    'resisc45_clip', 'stanford-cars']

    all_results = {}
    for data_name in all_datasets:
        print("Dataset: ", data_name)
        dataset_root = os.path.join(elevater_path, data_name)

        # Compute ensembled class embeddings
        print("Build classifier...")
        label_file = os.path.join(dataset_root, "label_cn.txt")
        with open(label_file, "r", encoding="utf8") as f:
            classnames = [line.strip() for line in f.readlines()]
        templates = PROMPTs.get(data_name, openai_templates)

        classifier = zero_shot_classifier(text_model, tokenizer, classnames, templates)

        img_dataset = ImgFolderDataset(os.path.join(dataset_root, 'test'), preprocess=preprocess)
        img_loader = DefaultDataLoader(img_dataset, batch_size=batch_size)

        total_logits = []
        total_targets = []
        with torch.no_grad():
            for data in tqdm.tqdm(img_loader):
                target = data.pop("target")
                total_targets.append(target)

                # predict
                image_features = img_model(**data)
                image_features /= image_features.norm(dim=-1, keepdim=True)
                logits = (100.0 * image_features @ classifier).softmax(dim=-1)
                total_logits.append(logits)

        outputs = torch.cat(total_logits, dim=0)
        targets = torch.cat(total_targets, dim=0)

        index_path = os.path.join(dataset_root, "index.json")
        if os.path.isfile(index_path):
            print("Use index to rearrange the logits...")
            with open(index_path, "r", encoding="utf-8") as f:
                index = json.load(f)
            outputs = outputs[index]
            targets = targets[index]

        metric_name = METRICS.get(data_name, 'accuracy')
        print(f'Metric: {metric_name}')
        metric = get_metric(metric_name)
        result = metric(targets.squeeze().cpu().detach().numpy(), outputs.cpu().detach().numpy())
        print(result)
        # all_results[data_name] = {"metric": metric_name, "result": result}
        all_results[data_name] = result

    for key in all_results:
        print(f'{key}: {all_results[key]}')
    print('Finished.')

    mean = sum(all_results.values()) / len(all_results)
    all_results['mean_zs'] = mean

    return all_results


def parse_args():
    parser = argparse.ArgumentParser(description='Zero-shot Image-Text Retrieval. ')
    parser.add_argument('--model_name', default='YouCLIP-Base',
                        choices=['YouCLIP-Base', 'YouCLIP-Base-CN-ENG', 'YouCLIP-Base-512', 'YouCLIP-Base-512-CN-ENG',
                                 'YouCLIP-Large', 'YouCLIP-Large-CN-ENG', 'YouCLIP-Huge', 'YouCLIP-Huge-CN-ENG'],
                        help='Model size. ')
    parser.add_argument('--model_checkpoint', default=None, type=str, help='checkpoint path. ')
    parser.add_argument('--elevater_path',
                        default='/xxx', type=str,
                        help='ELEVATER path. ')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print('build model...')
    model, preprocess, tokenizer = open_clip.load_YouCLIP(model_name=args.model_name,
                                                          model_file_path=args.model_checkpoint)
    result = eval_all_zero_shot_classification(model=model, tokenizer=tokenizer, preprocess=preprocess,
                                               elevater_path=args.elevater_path)
    print(result)
