import os
import numpy as np
import random
from tqdm import tqdm
import copy
import torch
from math import inf
from scipy import stats


def generate_transition_matrix_asymm(n, target_noise):
    """
    Generate an n×n transition matrix where:
    - Each column has identical non-diagonal elements (from column_values)
    - Diagonal elements are set to make each row sum to 1

    Parameters:
    - n: dimension of the matrix
    - column_values: list of values for non-diagonal elements of each column

    Returns:
    - transition_matrix: n×n numpy array
    """
    column_values = []
    target_noise = float(target_noise)
    max_deviation = target_noise/n
    p_corr = 1-target_noise
    initial_column_values = [target_noise/(n-1)]*n
    for val in initial_column_values:
        random_noise = random.uniform(-max_deviation, max_deviation)
        new_value = val + random_noise
        column_values.append(new_value)
    # Initialize the matrix
    matrix = np.zeros((n, n))
    # Fill the matrix with column values
    for j in range(n):
        non_diag_value = column_values[j]
        for i in range(n):
            if i != j:  # Skip diagonal
                matrix[i, j] = non_diag_value

    # Calculate diagonal elements to make rows sum to 1
    for i in range(n):
        row_sum_without_diag = sum(matrix[i, :])
        matrix[i, i] = 1.0 - row_sum_without_diag
    return matrix


# for noise generation
# Generate symmetric noisy label
def generate_noisy_label_symmetric(args, label):
    label = torch.tensor(label)
    noisy_label = copy.deepcopy(label)
    n_ratio = float(args.noisy_ratio)*((args.n_class)/(args.n_class-1))
    n_noisy = int(len(label)*n_ratio)
    chg_idx = np.random.permutation(np.arange(len(label)))[:n_noisy]
    noisy_label[chg_idx] = torch.randint(args.n_class,(n_noisy,))

    return noisy_label

# Generate asymmetric noisy label
def generate_noisy_label_asymmetric(args, label):
    label = torch.tensor(label)
    noisy_label = copy.deepcopy(label)
    if args.dataset == "CIFAR10":
        T = generate_transition_matrix_asymm(10, args.noisy_ratio)
        classes = range(10)
    elif args.dataset == "CIFAR100":
        T = generate_transition_matrix_asymm(100, args.noisy_ratio)
        classes = range(100)
    for i, cl in enumerate(label):
        noisy_label[i] = np.random.choice(classes, p=T[cl])
    return noisy_label, T

# Generate instancewise noisy label
def generate_noisy_label_idn(args, input, label):
    input, label = torch.tensor(input), torch.tensor(label)
    n_ratio = float(args.noisy_ratio)
    flip_distribution = stats.truncnorm((0 - n_ratio) / 0.1, (1 - n_ratio) / 0.1, loc=n_ratio, scale=0.1)
    flip_rate = torch.tensor(flip_distribution.rvs(len(label)))
    W = torch.randn(args.n_class, input[0].flatten().shape[0], args.n_class)
    p = torch.sum(input.contiguous().view(len(label),-1,1)*W[label], dim=1)
    p.scatter_(1,label.unsqueeze(1),-inf)
    p = flip_rate.unsqueeze(1) * torch.softmax(p, dim=1)
    p.scatter_(1, label.unsqueeze(1), (1-flip_rate).unsqueeze(1))
    noisy_label = torch.multinomial(p,1)
    return noisy_label.squeeze(1)

# Generate openset noisy label
def generate_noisy_label_open(args, input, label):
    return







