import argparse

import os
import clip
import torch
from tqdm import tqdm

import pickle
import torchvision
import numpy as np
import mytorch.utils as utils
from torch.utils.data import Dataset
from clip.classes import imagenet_classes
from clip.data_loader import data_loader
from clip.templates import imagenet_templates, concept_templates
from clip.save_predictions import save_to_file
from settings import selected_class_descriptions
from torchvision.datasets.folder import default_loader


class SimpleImgData(Dataset):
    def __init__(self, img_dir, preprocessing, loader=default_loader):
        self.loader = loader
        self.transform = preprocessing
        self.samples = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)]
        for sample in self.samples:
            print(sample)

    def __getitem__(self, index):
        path = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)

        return sample

    def __len__(self):
        return len(self.samples)


class WholeConceptImgData(Dataset):
    def __init__(self, whole_concept_data_path, preprocessing, loader=default_loader):
        self.loader = loader
        with open(whole_concept_data_path, 'rb') as file:
            all_samples = pickle.load(file)
        
        self.transform = preprocessing
        self.samples = all_samples['paths']

    def __getitem__(self, index):
        path = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, 0    # 0 is just a placeholder, meaningless.

    def __len__(self):
        return len(self.samples)


class ConceptImgData(Dataset):
    def __init__(self, concept_data_path, preprocessing, loader=default_loader):
        self.loader = loader
        with open(concept_data_path, 'rb') as file:
            all_concept_data = pickle.load(file)
    
        self.transform = preprocessing

        self.samples = []
        for concept_idx in range(len(all_concept_data)):
            concept_data = all_concept_data[concept_idx]
            concept_posi_paths, concept_nega_paths = concept_data['positive'], concept_data['negative']
            for sample_path in concept_posi_paths:
                self.samples.append((sample_path, concept_idx * 2))
            for sample_path in concept_nega_paths:
                self.samples.append((sample_path, concept_idx * 2 + 1)) # target // 2 : concept_idx, target % 2 : negative & positive sample
        self.num_concepts = len(all_concept_data)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, target
    
    def __len__(self):
        return len(self.samples)


class FeatureData(Dataset):
    def __init__(self, data_path, target_path):
        self.all_features = torch.load(data_path).to(torch.float32).numpy()
        self.all_target = torch.load(target_path).numpy()
        self.classes = list(np.unique(self.all_target))

    def __len__(self):
        return self.all_features.shape[0]

    def __getitem__(self, idx):
        return self.all_features[idx], self.all_target[idx]


def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def zeroshot_classifier(model, classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates]  # format with class
            texts = clip.tokenize(texts).cuda()  # tokenize
            class_embeddings = model.encode_text(texts)  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights


def main(args):
    metric_logger = utils.MetricLogger(delimiter="  ")
    log_suffix=""
    header = f"Test: {log_suffix}"

    model, preprocess = clip.load(args.clip_model)
    model_name = args.clip_model.replace('/', '-')
    save_dir = os.path.join('saved_contents', model_name)
    os.makedirs(save_dir, exist_ok=True)
    model.to(device())
    softmax = torch.nn.Softmax(dim=1)

    loader = data_loader(preprocess, args)
    model.eval()

    # Save CLIP image features on ImageNet
    print('Extract CLIP image features on ImageNet')
    zeroshot_weights = zeroshot_classifier(model, imagenet_classes, imagenet_templates)
    all_output_features, all_target, all_paths = [], [], []
    with torch.no_grad():
        for i, (images, target, paths) in enumerate(tqdm(loader)):
            images = images.to(device())
            target = target.to(device())

            # Predict
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            
            all_output_features.append(image_features.cpu().detach())
            all_target.append(target.detach())
            all_paths.extend(list(paths))

            logits = 100. * image_features @ zeroshot_weights
            logits = softmax(logits)
            
            batch_size = images.shape[0]
            acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5))
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
    all_output_features = torch.cat(all_output_features, dim=0)
    all_target = torch.cat(all_target, dim=0)
    save_features_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_features.pth')
    save_target_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_target.pth')
    torch.save(all_output_features.cpu(), save_features_path)
    torch.save(all_target.cpu(), save_target_path)

    # Save CLIP similarities between related concepts (relative) features and train image features
    max_pred_cls_path = 'saved_contents/all_wrong_max_pred_classes.pkl'
    with open(max_pred_cls_path, 'rb') as file:
        all_max_pred_classes = pickle.load(file)
    posi_templates = ['a photo of {}.']
    nega_templates = ['a photo of {}, not {}.']
    num_classes = len(imagenet_classes)
    with torch.no_grad():
        all_concept_embeddings = []
        for cls_id in tqdm(range(num_classes)):
            cls_name = imagenet_classes[cls_id]
            if len(all_max_pred_classes[cls_id]) == 0:
                texts = [template.format(cls_name) for template in posi_templates]
            else:
                all_max_pred_classes[cls_id] = all_max_pred_classes[cls_id][:1]
                nega_cls_names = [imagenet_classes[nega_cls_id] for nega_cls_id in all_max_pred_classes[cls_id]]
                texts = [template.format(cls_name, nega_cls_name) for template in nega_templates for nega_cls_name in nega_cls_names]

            texts = clip.tokenize(texts).cuda()  # tokenize
            concept_embeddings = model.encode_text(texts)  # embed with text encoder
            concept_embeddings /= concept_embeddings.norm(dim=-1, keepdim=True)
            all_concept_embeddings.append(concept_embeddings)
        all_concept_embeddings = torch.cat(all_concept_embeddings, dim=0).permute(1, 0).cuda()
    zeroshot_weights = all_concept_embeddings.to(torch.float32)

    train_data_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_features.pth')
    train_target_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_target.pth')

    dataset = FeatureData(train_data_path, train_target_path)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers
    )

    with torch.no_grad():
        all_similarities = []
        for i, (image_features, targets) in enumerate(loader):
            image_features = image_features.to(device())
            targets = targets.to(device())

            # Predict
            similarities = image_features @ zeroshot_weights
            all_similarities.append(similarities)
            
            batch_size = image_features.shape[0]
            acc1, acc5 = utils.accuracy(similarities, targets, topk=(1, 5))
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
    all_similarities = torch.cat(all_similarities, dim=0).cpu()
    save_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_similarites(cls-relative-text).pth')
    torch.save(all_similarities, save_path)    

    # Save CLIP img logits (cls text features (selected classes) to all img features)
    posi_templates = ['a photo of {}.']
    num_classes = len(imagenet_classes)
    with torch.no_grad():
        all_concept_embeddings = []
        for cls_id in tqdm(range(num_classes)):
            cls_descriptions = imagenet_classes[cls_id]
            if cls_id in selected_class_descriptions:
                cls_descriptions = selected_class_descriptions[cls_id]
            texts = [template.format(cls_descriptions) for template in posi_templates]

            texts = clip.tokenize(texts).cuda()  # tokenize
            concept_embeddings = model.encode_text(texts)  # embed with text encoder
            concept_embeddings /= concept_embeddings.norm(dim=-1, keepdim=True)
            all_concept_embeddings.append(concept_embeddings)
        all_concept_embeddings = torch.cat(all_concept_embeddings, dim=0).permute(1, 0).cuda()
    zeroshot_weights = all_concept_embeddings.to(torch.float32)

    data_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_features.pth')
    target_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_target.pth')
    save_path = os.path.join('saved_contents', model_name, 'clip_model_all_train_img_similarites(cls-selected-text).pth')

    dataset = FeatureData(data_path, target_path)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers
    )

    with torch.no_grad():
        all_similarities = []
        for i, (image_features, targets) in enumerate(tqdm(loader)):
            image_features = image_features.to(device())
            targets = targets.to(device())

            # Predict
            similarities = image_features @ zeroshot_weights
            all_similarities.append(similarities)
            
            batch_size = image_features.shape[0]
            acc1, acc5 = utils.accuracy(similarities, targets, topk=(1, 5))
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
    all_similarities = torch.cat(all_similarities, dim=0).cpu()
    torch.save(all_similarities, save_path)


if __name__ == "__main__":
    args = argparse.ArgumentParser(description='CLIP inference')
    args.add_argument('--data-dir', default='datasets/ILSVRC2012/train', type=str,
                      help='dataset path (default: None)')
    args.add_argument('--num-workers', default=25, type=int,
                      help='number of workers (default: 64)')
    args.add_argument('--batch-size', default=512, type=int,
                      help='Batch size (default: 64)')
    args.add_argument('--clip-model', default='ViT-L/14', type=str)

    config = args.parse_args()
    main(config)