import os
import numpy as np
import torch
import  json
from enum import Enum
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class ConfigEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, type):
            return {'$class': o.__module__ + "." + o.__name__}
        elif isinstance(o, Enum):
            return {
                '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name
            }
        elif callable(o):
            return {
                '$function': o.__module__ + "." + o.__name__
            }
        return json.JSONEncoder.default(self, o)

def count_parameters(model, trainable=False):
    if trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def tensor2numpy(x):
    return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()


def target2onehot(targets, n_classes):
    onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
    onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)
    return onehot


def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def accuracy(y_pred, y_true, nb_old, init_cls=10, increment=10):
    assert len(y_pred) == len(y_true), "Data length error."
    all_acc = {}
    all_acc["total"] = np.around(
        (y_pred == y_true).sum() * 100 / len(y_true), decimals=2
    )

    # Grouped accuracy, for initial classes
    idxes = np.where(
        np.logical_and(y_true >= 0, y_true < init_cls)
    )[0]
    label = "{}-{}".format(
        str(0).rjust(2, "0"), str(init_cls - 1).rjust(2, "0")
    )
    all_acc[label] = np.around(
        (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
    )
    # for incremental classes
    for class_id in range(init_cls, np.max(y_true), increment):
        idxes = np.where(
            np.logical_and(y_true >= class_id, y_true < class_id + increment)
        )[0]
        label = "{}-{}".format(
            str(class_id).rjust(2, "0"), str(class_id + increment - 1).rjust(2, "0")
        )
        all_acc[label] = np.around(
            (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
        )

    # Old accuracy
    idxes = np.where(y_true < nb_old)[0]

    all_acc["old"] = (
        0
        if len(idxes) == 0
        else np.around(
            (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
        )
    )

    # New accuracy
    idxes = np.where(y_true >= nb_old)[0]
    all_acc["new"] = np.around(
        (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
    )

    return all_acc


def split_images_labels(imgs):
    # split trainset.imgs in ImageFolder
    images = []
    labels = []
    for item in imgs:
        images.append(item[0])
        labels.append(item[1])

    return np.array(images), np.array(labels)

def save_fc(args, model):
    _path = os.path.join(args['logfilename'], "fc.pt")
    if len(args['device']) > 1: 
        fc_weight = model._network.fc.weight.data    
    else:
        fc_weight = model._network.fc.weight.data.cpu()
    torch.save(fc_weight, _path)

    _save_dir = os.path.join(f"./results/fc_weights/{args['prefix']}")
    os.makedirs(_save_dir, exist_ok=True)
    _save_path = os.path.join(_save_dir, f"{args['csv_name']}.csv")
    with open(_save_path, "a+") as f:
        f.write(f"{args['time_str']},{args['model_name']},{_path} \n")

def save_model(args, model):
    #used in PODNet
    _path = os.path.join(args['logfilename'], "model.pt")
    if len(args['device']) > 1:
        weight = model._network   
    else:
        weight = model._network.cpu()
    torch.save(weight, _path)

def tsne(args,model,task,path):

    testdata = []
    labels = []
    correct = 0
    total =0 

    x_path = f'{path}/{task}_X_tsne_{args["dnm"]}.npy'
    y_path = f'{path}/{task}_y_tsne_{args["dnm"]}.npy'
    if not os.path.exists(x_path): 
    
        for _, (_, inputs, targets) in enumerate(model.test_loader):
            inputs = inputs.to(model._device)
            targets = targets.to(model._device)
            with torch.no_grad():
                outputs = model._network(inputs)            
            logits = outputs['logits']
            features = outputs['features']
            _, preds = torch.max(logits, dim=1)        
            
            testdata.append(features.cpu().numpy())
            labels.append(preds.cpu().numpy())
            _, preds = torch.max(logits, dim=1)
            correct += preds.eq(targets.expand_as(preds)).cpu().sum()
            total += len(targets)
        train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
        
        print(train_acc)
        testdata = np.array(testdata)
        testdata = np.squeeze(testdata)
        
        labels = np.array(labels)
        labels = np.squeeze(labels)
        
    
        tsne = TSNE(n_components=2, random_state=42, init='pca', learning_rate='auto')
        X_tsne = tsne.fit_transform(testdata)
        np.save(x_path, X_tsne)
        np.save(y_path, labels)
    else:
        X_tsne = np.load(x_path)
        labels = np.load(y_path)

    # fig, axes = plt.subplots(2, 2, figsize=(16, 9), sharex=True, sharey=True)
    plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=labels, cmap='tab10', s=10, alpha=0.7)
    plt.xticks([])
    plt.yticks([])
    # ax.set_xticks([])
    # ax.set_yticks([])

    # scatters = []
    # for ind,ax in enumerate( axes.flat):
    #     scatter = 
    #     scatters.append(scatter)
    #     # ax.set_title(name_title[ind],fontsize=24)
        
    plt.savefig(f'{path}/task_{task}_tsne_{args["dnm"]}.png',format='png')
def dnm():
    return



#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Dec  6 11:56:21 2023

@author: spiros
"""
import numpy as np


def calculate_sparsity(arr, threshold=0):
    """
    Calculate the sparsity of a layer

    Parameters
    ----------
    arr : np.ndarray
        Activations of a layer in `(N, D)` shape, where `N` is the number of inputs
        and `D` the number of nodes in that layer.
    threshold : float, optional
        Threshold above which a node is considered to be active. The default is 0.

    Returns
    -------
    sparsity : np.ndarray
        Sparsity per input in (N,) shape.

    """
    # Convert the output to a binary format
    binary_output = (arr > threshold).astype(int)
    # Calculate the sparsity
    sparsity = np.count_nonzero(binary_output == 0, axis=1) / binary_output.shape[1]
    return sparsity


def compute_hit_matrix(arr, y, threshold=0.0):
    """
    Computes the hit matrix for a given layer of an ANN model,
    using a matrix of layer activations.

    Args:
        arr (numpy.ndarray):
            A matrix of layer activations, with one row per input sample
            and one column per node.
        y (numpy.ndarray):
            The target labels, represented as an array of integers.
        threshold (float):
            The threshold above which a node considered to be active.
            Default is 0.

    Returns:
        numpy.ndarray:
            The hit matrix for the layer.
    """
    # Compute the hit matrix for the layer
    y = y.squeeze()
    num_classes = np.unique(y).shape[0]
    hit_matrix = np.zeros((num_classes, arr.shape[1]))
    for i in range(num_classes):
        hit_matrix[i] = np.sum(arr[y == i] > threshold, axis=0)
    return hit_matrix


def calc_inactive_nodes(hit_matrix, theta=0):
    return np.where(hit_matrix.sum(axis=0) <= theta)[0]


def zero_div(x, y):
    """
    Handle division with zero.

    Parameters
    ----------
    x : numpy.ndarray
        The numerator.
    y : numpy.ndarray
        The denominator.

    Returns
    -------
    float
        Returns zero if any element of `y` is zero, otherwise returns `x / y`.

    """
    return 0.0 if y.any() == 0 else x / y


def add_countless(hit_matrix, total_inputs):
    countless = (len(total_inputs) - hit_matrix.sum(axis=0)).reshape(1, -1)
    return np.concatenate((hit_matrix, countless), axis=0)


def compute_entropy(arr):
    """
    Compute the entropy of an array `arr`.

    Parameters
    ----------
    arr : numpy.ndarray
        The array with probabilities.

    Returns
    -------
    float
        The entropy H[x] = -\sum p(x) log2 p(x).

    """
    # make an array,  if `arr` is list.
    arr = np.array(arr).astype('float64')
    # normalize probabilities to sum to 1 if they don't.
    p = arr/np.sum(arr) if np.sum(arr, axis=0).any() != 1 else arr
    # calculate log and set log(0) to 0.
    logp = np.log2(p, out=np.zeros_like(arr), where=(arr!=0))
    return -np.sum(p * logp, axis=0)


def compute_node_entropies(hit_matrix, true_labels):
    """
    Computes the entropy of each node per class using the hit matrix.

    Args:
        hit_matrix (numpy.ndarray):
            The hit matrix, with one row per class and one column per node.
        total_inputs (numpy.ndarray)
            The total number of inputs for which the entropy is calculated.

    Returns:
        entropies (numpy.ndarray):
            The entropy per node, i.e, low values show class specific behavior,
            whereas large values show mixture selectivity.
    """
    # Compute keep the nodes that are not silent
    hit_matrix = add_countless(hit_matrix, true_labels)
    probabilities = hit_matrix / hit_matrix.sum(axis=0)
    if not probabilities.sum(axis=0).all() == 1.:
        raise ValueError("Node probabilities must sum to 1.")

    # Compute the entropy of the probability distribution for each node and each class
    # calculate the entropy by removing the last row
    entropies = compute_entropy(probabilities)

    if sum(entropies < 0) != 0:
        raise ValueError("Entropy can't be negative!")
    return entropies


def node_specificity(hit_matrix, theta=0):
    """
    The selectivity index.

    Parameters
    ----------
    hit_matrix : numpy.ndarray
        The hit matrix with counts per class and nodes.
    theta : int, optional
        The threshold above which a node is active for a class. The default is 0.

    Returns
    -------
    numpy.ndarray
        An array equal to number of nodes with values in [0, `nclasses`].

    """
    return np.sum(hit_matrix > theta, axis=0)


def information_metrics(activation_arr, true_labels, theta=0,thershold=0.0):
    # Information theory metrics
    hit_matrix = compute_hit_matrix(activation_arr, true_labels,thershold)
    # print(hit_matrix)
    inactive_nodes = calc_inactive_nodes(hit_matrix, theta=theta)
    entropy = np.delete(compute_node_entropies(hit_matrix, true_labels), inactive_nodes)
    sparsity = len(inactive_nodes)/hit_matrix.shape[1]

    # Calculate the node selectivity
    selectivity = node_specificity(hit_matrix, 0)
    return entropy, sparsity, inactive_nodes, selectivity

