import os
import re
import sys
import math
import time
import pickle

import errno
import signal
from functools import wraps, partial
import torch

# from utils.logger import create_logger

CUDA = True


def count_cuda_devices():
    # Get CUDA_VISIBLE_DEVICES from environment variables
    cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")

    if cuda_visible_devices is not None:
        # If CUDA_VISIBLE_DEVICES is set, split by comma and count
        visible_devices = cuda_visible_devices.split(",")
        return len(visible_devices)
    else:
        # If CUDA_VISIBLE_DEVICES is not set, count all available devices
        return torch.cuda.device_count()


def write_to_file(filename, string_list):
    with open(filename, "w") as file:
        for string in string_list:
            file.write(string + "\n")


def initialize_exp(params):
    """
    Initialize the experience:
    - dump parameters
    - create a logger
    """
    # dump parameters
    # get_dump_path(params)
    pickle.dump(params, open(os.path.join(params.save_path, "params.pkl"), "wb"))

    # get running command
    command = ["python", sys.argv[0]]
    for x in sys.argv[1:]:
        if x.startswith("--"):
            assert '"' not in x and "'" not in x
            command.append(x)
        else:
            assert "'" not in x
            if re.match("^[a-zA-Z0-9_]+$", x):
                command.append("%s" % x)
            else:
                command.append("'%s'" % x)
    command = " ".join(command)
    params.command = command + ' --exp_id "%s"' % params.exp_id

    # check experiment name
    assert len(params.exp_name.strip()) > 0

    # create a logger
    logger = create_logger(os.path.join(params.save_path, "train.log"), rank=getattr(params, "global_rank", 0))
    logger.info("============ Initialized logger ============")
    logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items())))
    logger.info("The experiment will be stored in %s\n" % params.save_path)
    logger.info("Running command: %s" % command)
    logger.info("")
    return logger


# def get_dump_path(params):
#     """
#     Create a directory to store the experiment.
#     """
#     assert len(params.exp_name) > 0

#     # create the sweep path if it does not exist
#     sweep_path = os.path.join(params.save_path, params.exp_name)
#     if not os.path.exists(sweep_path):
#         subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()

#     # create an ID for the job if it is not given in the parameters.
#     # if we run on the cluster, the job ID is the one of Chronos.
#     # otherwise, it is randomly generated
#     if params.exp_id == '':
#         chronos_job_id = os.environ.get('CHRONOS_JOB_ID')
#         slurm_job_id = os.environ.get('SLURM_JOB_ID')
#         assert chronos_job_id is None or slurm_job_id is None
#         exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id
#         if exp_id is None:
#             chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
#             while True:
#                 exp_id = ''.join(random.choice(chars) for _ in range(10))
#                 if not os.path.isdir(os.path.join(sweep_path, exp_id)):
#                     break
#         else:
#             assert exp_id.isdigit()
#         params.exp_id = exp_id

#     # create the dump folder / update parameters
#     params.save_path = os.path.join(sweep_path, params.exp_id)
#     if not os.path.isdir(params.save_path):
#         subprocess.Popen("mkdir -p %s" % params.save_path, shell=True).wait()


def to_cuda(*args):
    """
    Move tensors to CUDA.
    """
    if not CUDA:
        return args
    return [None if x is None else x.cuda() for x in args]


class TimeoutError(BaseException):
    pass


def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):

    def decorator(func):

        def _handle_timeout(repeat_id, signum, frame):
            # logger.warning(f"Catched the signal ({repeat_id}) Setting signal handler {repeat_id + 1}")
            signal.signal(signal.SIGALRM, partial(_handle_timeout, repeat_id + 1))
            signal.alarm(seconds)
            raise TimeoutError(error_message)

        def wrapper(*args, **kwargs):
            old_signal = signal.signal(signal.SIGALRM, partial(_handle_timeout, 0))
            old_time_left = signal.alarm(seconds)
            assert type(old_time_left) is int and old_time_left >= 0
            if 0 < old_time_left < seconds:  # do not exceed previous timer
                signal.alarm(old_time_left)
            start_time = time.time()
            try:
                result = func(*args, **kwargs)
            finally:
                if old_time_left == 0:
                    signal.alarm(0)
                else:
                    sub = time.time() - start_time
                    signal.signal(signal.SIGALRM, old_signal)
                    signal.alarm(max(0, math.ceil(old_time_left - sub)))
            return result

        return wraps(func)(wrapper)

    return decorator


def read_file(filename):
    with open(filename, "r") as file:
        lines = file.readlines()
    return [line.strip() for line in lines]


def read_file_split(filename):
    with open(filename, "r") as file:
        lines = file.readlines()
    return [line.strip().split(":") for line in lines]  # each line separated as "source : target"


def compute_attention_sparsity(attn_matrix):
    """
    Function to calculate the sparsity of an attention matrix.

    Args:
        attn_matrix (torch.Tensor): The attention matrix.

    Returns:
        dict: A dictionary containing the sparsity values of the matrix.
    """
    # Sparsity calculation using entropy
    # Add a small value (1e-4) to prevent log(0)
    epsilon = 1e-4

    # Sparsity calculation of the matrix
    sparsity = -torch.sum(attn_matrix * torch.log(attn_matrix + epsilon), dim=-1)

    return {
        "sparsity": sparsity,
        "mean_sparsity": torch.mean(sparsity).item(),
        "std_sparsity": torch.std(sparsity).item(),
    }


def compute_attention_sparsity2(attn_matrix):
    """
    Function to calculate Hoyer's Sparseness Index of an Attention matrix.

    Args:
        attn_matrix (torch.Tensor): Attention matrix (usually assumes non-negative values).

    Returns:
        dict: A dictionary containing sparsity values based on Hoyer's sparseness index.
    """
    epsilon = 1e-8  # A small value to prevent division by zero

    # Get the number of elements in the last dimension of attn_matrix
    n = attn_matrix.shape[-1]

    if n <= 1:
        # If the number of elements is 1 or less, sparsity is difficult to define or always becomes a specific value.
        # Here, we could return NaN, or a value like 0 or 1 depending on the context.
        # In Hoyer's formula, the denominator would be 0, so NaN or an error might be appropriate.
        # For simplicity, we return 0 here, but this may need adjustment depending on the use case.
        device = attn_matrix.device
        dtype = attn_matrix.dtype
        batch_dims = attn_matrix.shape[:-1]
        sparsity2 = torch.zeros(batch_dims, device=device, dtype=dtype)
        return {
            "sparsity2": sparsity2,
            "mean_sparsity2": torch.mean(sparsity2).item(),
            "std_sparsity2": torch.std(sparsity2).item(),
        }

    sqrt_n = torch.sqrt(torch.tensor(n, dtype=attn_matrix.dtype, device=attn_matrix.device))

    # L1 norm (sum of absolute values)
    sum_abs_x = torch.sum(torch.abs(attn_matrix), dim=-1)

    # L2 norm (sqrt of sum of squares)
    l2_norm_x = torch.linalg.norm(attn_matrix, ord=2, dim=-1)

    # Calculation of Hoyer's Sparseness Index
    numerator = sqrt_n - (sum_abs_x / (l2_norm_x + epsilon))
    denominator = sqrt_n - 1

    # Handling cases where the denominator is close to 0 (e.g., when n is close to 1).
    # Since n is an integer, the denominator is 0 only if n=1, which is handled above.
    # We add epsilon to account for potential floating-point errors making it very small.
    sparsity2 = numerator / (denominator + epsilon)

    return {
        "sparsity": sparsity2,
        "mean_sparsity": torch.mean(sparsity2).item(),
        "std_sparsity": torch.std(sparsity2).item(),
    }


def compute_attention_ratio(attn_matrix, input_len):
    """
    Function to calculate the ratio of attention weights.

    Args:
        attn_matrix (torch.Tensor): Attention matrix [batch_size, num_heads, seq_len, seq_len].
        input_len (int): The length of the input sequence.

    Returns:
        dict: A dictionary containing the attention ratio for each head's output tokens.
    """
    # Get batch size and number of heads
    batch_size, num_heads, seq_len, _ = attn_matrix.shape

    # Calculate the attention ratio for the output tokens
    # Consider only the part after the input tokens
    output_attn = attn_matrix[:, :, input_len:, :]

    # For each output token, separate attention to input and output tokens
    input_attn = output_attn[:, :, :, :input_len]  # Attention to input tokens
    output_attn = output_attn[:, :, :, input_len:]  # Attention to output tokens

    # For each output token, calculate the sum of attention to input and output tokens
    input_attn_sum = torch.sum(input_attn, dim=-1)  # [batch_size, num_heads, output_len]
    output_attn_sum = torch.sum(output_attn, dim=-1)  # [batch_size, num_heads, output_len]

    # Calculate the ratio of attention to output tokens
    output_ratio = output_attn_sum / (input_attn_sum + output_attn_sum + 1e-6)  # Prevent division by zero

    # Calculate the mean and standard deviation for each head
    mean_ratio = torch.mean(output_ratio, dim=[0, 2])  # [num_heads]
    std_ratio = torch.std(output_ratio, dim=[0, 2])  # [num_heads]

    return {
        "output_ratio": output_ratio,  # [batch_size, num_heads, output_len]
        "mean_ratio": mean_ratio,  # [num_heads]
        "std_ratio": std_ratio,  # [num_heads]
        "per_token_ratio": torch.mean(output_ratio, dim=[0, 1]),  # [output_len]
    }