import random
import torch
import numpy as np
import argparse
import pdb
import logging


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def grad_param(net, mode=True):
    # freeze all layers but the last fc
    for name, param in net.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = mode


def count_param(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    param_train = sum([np.prod(p.size()) for p in model_parameters])
    param_all = sum([np.prod(p.size()) for p in model.parameters()])
    return param_all, param_train


def print_param(model):
    # print('\n--- frozen param ---\n')
    # for name, param in model.named_parameters():
    #     if not param.requires_grad:
    #         print(name)
    print('\n--- trainable param ---\n')
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)
            if 'bias' in name:
                print(param.data[0])


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        logging.info('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def upper_triangularize(entries, size):
    """
    Given a vector of n * (n-1) / 2 components, constructs an upper triangular matrix with those
    components as entries (and 0 elsewhere)
    """

    batch_dims = entries.shape[:-1]
    a = torch.zeros(*batch_dims, size, size, device=entries.device, dtype=entries.dtype)
    i, j = torch.triu_indices(size, size, offset=1)
    a[..., i, j] = entries

    return a


def topological_sort(adjacency_matrix):
    """Topological sort"""

    # Make a copy of the adjacency matrix
    a = adjacency_matrix.clone().to(torch.int)
    dim = a.shape[0]
    assert a.shape == (dim, dim)

    # Kahn's algorithm
    ordering = []
    to_do = list(range(dim))

    while to_do:
        root_nodes = [i for i in to_do if not torch.sum(torch.abs(a[:, i]))]
        for i in root_nodes:
            ordering.append(i)
            del to_do[to_do.index(i)]
            a[i, :] = 0

    return ordering


def clean_and_clamp(inputs, min_=-1.0e12, max_=1.0e12):
    """Clamps a tensor and replaces NaNs"""
    return torch.clamp(torch.nan_to_num(inputs), min_, max_)


def mask(data, mask_, mask_data=None, concat_mask=True):
    """Masking on a tensor, optionally adding the mask to the data"""

    if mask_data is None:
        masked_data = mask_ * data
    else:
        masked_data = mask_ * data + (1 - mask_) * mask_data

    if concat_mask:
        masked_data = torch.cat((masked_data, mask_), dim=1)

    return masked_data


def logmeanexp(x, dim):
    """Like logsumexp, but using a mean instead of the sum"""
    return torch.logsumexp(x, dim=dim) - np.log(x.shape[dim])


def inverse_softplus(x, beta=1.0):
    """Inverse of the softplus function"""
    return 1 / beta * np.log(np.exp(beta * x) - 1.0)


def freeze(freez, models):
    for model in models:
        if model != None:
            for p in model.parameters():
                p.requires_grad = not freez


def correlation(x, y):
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)

    # Calculate the covariance
    cov_xy = torch.mean((x - mean_x) * (y - mean_y))

    # Calculate the standard deviations
    std_x = torch.std(x)
    std_y = torch.std(y)

    # Calculate the Pearson correlation coefficient
    correlation = cov_xy / (std_x * std_y)

    return correlation


def check_linear_dependence(vectors):
    """Check if a set of vectors are linearly dependent or independent.

    Args:
    vectors (list of list of numbers): The set of vectors to check.

    Returns:
    str: A message indicating whether the vectors are linearly dependent or independent.
    """
    matrix = np.array(vectors)
    rank = np.linalg.matrix_rank(matrix)

    if rank < len(vectors[0]):
        return "The vectors are linearly dependent."
    else:
        return "The vectors are linearly independent."
