import os
import copy
import pandas as pd
import numpy as np
import torch
import torchvision
from tqdm import tqdm
import dill

from robustness.train import eval_model
from robustness.model_utils import make_and_restore_model
from robustness.tools.vis_tools import show_image_row
from robustness.tools.label_maps import CLASS_DICT


def get_perturbed_images(args, model, testloader, device):
    # transforms = torchvision.transforms.GaussianBlur((9,9)) # cifar10
    if args.perturb_method is not None:
        if args.perturb_method == 'blur':
            transforms = torchvision.transforms.GaussianBlur((9,9), sigma=(0.5, 2.5))
        elif args.perturb_method == 'elastic':
            transforms = torchvision.transforms.ElasticTransform(alpha=80.0)
    kwargs = {
        'constraint': args.constraint, # use L2-PGD
        'eps': float(args.eps),
        'step_size': float(args.attack_lr),
        'iterations': int(float(args.attack_steps)),
        'do_tqdm': True,
    }
    all_perturbed_imgs, all_labels, all_preds = [], [], []
    for i, (inp, label) in tqdm(enumerate(testloader), total=len(testloader)):
        original_inp = copy.deepcopy(inp)
        inp = inp.to(device)
        label = label.to(device)

        # torchvision.utils.save_image(inp, 'orig_img.png')
        if args.perturb_method == 'adv':
            _, inp = model(inp, label, make_adv=True, **kwargs)
        elif args.perturb_method == 'blur' or args.perturb_method == 'elastic':
            inp = transforms(inp)
        else: # random noise
            inp = torch.randn(inp.shape)

        # torchvision.utils.save_image(inp, 'img.png')
        # input()
        all_perturbed_imgs.extend(inp.cpu().detach())
        inp = inp.to(device)
        label = label.to(device)

        pred, _ = model(inp)
        label_pred = torch.argmax(pred, dim=1)

        all_labels.extend(label.cpu().detach())
        all_preds.extend(label_pred.cpu().detach())

        # show_image_row([original_inp.cpu(), inp.cpu()],
        #         tlist=[[CLASS_DICT['CIFAR'][int(t)] for t in l] for l in [label, label_pred]],
        #         fontsize=18,
        #         filename='./example_CIFAR.png')
        # exit()

    all_perturbed_imgs = torch.stack(all_perturbed_imgs)
    all_labels = torch.stack(all_labels)
    all_preds = torch.stack(all_preds)
    return all_perturbed_imgs, all_labels, all_preds


def get_features(args, testloader, model, device, model2=None):
    features_y, features_x, inputs = [], [], []
    labels = []
    for i, (inp, label) in tqdm(enumerate(testloader), total=len(testloader)):
        with torch.no_grad():
            inp = inp.to(device)
            inp.requires_grad = False
            label = label.to(device)
            label.requires_grad = False

            output1 = model(inp, label=label, with_image=False, layer_num=args.layer_num)
            features_x.extend(output1[1])
            if model2 is not None:
                output2 = model2(inp, label=label, with_image=False, layer_num=args.layer_num2)
                features_y.extend(output2[1])

            inputs.extend(inp.cpu().detach())
            labels.extend(label)

    features_x = torch.stack(features_x)
    inputs = torch.stack(inputs)
    labels = torch.stack(labels)
    if model2 is not None:
        return features_x, torch.stack(features_y), labels, inputs
    return  features_x, None, labels, inputs

def load_model(args, arch, resume_path, dataset, testloader=None):
    model, _ = make_and_restore_model(
        arch=arch, dataset=dataset, resume_path=resume_path, parallel=False)
    model = disable_gradients(model)
    model.eval()
    if testloader is not None:
        result_model = eval_model(args, model, testloader, store=None)
        return model, result_model
    return model, None


def get_features_normal_models(args, model, testloader, device):
    '''Function to get features for models that were not trained using the robustness package'''
    pre_act = {}
    def get_pre_out(name):
        def hook(model, input, output):
            pre_act[name] = input[0].detach()
        return hook
    model.fc.register_forward_hook(get_pre_out('pre_head'))

    features, inputs, labels = [], [], []
    for i, (inp, label) in tqdm(enumerate(testloader), total=len(testloader)):
        with torch.no_grad():
            inp = inp.to(device)
            inp.requires_grad = False

            pre_act = {}
            output = model(inp)
            features.extend(pre_act['pre_head'].cpu().detach())

            labels.extend(label)
            inputs.extend(inp.cpu().detach())

    features = torch.stack(features)
    inputs = torch.stack(inputs)
    labels = torch.stack(labels)
    return  features, labels, inputs

def disable_gradients(model):
    for param in model.parameters():
        param.requires_grad = False 
    for module in model.modules():
        module.eval()
    return model

# ###### NLP

def get_per_embeddings_type(embedding_repr, embeddings_type=None):
    if embeddings_type is not None:
        if embeddings_type == 'debiased':
            embedding_repr = embedding_repr[:-1] # remove gender dimension
        elif embeddings_type == 'random':
            embedding_repr = embedding_repr[1:] # remove random dimension
        elif embeddings_type == 'gender':
            embedding_repr = embedding_repr[-1] # get gender dimension only
    return embedding_repr

def get_word_embeddings(model, words_path, embeddings_type=None, embeddings_op=None,
    norm_embeddings_op=False, return_words=False):
    print(f'Loading {words_path}')
    words_csv = pd.read_csv(words_path, index_col=False, header=0)
    print(embeddings_type, embeddings_op, norm_embeddings_op)
    print(words_csv)
    features = []
    for word in words_csv['word']:
        # print('word', word)
        if embeddings_op is None:
            embedding_repr = torch.from_numpy(model[word])
            embedding_repr = get_per_embeddings_type(embedding_repr, embeddings_type)
            features.append(embedding_repr)
        else:
            word1, word2 = word.split(':')
            embedding_repr1 = model[word1]
            embedding_repr2 = model[word2]
            embedding_repr1 = get_per_embeddings_type(embedding_repr1, embeddings_type)
            embedding_repr2 = get_per_embeddings_type(embedding_repr2, embeddings_type)
            if norm_embeddings_op:
                embedding_repr1 = embedding_repr1 / np.linalg.norm(embedding_repr1)
                embedding_repr2 = embedding_repr2 / np.linalg.norm(embedding_repr2)
            embedding_repr1 = torch.from_numpy(embedding_repr1)
            embedding_repr2 = torch.from_numpy(embedding_repr2)
            if embeddings_op == 'gender_association':
                gender_vector = model['he'] - model['she']
                if norm_embeddings_op:
                    gender_vector = gender_vector / np.linalg.norm(gender_vector)
                gender_vector = torch.from_numpy(gender_vector)
                features.append(embedding_repr1 - embedding_repr2 - gender_vector)
            else:
                features.append(embedding_repr1 - embedding_repr2)
    features = torch.stack(features)
    if return_words:
        return features, words_csv['word']
    return features

# # def get_embedding_associations(model, words_path):
# #     words_csv = pd.read_csv(words_path, index_col=False, header=0)
# #     print(words_csv)
# #     features = []
# #     for word_pair in words_csv['words']:
# #         word1, word2 = word_pair.split(':')
# #         embedding_repr1 = torch.from_numpy(model[word1])[:-1]
# #         embedding_repr2 = torch.from_numpy(model[word2])[:-1]
# #         features.append(embedding_repr1 - embedding_repr2)
# #     features = torch.stack(features)
# #     return features
