import torch
import os 
import numpy as np
from calibration import *
from baseline_util import *
from torch.utils.data import DataLoader
from collections import Counter
import ot
import random
def baseline(net, trainloader, testloader):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ## Evaluate
    DoC_list, DoE_list, mu_list = [], [], []
    delta_acc, train_acc_list, test_acc_list, mc_acc_list, ne_acc_list = [], [], [], [], []

    # Training data
    probs, labels = get_probs(net, trainloader, device)
    pred_idx = np.argmax(probs, axis=-1)
    pred_probs = np.max(probs, axis=-1)
    acc = np.mean(pred_idx == labels)*100.

    #Calibrate the output probabilities (on Training data)
    try:
        calibrator = TempScaling()
        calibrator.fit(inverse_softmax(probs), labels)
    except: 
        class Calibration: pass 
        calibrator = Calibration()
        calibrator.calibrate = lambda x: x

    calib_probs = softmax(inverse_softmax(probs))
    calib_pred_idx = np.argmax(calib_probs, axis=-1) 
    calib_pred_probs = np.max(calib_probs, axis=-1) 
    calib_entropy = get_entropy(calib_probs)
    # Calculate threshold for ATC
    _, calib_entropy_thres_balance = find_ATC_threshold(calib_entropy, calib_pred_idx == labels )
    _, calib_thres_balance = find_ATC_threshold(calib_pred_probs, calib_pred_idx == labels )
    
    if acc == 100: # degrade
        calib_entropy_thres_balance = calib_entropy.min()
        calib_thres_balance = calib_pred_probs.min()
        print("ATC degrade")
    # Test data
    probs_new, _ = get_probs(net, testloader, device) 

    pred_idx_new = np.argmax(probs_new, axis=-1)
    pred_probs_new = np.max(probs_new, axis=-1)

    calib_probs_new = softmax(inverse_softmax(probs_new))
    calib_pred_idx_new = np.argmax(calib_probs_new, axis=-1)
    calib_pred_probs_new = np.max(calib_probs_new, axis=-1)

    entropy_new = get_entropy(probs_new)
    calib_entropy_new = get_entropy(calib_probs_new)
    # Calculate ATC 
    calib_entropy_pred_balance = get_ATC_acc(calib_entropy_thres_balance, calib_entropy_new)
    calib_pred_balance = get_ATC_acc(calib_thres_balance, calib_pred_probs_new)

    # Doc and DoE
    DoC = get_DoC(calib_probs, calib_probs_new)
    DoE = get_DoE(calib_probs, calib_probs_new)

    print("DOC:", DoC)
    print("DOE:", DoE)
    print("Predicted Acc by ATC_MC: ", calib_pred_balance)
    print("Predicted Acc by ATC_NE: ", calib_entropy_pred_balance)

    # CoT
    est = 0
    for _ in range(n_batch):
        rand_inds = torch.as_tensor( random.choices( list(range(n_test_sample)), k=batch_size))
                
        iid_acts_batch = nn.functional.one_hot(
            sample_label_dist(dsname, n_class, batch_size)
        )

        ood_acts_batch = ood_acts[rand_inds]
                
        M = torch.cdist(iid_acts_batch.float(), ood_acts_batch, p=1)
        weights = torch.as_tensor([])
        est += ( ot.emd2(weights, weights, M, numItermax=1e8, numThreads=8) / 2 ).item()
    est = est / n_batch
    
    # ATC_MC, ATC_NE, DOC, DOE
    return calib_pred_balance, calib_entropy_pred_balance, DoC, DoE
	

def NeiborhoodInviriance(algorithm, test_dataset, device):
    # Monte Carlo N times
    N = 10
    total_y = []
    with torch.no_grad():
        for j in range(N):
            test_loader = DataLoader([test_dataset.pertur(i, degree = 1) for i in range(len(test_dataset))], batch_size=256, shuffle=False)
            pred_y = []
            for item in test_loader:
                batch_x = item[1]
                batch_x = batch_x.to(device)
                result = algorithm.predict(batch_x)
                pred_y_batch = torch.argmax(result, dim=1).cpu()
                pred_y.extend(list(pred_y_batch))
            total_y.append(pred_y)

    total_y = np.array(total_y)
    mu_list = []
    for column in range(total_y.shape[1]):
        column_data = total_y[:, column]
        # 使用Counter统计出现次数
        counter = Counter(column_data)
        # 找到出现最多的元素及其出现次数
        most_common_element, most_common_count = counter.most_common(1)[0]
        mu = most_common_count / len(column_data)
        mu_list.append(mu)
    print("Neighborhood Invariance: ", np.mean(mu_list))
    return np.mean(mu_list)