#!/usr/bin/env python
# coding=utf-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.spatial.distance import cdist


def get_all_info(feas, outputs,  labels, start_test):   
    if start_test:
        all_fea = feas.float().cpu()
        all_output = outputs.float().cpu()
        all_label = labels.float()
        start_test = False
    else:
        all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
        all_output = torch.cat((all_output, outputs.float().cpu()), 0)
        all_label = torch.cat((all_label, labels.float()), 0)
    return all_fea, all_output, all_label


def supervised_acc_pred(all_output, all_label):
    all_output = nn.Softmax(dim=1)(all_output)   # [498, 31]
    prob_value, predict = torch.max(all_output, 1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    return prob_value, predict, accuracy

def kmeans_culstering(args, all_fea, all_output, predict):
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
    
    all_fea = all_fea.float().cpu().numpy()   # [498, 257]
    K = all_output.size(1)    # K = 31
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)   # 31 x 256
    initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>args.threshold)
    labelset = labelset[0]

    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]
    for round in range(5):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    return pred_label


def dynamic_kmeans_culstering(args, all_fea, all_output, predict):
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
    
    all_fea = all_fea.float().cpu().numpy()   # [498, 257]
    K = all_output.size(1)    # K = 31
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)   # 31 x 256
    initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>args.threshold)
    labelset = labelset[0]

    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]
    for round in range(5):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
        initc_ema = initc_ema*0.9 + 0.1*initc 

        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    return pred_label


"""
implement by mixed dynamic label supervised and mixed-up labels
"""
def clean_obtain_dynamic_labels(loader, netF, netB, netC, args, log):
    start_test = True    #
    with torch.no_grad():
        iter_test = iter(loader)
        for _ in range(len(loader)):
            data = next(iter_test)  
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            att_info, neck_feat = netF(inputs)
            att_feas = F.normalize(att_info[0]) 
            att_outputs = att_info[1] 

            neck_feas = netB(neck_feat)
            neck_outputs = netC(neck_feas)
            if start_test:
                all_goan_fea = att_feas.float().cpu()
                all_goan_output = att_outputs.float().cpu()

                all_neck_feas = neck_feas.float().cpu()
                all_neck_outputs = neck_outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_goan_fea = torch.cat((all_goan_fea, att_feas.float().cpu()), 0)
                all_goan_output = torch.cat((all_goan_output, att_outputs.float().cpu()), 0)

                all_neck_feas = torch.cat((all_neck_feas, neck_feas.float().cpu()), 0)
                all_neck_outputs = torch.cat((all_neck_outputs, neck_outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
       
    _, attent_predict,  attent_acc = supervised_acc_pred(all_goan_output, all_label)
    _, neck_predict, neck_acc = supervised_acc_pred(all_neck_outputs, all_label)

    all_goan_output = torch.exp(all_goan_output) 
    attent_pred_label = kmeans_culstering(args, all_goan_fea, all_goan_output, attent_predict)
    all_neck_outputs = torch.exp(all_neck_outputs) 
    neck_pred_label = kmeans_culstering(args, all_neck_feas, all_neck_outputs, neck_predict)

    argeed_temp_label = []
    disargeed_temp_label = []
    true_label = []

    argeed_temp_idx = []
    disargeed_temp_idx = []


    for i in range(attent_pred_label.shape[0]):
        if attent_pred_label[i] !=  neck_pred_label[i]:
            disargeed_temp_label.append(neck_pred_label[i])
            disargeed_temp_idx.append(i)
            
        else:
            argeed_temp_label.append(neck_pred_label[i])
            true_label.append(all_label[i])
            argeed_temp_idx.append(i)

    argeed_acc = np.sum(np.array(true_label)==np.array(argeed_temp_label)) / len(true_label)

    argeed_temp_idx = np.array(argeed_temp_idx)

    disargeed_temp_idx = np.array(disargeed_temp_idx)

    ## features
    argeed_neck_features = all_neck_feas[argeed_temp_idx]
    disargeed_neck_features = all_neck_feas[disargeed_temp_idx]
    distance = disargeed_neck_features@argeed_neck_features.T

    if distance.shape[1] > args.K :
        _, idx_near = torch.topk(distance, dim=-1, largest=True, k=args.K)
        near_labels = neck_pred_label[argeed_temp_idx[idx_near]]

        nearest_labels, _ = torch.mode(torch.from_numpy(near_labels))        

        constent_labels = neck_pred_label.copy()

        if len(disargeed_temp_idx) > 0:
            constent_labels[disargeed_temp_idx] = nearest_labels
    
        return argeed_temp_idx, disargeed_temp_idx, constent_labels.astype('int')
    else:
        return argeed_temp_idx, disargeed_temp_idx, attent_pred_label.astype('int')



if __name__ == '__main__':
    pass
