'''Computes similarity matrix'''
import os
import sys
import argparse
import random
import timm
import time
import torch
import logging
import torchvision
import numpy as np
from tqdm import tqdm
from gensim.models import KeyedVectors

from models.nlp_embedding_models import Encoder
from utils.data import SpecialCIFAR10, SpecialCIFAR100, ImageNetV2Dataset
from utils.model import load_model, get_features, get_word_embeddings, load_simclr_model, get_simclr_features, get_features_normal_models
from utils.helper import create_logger

from robustness import defaults
from robustness.datasets import DATASETS
from robustness.data_augmentation import TEST_TRANSFORMS_IMAGENET
from robustness.main import setup_args


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--out-dir-complement', default=None, type=str,
        help='Folder name to be added to the out_dir. If None, will be automatically generated.')

    parser.add_argument('--dataset_for_features', default=None, type=str,
        help='If None, get the representations from the dataset. Else get it from this dataset.')
    parser.add_argument('--data_for_features', default='./data', type=str,
        help='Path to the dataset for features.')
    
    parser.add_argument('--layer_num', default=17, type=int,
        help='Which layer to get the representation from model')
    parser.add_argument('--out_filename', default='', type=str, 
        help='Output filename to save -- for tabular data')
    parser.add_argument('--standard_model', default='robustness',
        choices=['robustness', 'simclr', 'timm'], type=str,
        help='If robustness, model was trained using the robustness package. If simclr, model was trained with simclr loss.')
    
    parser.add_argument('--task', default='cv', type=str, choices=['cv', 'nlp', 'tabular'],
        help='Which task is it: Computer Vision or Natural Language Processing')
    parser.add_argument('--data_split', default='', type=str, choices=['', 'train'],
        help='The split from which to get the images')

    # NLP-Related
    parser.add_argument('--words_path', default=None, type=str,
        help='Path to words')
    parser.add_argument('--embeddings_type', default=None, type=str,
        choices=['debiased', 'gender', 'random'], help='Types of embeddings to get from the embeddings (for GN-Glove)')
    parser.add_argument('--embeddings_op', default=None, type=str,
        choices=['association', 'gender_association'],
        help='Embedding operation (association: emb(w1) - emb(w2); gender_association: [emb(w1) - emb(w2)] - [emb(he) - emb(she)]).')
    parser.add_argument('--norm_embeddings_op', default=False, action='store_true',
        help='Whether to normalize the individual embeddings')

    # CV-Related
    parser.add_argument('--seed', default=0, type=int, help='Seed')
    parser.add_argument('--load_dataset', default=None, type=str)
    parser.add_argument('--sim_method', default='cka', type=str, 
        choices=['cka', 'cos'], help='Methods for calculating the similarity')
    parser.add_argument('--perturb_method', default=None, type=str, 
        choices=['blur', 'adv', 'noise', 'elastic'],
        help='Method for perturbing instances of dataset (if None, original images are used).')
    parser.add_argument('--perturbed_imgs_path', default=None, type=str,
        help='Path for loading adversarially perturbed images')
    parser.add_argument('--num_classes', default=10, type=int, help='Number of classes the model was trained on')
    parser.add_argument('--num_samples', default=None, type=int, help='Number of instances to get from the data split. If None gets all of them')

    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)

    # parser.add_argument('--dataset', default='cifar', type=str, help='Dataset name')
    # parser.add_argument('--arch', default='resnet18', type=str, help='Architecture name')
    args = parser.parse_args()
    return args

def get_path(args):
    if args.out_dir_complement is None:
        model_name = ''
        if args.resume is not None:
            model_name = args.resume.split('/')[-2]
            if args.task == 'nlp':
                model_name = ''
        perturb_path = ''
        if args.perturb_method is not None:
            perturb_path = args.perturb_method
            if args.perturb_method == 'adv':
                perturbed_model = args.perturbed_imgs_path.split('/')[-1]
                perturbed_params = args.perturbed_imgs_path.split('/')[-2]
                perturb_path = os.path.join(args.perturb_method, perturbed_model, perturbed_params)
        special_dataset = ''
        if args.load_dataset is not None:
            special_dataset = args.load_dataset
        if args.resume is not None and 'everyepoch' in args.resume:
            model_name += f"/{args.resume.split('/')[-1].split('_')[0]}_checkpoint"

        path = os.path.join(args.out_dir, model_name, special_dataset, args.data_split, perturb_path)
    else:
        path = os.path.join(args.out_dir, args.out_dir_complement)
    print('path to save', path)
    input()
    if not os.path.exists(path):
        os.makedirs(path)
    return path

def main():
    args = get_args()
    args = setup_args(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    path = get_path(args)
    logger = create_logger(path)

    if args.task == 'tabular': # ifair
        model = torch.load(args.resume)
        x = torch.load(args.data)
        if isinstance(x, dict):
            x = x['X']
        fair_x = model.transform(x)
        print('x', x.shape)
        print('fair_x', fair_x.shape)
        torch.save(fair_x, os.path.join(path, args.out_filename))
    elif args.task == 'nlp': # glove
        model = KeyedVectors.load_word2vec_format(args.resume, binary=False)
        logger.info(f'Loaded model from {args.resume}')
        features = get_word_embeddings(model, args.words_path, args.embeddings_type, args.embeddings_op, args.norm_embeddings_op)
        logger.info(f'features shape {features.shape}')
        feature_type = args.embeddings_type
        if feature_type is None:
            feature_type = 'all'
        embeddings_op = args.embeddings_op
        if args.embeddings_op is None:
            embeddings_op = ''
        if args.embeddings_op is not None and args.norm_embeddings_op:
            feature_type = f'norm_{feature_type}'
        if not os.path.exists(os.path.join(path, embeddings_op)):
            os.makedirs(os.path.join(path, embeddings_op))
        torch.save(features.cpu().detach(), os.path.join(path, embeddings_op, f'{feature_type}_features.pt'))
        print(f"Saved features in {os.path.join(path, embeddings_op, f'{feature_type}_features.pt')}")
    else:
        if 'imagenet-v2' in args.data:
            print('Loading imagenet-v2 dataset...')
            dataset = ImageNetV2Dataset("matched-frequency", location=args.data, transform=TEST_TRANSFORMS_IMAGENET)
            testloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
        else:
            dataset = DATASETS[args.dataset](args.data)
            trainloader, testloader = dataset.make_loaders(0, args.batch_size, 
                shuffle_train=False, shuffle_val=False, data_aug=False, val_batch_size=args.batch_size)
        dataset = DATASETS[args.dataset](args.data)
        loader = testloader
        if args.data_split == 'train':
            loader = trainloader

        if args.perturb_method is not None:
            print(f'=> Using perturbed images from {args.perturbed_imgs_path}')
            perturbed_images = torch.load(os.path.join(args.perturbed_imgs_path, 'images.pt'))
            labels = torch.load(os.path.join(args.perturbed_imgs_path, 'labels.pt'))
            if args.dataset == 'cifar':
                dataset = SpecialCIFAR10(perturbed_images, labels, data_path='data/')
            else:
                dataset = SpecialCIFAR100(perturbed_images, labels, data_path='data/')
            loader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

        if args.load_dataset is not None:
            if args.load_dataset == 'cifar10_1':
                print(f'=> Using images from {args.perturbed_imgs_path}')
                perturbed_images = np.load(os.path.join(args.perturbed_imgs_path, 'cifar10.1_v6_data.npy'))
                labels = np.load(os.path.join(args.perturbed_imgs_path, 'cifar10.1_v6_labels.npy'))
                dataset = SpecialCIFAR10(perturbed_images, labels, data_path='data/')
                loader = torch.utils.data.DataLoader(
                    dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

        if args.standard_model == 'robustness': # robustness package
            model, result_model = load_model(
                args, args.arch, args.resume, dataset, testloader=None)
            logger.info(f'Eval model 1: {result_model}')
        elif args.standard_model == 'simclr':
            model = load_simclr_model(args, device)
        elif args.standard_model == 'timm':
            model = timm.create_model(args.arch, num_classes=args.num_classes, pretrained=False)
            state_dict = torch.load(args.resume)['state_dict']
            new_state_dict = {}
            for key, value in state_dict.items():
                key_name = key.replace('model.', '')
                new_state_dict[key_name] = value
            new_state_dict.pop("normalizer.new_mean")
            new_state_dict.pop("normalizer.new_std")
            model.load_state_dict(new_state_dict)
            model.to(device)

        if args.dataset_for_features is not None:
            dataset = DATASETS[args.dataset_for_features](args.data_for_features) #, **kwargs)
            trainloader, testloader = dataset.make_loaders(0, args.batch_size, 
                shuffle_train=False, shuffle_val=False, data_aug=False, val_batch_size=args.batch_size)
            loader = testloader
            if args.data_split == 'train':
                loader = trainloader

        if args.standard_model == 'robustness':
            features_x, _, _, _ = get_features(args, loader, model, device)
        elif args.standard_model == 'simclr':
            features_x, _, _ = get_simclr_features(args, model, loader, device)
        else:
            features_x, _, _ = get_features_normal_models(args, model, loader, device)
        

        logger.info(f'features_x shape {features_x.shape}')
        torch.save(features_x.cpu().detach(), os.path.join(path, f'l{args.layer_num}.pt'))        

if __name__=='__main__':
    main()
