"""
Utility functions for the attributors package.
"""
import random
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import default_data_collator
from typing import List

def set_seed(seed=42):
    if seed:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def array_in_list(array: np.ndarray, list_of_arrays: List[np.ndarray]) -> bool:
    # Utility function to check if array is in list
    for l in list_of_arrays:
        if np.array_equal(l, array):
            return True
    return False

def convert_to_list(x):
    """
    Convert x to a list if it is not already a list.
    """
    if isinstance(x, list):
        return x
    elif isinstance(x, torch.Tensor):
        return x.tolist()
    elif isinstance(x, np.ndarray):
        return x.tolist()
    else:
        return [x]

def area_over_curve(y):
    """
    Compute normalized AOC for given top-k property vals (length may vary)

    :param y: List of y values
    return: AOC
    """
    from sklearn.metrics import auc
    import numpy as np

    x = np.arange(len(y)) / (len(y) - 1)
    return 1 - auc(x, y).item()

def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

def flatten_params(model, gradients=False):
    vec = []
    for p in model.parameters():
        if gradients: vec.append(p.grad.view(-1))
        else: vec.append(p.view(-1))
    return torch.cat(vec)

def test_accuracy(model: nn.Module, test_dataset: Dataset, device: str) -> torch.Tensor:
    """
    Example property function (convention: higher is better)
    Requirement: must be differentiable with autograd
    """
    if isinstance(model, nn.Module):
        model = model.to(device).eval()
    y_pred = model(test_dataset.data.float().to(device))
    if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
        y_pred = y_pred[:, 1]
    err = y_pred - test_dataset.targets.float().to(device)
    return torch.exp(-1 * torch.pow(err, 2).mean())
