import random
import os
import torch
import numpy as np


# ------------------- utils -------------------
def set_seed(seed):
    if seed != 0:
        random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    return seed

def produce_noise(labels, epsilon=0.2, num_classes=10):
    noisy_labels = labels.copy()
    
    for i in range(len(labels)):
        rand_val = np.random.uniform(0, 1)
        
        if rand_val < epsilon:
            new_label = np.random.choice([i for i in range(num_classes) if i != labels[i]])
            noisy_labels[i] = new_label
    
    return noisy_labels


# ------------------- Non-conformity score function -------------------
class THR(object):
    def __call__(self, logits, label=None):
        assert len(logits.shape) <= 2, "The dimension of logits must be less than 2."
        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        probs = torch.softmax(logits, dim=-1)
        if label is None:
            return self.__calculate_all_label(probs)
        else:
            return self.__calculate_single_label(probs, label)
    
    def __calculate_single_label(self, temp_values, label):
        return 1 - temp_values[torch.arange(label.shape[0], device=temp_values.device), label]

    def __calculate_all_label(self, temp_values):
        return 1 - temp_values
    

# ------------------- Adaptive conformal inference with pinball loss -------------------
def ACI(scores, q_1, etas, alpha):

    T = scores.shape[0]
    q = np.zeros(T)
    q[0] = q_1
    for t in range(T):
        err_t = (scores[t] > q[t]).astype(int)
        if t < T - 1:
            q[t + 1] = q[t] - etas[t] * (alpha - err_t)
    return q


# ------------------- Adaptive conformal inference with robust pinball loss -------------------
def rACI(scores, scores_all, q_1, etas, alpha, noise_rate):

    T = scores.shape[0]
    K = scores_all.shape[1]
    q = np.zeros(T)
    q[0] = q_1
    
    for t in range(T):
        err_t = (scores[t] > q[t]).astype(int)
        errs_t = (scores_all[t,:] > q[t]).sum()
        if t < T - 1:
            n1 = (1/(1 - noise_rate))
            n2 = (noise_rate/(K *(1 - noise_rate)))
            q[t + 1] = q[t] - etas[t] * n1 * (alpha - err_t) + etas[t] * n2 * (K*alpha - errs_t)

    return q
    