import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import sys, os
from random import sample
from numpy.random import uniform
from sklearn.neighbors import NearestNeighbors
import numpy as np 
import math
import os.path as osp 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def uniformity(weight):
    if weight.dim() > 2:
        weight = weight.view(weight.size(0), -1)

    weight_ = F.normalize(weight, p=2, dim=1)
    cosine = torch.matmul(weight_, weight_.t())
    n=cosine.size(0)
    cosine=cosine.flatten()[:-1].view(n-1,n+1)
    cosine=cosine[:,1:]
    theta=torch.acos(cosine)
    theta0=torch.acos(torch.tensor(-1/(n-1)))
    dif=theta-theta0
    u2=dif.mul(dif)
    u2_loss=torch.mean(u2)

    return u2_loss.item()

def Score(train_target_iter, classifier, num_classes):
    weight_path= osp.join('weight', args.data, args.source[0]+'2'+args.target[0])
    weights=os.listdir(weight_path)
    for weight in weights:
        save_path=osp.join(weight_path, weight)
        classifier.load_state_dict(torch.load(save_path))
        classifier.eval()
        iter_num=len(train_target_iter)

        im_loss=[]
        features=torch.tensor([0]).cuda().float()
        u=0

        ###calculate u###
        for k, v in classifier.head.named_parameters(): #fc_weight fc_bias
            if "weight" in k:
                u=uniformity(v)
        ###calculate m###
        with torch.no_grad():
            for i in range(iter_num):
                x_t, = next(train_target_iter)[:1]
                x_t = x_t.to(device)
                y, output_f = classifier(x_t, require_feature=True)
                if features.size()[0]==1:
                    features=output_f
                else:
                    features= torch.cat((features, output_f), dim=0)
                output_test=y
                softmax_out = nn.Softmax(dim=1)(output_test)
                entropy_loss = torch.mean(entropy(softmax_out))

                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-6))
                entropy_loss -= gentropy_loss
                im_loss.append(entropy_loss.item())
        ###calculate h###
        X=features.cpu().numpy()

        sample_size = int(X.shape[0]*0.05) #0.05 (5%) based on paper by Lawson and Jures

        X_uniform_random_sample = uniform(X.min(axis=0), X.max(axis=0) ,(sample_size , X.shape[1]))

        random_indices=sample(range(0, X.shape[0], 1), sample_size)
        X_sample = X[random_indices]

        neigh = NearestNeighbors(n_neighbors=2)
        nbrs=neigh.fit(X)

        u_distances , u_indices = nbrs.kneighbors(X_uniform_random_sample , n_neighbors=2)
        u_distances = u_distances[: , 0] 

        w_distances , w_indices = nbrs.kneighbors(X_sample , n_neighbors=2)
        w_distances = w_distances[: , 1]

        u_sum = np.sum(u_distances)
        w_sum = np.sum(w_distances)            
        H = u_sum/ (u_sum + w_sum)
        ###score###
        score= H-sum(im_loss)/len(im_loss)/math.log(num_classes)-u
        print(score)
