import argparse
from model import *
import random
import numpy as np
import os
os.environ['CURL_CA_BUNDLE'] = ''
import scipy
import time
import faiss
import numpy as np
from model import *
from utils.options import parser

args = parser.parse_args()

def eval_quality(features, partial_label_matrix, true_label, epoch):
    features = F.normalize(features, dim=1).numpy()

    true_label_matrix = F.one_hot(true_label, partial_label_matrix.size(1))
    false_partial_label_matrix = partial_label_matrix-true_label_matrix
    index = faiss.IndexFlatL2(features.shape[1])
    index.add(features)
    _, I5 = index.search(features, 5+1)
    n_indices_5 = torch.from_numpy(I5[:, 1:6])
    _, I20 = index.search(features, 20+1)
    n_indices_20 = torch.from_numpy(I20[:, 1:21])
    _, I50 = index.search(features, 50+1)
    n_indices_50 = torch.from_numpy(I50[:, 1:51])
    _, I100 = index.search(features, 100+1)
    n_indices_100 = torch.from_numpy(I100[:, 1:101])
    _, I150 = index.search(features, 150+1)
    n_indices_150 = torch.from_numpy(I150[:, 1:151])
    _, I200 = index.search(features, 200+1)
    n_indices_200 = torch.from_numpy(I200[:, 1:201])

    delta_5 = (true_label_matrix[n_indices_5].sum(dim=1) * true_label_matrix).sum(dim=1).float().mean() / 5
    delta_20 = (true_label_matrix[n_indices_20].sum(dim=1) * true_label_matrix).sum(dim=1).float().mean() / 20
    delta_50 = (true_label_matrix[n_indices_50].sum(dim=1) * true_label_matrix).sum(dim=1).float().mean() / 50
    delta_100 = (true_label_matrix[n_indices_100].sum(dim=1) * true_label_matrix).sum(dim=1).float().mean() / 100
    delta_150 = (true_label_matrix[n_indices_150].sum(dim=1) * true_label_matrix).sum(dim=1).float().mean() / 150
    delta_200 = (true_label_matrix[n_indices_200].sum(dim=1) * true_label_matrix).sum(dim=1).float().mean() / 200

    rho_5 = (partial_label_matrix[n_indices_5].sum(dim=1) * false_partial_label_matrix).max(dim=1)[0].float().mean() / 5
    rho_20 = (partial_label_matrix[n_indices_20].sum(dim=1) * false_partial_label_matrix).max(dim=1)[0].float().mean()  / 20
    rho_50 = (partial_label_matrix[n_indices_50].sum(dim=1) * false_partial_label_matrix).max(dim=1)[0].float().mean()  / 50
    rho_100 = (partial_label_matrix[n_indices_100].sum(dim=1) * false_partial_label_matrix).max(dim=1)[0].float().mean()  / 100
    rho_150 = (partial_label_matrix[n_indices_150].sum(dim=1) * false_partial_label_matrix).max(dim=1)[0].float().mean()  / 150
    rho_200 = (partial_label_matrix[n_indices_200].sum(dim=1) * false_partial_label_matrix).max(dim=1)[0] .float().mean() / 200
    ##


    print('Epoch[{:d}]: delta_5:{:.4f} delta_20:{:.4f} delta_50:{:.4f} delta_100:{:.4f} delta_150:{:.4f} delta_200:{:.4f}'.format(i+1, delta_5, delta_20, delta_50, delta_100,delta_150, delta_200))
    print('Epoch[{:d}]: rho_5:{:.4f} rho_20:{:.4f} rho_50:{:.4f} rho_100:{:.4f} rho_150:{:.4f} rho_200:{:.4f}'.format(epoch+1, rho_5, rho_20, rho_50, rho_100,rho_150, rho_200))

def knn_voting(args, features, pseudo_labels, partial_labels, distance_metric='cosine'):
    #features = F.normalize(features, dim=1).numpy()
    num_class = partial_labels.size(1)
    batch_size = features.shape[0]
    features = F.normalize(features, dim=1).detach().clone().cpu().numpy()
    
    # weighted voting based on neighbors's pseudo labels
    # neighbor_voting_labels = torch.zeros_like(pseudo_labels) # (batch_size, num_class)
    # neighbor_pseudo_labels = torch.zeros_like(pseudo_labels)
    down_voting_counts = torch.zeros_like(pseudo_labels).detach().cpu().numpy()
    
    if distance_metric == 'L2':
        index = faiss.IndexFlatL2(features.shape[1]) # (feature_dim, batch_size)
        index.add(features)

        D, I = index.search(features, args.k+1) # (batch_size, k+1), (batch_size, k+1)

        n_indices = torch.from_numpy(I[:, 1:args.k+1]) # top-k index
        n_values = torch.from_numpy(D[:, 1:args.k+1]) # top-k distance
        n_values = n_values / n_values.sum(dim=1, keepdim=True) # similarity weights

    elif distance_metric == 'cosine':
        similarity_matrix = torch.matmul(features, features.T)
        distance_matrix = 1.0 - similarity_matrix
        n_values, n_indices = distance_matrix.topk(args.k, dim=1, largest=False, sorted=True)
        _, s_indices = similarity_matrix.topk(args.k, dim=1, largest=False, sorted=True)
        n_values = n_values
        n_indices = n_indices
        n_values = n_values / n_values.sum(dim=1, keepdim=True) # similarity weights
        print(n_indices == s_indices)
    # else:# same pseudo label as anchor
    #     # for each anchor, find all the neighbors index with the same pseudo label
    #     n_indices = []
    
    reversed_partial_labels = (partial_labels.float() + 1) % 2
    reversed_partial_labels = reversed_partial_labels.detach().cpu().numpy()
    # K neighbors weighted voting
    for anchor_idx, n_idx in enumerate(n_indices):
        # method1 down voting
        neighbor_down_voting = reversed_partial_labels[n_idx].sum(dim=0) # (num_class,)
        print(pseudo_labels[anchor_idx])
        print(pseudo_labels[n_idx])
        print(neighbor_down_voting.shape == partial_labels[anchor_idx].shape)
        print(neighbor_down_voting.dtype)
        print(partial_labels[anchor_idx].dtype)
        down_voting_counts[anchor_idx] = torch.matmul(neighbor_down_voting, partial_labels[anchor_idx]) # (num_class,)
        
    return down_voting_counts

    #     #extend n_values[anchor_idx] from (k,) to (k, num_class)
    #     neighbor_weights = n_values[anchor_idx].unsqueeze(dim=1).repeat(1, num_class) # (k, num_class)
    #     # neighbor_voting_labels[anchor_idx] weighted sum of partial_labels[n_idx] and keep the shape (num_class,)
    #     neighbor_voting_labels[anchor_idx] = (partial_labels[n_idx] * neighbor_weights).sum(dim=0) # (k, num_class) * (k, num_class)
    #     # transform the pseudo labels pseudo_labels[n_idx] to one-hot representation
    #     pseudo_labels_onehot = F.one_hot(pseudo_labels[anchor_idx], num_class).float() # (k, num_class)
    #     # neighbor_pseu
    #     neighbor_pseudo_labels[anchor_idx] = 
    #     # normalize the voting results
    #     neighbor_voting_labels[anchor_idx] = neighbor_voting_labels[anchor_idx] / neighbor_voting_labels[anchor_idx].sum(dim=0, keepdim=True) # (num_class,)
    #     # neighbor_candid_labels[anchor_idx] = neighbor_candid_labels[anchor_idx] / neighbor_candid_labels[anchor_idx].sum(dim=0, keepdim=True) # (num_class,)
    
    # voting_pseudo_labels = torch.max(neighbor_voting_labels, dim=1)[1] # (batch_size,)
    
    
    # score_n_matrix = ((1-n_values).unsqueeze(dim=2).repeat(1, 1, num_class) * partial_labels[n_indices]).sum(dim=1) * partial_labels
    #score_n_matrix = partial_label_matrix[n_indices].sum(dim=1) * partial_label_matrix
    #score_n_matrix = reversed_partial_label_matrix[n_indices].sum(dim=1) * partial_label_matrix
    # down_notes = (reversed_partial_label_matrix[n_indices] * partial_label_matrix.unsqueeze(dim=1)).sum(dim=1)
    # down_notes2 = ((reversed_partial_label_matrix[n_indices] * n_values.unsqueeze(dim=2).repeat(1, 1, num_class)) * partial_label_matrix.unsqueeze(dim=1)).sum(dim=1)
    #down_notes -= reversed_partial_label_matrix
    #reversed_score_n_matrix = ((1-n_values).unsqueeze(dim=2).repeat(1, 1, num_class) * partial_label_matrix[n_indices]).sum(dim=1) * reversed_partial_label_matrix #for detect noise
    #score_n_matrix = score_n_matrix / score_n_matrix.sum(dim=1, keepdim=True) # whether normalize




def label_pruning(down_voting_counts, partial_labels, true_label):
    #maxValue_non_partial_label_matrix = reversed_partial_label_matrix * 1e+6
    #condition_matrix = maxValue_non_partial_label_matrix + score_n_matrix
    num_data = partial_labels.size(0)
    num_candidate = partial_labels.sum(dim=1) - 1
    num_del1 = (num_candidate * args.tau).ceil().long()
    #values1, indices1 = condition_matrix.sort(dim=1, descending=False)
    values1, indices1 = down_voting_counts.sort(dim=1, descending=True)
    threshold_values1 = values1[torch.arange(num_data), num_del1].unsqueeze(dim=1)
    #del_matrix = (condition_matrix < threshold_values1).float()
    del_matrix = (down_voting_counts > threshold_values1).float()

    pruned_partial_labels = partial_labels * (1-del_matrix)
    pruning_error = (true_label * del_matrix).sum().item() / num_data
    
    # del_matrix2 = (score_n_matrix > args.beta*args.k).float()
    # temp_index = torch.nonzero((partial_label_matrix-del_matrix2).sum(dim=1) == 0).squeeze(dim=1)
    # temp_del = F.one_hot(score_n_matrix[temp_index].max(dim=1)[1], num_class).float()
    # del_matrix2[temp_index] = temp_del.detach().clone()

    return pruned_partial_labels, 1 - pruning_error

def detect_noise(score_n_matrix, reversed_score_n_matrix, partial_label_matrix, true_labels):
    num_class = partial_label_matrix.size(1)
    true_label_matrix = F.one_hot(true_labels, num_class)
    selected_noise_mask = (score_n_matrix.max(dim=1)[0] < reversed_score_n_matrix.max(dim=1)[0])
    true_noise_mask = ((partial_label_matrix * true_label_matrix).sum(dim=1) == 0)
    precision = torch.nonzero(selected_noise_mask * true_noise_mask).shape[0] / torch.nonzero(selected_noise_mask).shape[0]
    recall = torch.nonzero(selected_noise_mask * true_noise_mask).shape[0] / torch.nonzero(true_noise_mask).shape[0]
    F1_score = 2 * precision * recall / (precision + recall)
    print("Epoch[{:d}]: Precision:{:.4f} recall:{:.4f} F1_score:{:.4f}".format(i+1, precision, recall, F1_score))

def eval_clc(score_n_matrix, del_matrix, partial_label_matrix, true_labels):
    num_data = partial_label_matrix.size(0)
    num_class = partial_label_matrix.size(1)
    reversed_partial_label_matrix = 1 - partial_label_matrix
    true_label_matrix = F.one_hot(true_labels, num_class)
    temp = (score_n_matrix + 100*reversed_partial_label_matrix)
    overall_accuracy = torch.nonzero(temp.min(dim=1)[1] == true_labels).shape[0] / num_data
    error_rate = torch.nonzero((del_matrix * true_label_matrix).sum(dim=1) != 0).shape[0] / num_data
    del_ratio = del_matrix.sum() / (partial_label_matrix.sum() - num_data)


    F_beta_score1 = F1_beta_scroe(0.5, 1-error_rate, del_ratio)
    F_beta_score2 = F1_beta_scroe(0.2, 1-error_rate, del_ratio)
    F_beta_score3 = F1_beta_scroe(0.1, 1-error_rate, del_ratio)


    print("Epoch[{:d}]: overall_accuracy:{:.4f} del_ratio:{:.4f}  error_rate:{:.4f} F_beta1:{:.4f} F_beta2:{:.4f} F_beta3:{:.4f}".format(i+1, overall_accuracy, del_ratio , error_rate, F_beta_score1,F_beta_score2,F_beta_score3))


def F1_beta_scroe(beta, precision, recall):
    return (1+beta**2)*precision*recall / ((beta**2)*precision+recall)

# feature_dim_map = {'resnet18_s': 512, 'resnet18_c': 512, 'resnet18_i': 512}


# data_loader, partial_label_matrix, true_labels = loader.run('train')
# reversed_partial_label_matrix = (partial_label_matrix.float() + 1) % 2

# num_data = partial_label_matrix.size(0)
# feature_dim = feature_dim_map[args.model_name]

# all_score_matrix = []
# all_reversed_score_matrix = []
# all_del_matrix = []
# false_partial_label_matrix = partial_label_matrix-true_label_matrix

# for i in range(args.epochs):
#     features = get_feature(model, data_loader)
#     eval_quality(features)
#     score_n_matrix = cal_knn(features)
#     del_matrix = clc(score_n_matrix)
#     eval_clc(score_n_matrix, del_matrix)









