
import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error,f1_score,accuracy_score
from sklearn.metrics import normalized_mutual_info_score
from sklearn.metrics import adjusted_rand_score
from sklearn import linear_model
import pdb

from sklearn.preprocessing import normalize
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from random import sample
import random
import os
from sklearn.svm import LinearSVC
#from sklearnex import patch_sklearn, unpatch_sklearn

from torch.autograd import Variable


"""Function used for Orthogonal Regularization"""
def l2_reg_ortho(mdl):
        l2_reg = None
        for W in mdl.parameters():
                if W.ndimension() < 2:
                        continue
                else:
                        cols = W[0].numel()
                        rows = W.shape[0]
                        w1 = W.view(-1,cols)
                        wt = torch.transpose(w1,0,1)
                        if (rows > cols):
                                m  = torch.matmul(wt,w1)
                                ident = Variable(torch.eye(cols,cols),requires_grad=True)
                        else:
                                m = torch.matmul(w1,wt)
                                ident = Variable(torch.eye(rows,rows), requires_grad=True)

                        ident = ident.cuda()
                        w_tmp = (m - ident)
                        b_k = Variable(torch.rand(w_tmp.shape[1],1))
                        b_k = b_k.cuda()

                        v1 = torch.matmul(w_tmp, b_k)
                        norm1 = torch.norm(v1,2)
                        v2 = torch.div(v1,norm1)
                        v3 = torch.matmul(w_tmp,v2)

                        if l2_reg is None:
                                l2_reg = (torch.norm(v3,2))**2
                        else:
                                l2_reg = l2_reg + (torch.norm(v3,2))**2
        return l2_reg


def write_csv(data,output,normal,half):
    with open(output,"w") as f:
        for key in data:
            f.write(key+"\n")
            res_1 = [str(num) for num in data[key][0]]
            res_2 = [str(num) for num in data[key][1]]
            if normal:
                f.write("mean \n")
                f.write(",".join(res_1)+"\n")
            if half:
                f.write("std \n")
                f.write(",".join(res_2)+"\n")



def load_dataset(file_name):
    Experiment_Data = np.load(file_name,allow_pickle=True).item()
    try:
        God_Embedding = Experiment_Data['God_Embedding']
    except:
        God_Embedding=None
    #Ws= Experiment_Data['Ws']

    Tasks= Experiment_Data['Tasks']


    Mlt_Data_Cases = Experiment_Data['Mlt_Data_Cases']

    return God_Embedding,Tasks,Mlt_Data_Cases





def Initialize_Seed(seed=2):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark=False
    os.environ["PYTHONHASHSEED"]= str(seed)


def evaluation_metrics_reg(y_pred, y_test):
    #y_pred[y_pred < 0] = 0
    mae = mean_absolute_error(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)
    return mae, np.sqrt(mse), r2


def regression(X_train, Y_train, X_test):
    reg = linear_model.Ridge(alpha=1)
    #reg = SVR(kernel="poly",C=10,gamma=0.1,degree=10,coef0=1)
    #reg = SVR(kernel='linear',C=10,gamma=0.1)
    ##reg = RandomForestRegressor(n_jobs=-1,max_depth=20)
    #  train 1000 * 48 -> 1000*1  test:705 *48 -> 705*1            # income_sim_matrix   embedding  
    #pdb.set_trace()
    reg.fit(X_train, Y_train)
    y_pred = reg.predict(X_test)
    return y_pred



def kf_regression(X, Y):
    kf = KFold(n_splits=5,shuffle=True,random_state=0)
    y_preds = []
    y_truths = []
    for train_index, test_index in kf.split(X):
        X_train, X_test = X[train_index], X[test_index]
        Y_train, Y_test = Y[train_index], Y[test_index]
        y_pred = regression(X_train, Y_train, X_test)
        y_preds.append(y_pred)
        y_truths.append(Y_test)
    return np.concatenate(y_preds), np.concatenate(y_truths)

def predict_regression(emb, label):
    #pdb.set_trace()
    Initialize_Seed()
    #y_pred, y_test = k_means_regression(emb[mask], np.array(label))
    y_pred, y_test = kf_regression(emb, np.array(label).reshape(-1))     
    mae, rmse, r2 = evaluation_metrics_reg(y_pred, y_test)
    return r2


def evaluation_metrics_cls(y_pred, y_test):
    #y_pred[y_pred < 0] = 0
    f1_scores = f1_score(y_test,y_pred, average=None)

    # # 打印每个类别的F1 Score
    # for i in range(len(f1_scores)):
    #     print("类别{}的F1 Score为： {:.4f}".format(i, f1_scores[i]))

    # 计算所有类别的平均F1 Score
    avg_f1_score = np.mean(f1_scores)
    #print("所有类别的平均F1 Score为： {:.4f}".format(avg_f1_score)) 
    return avg_f1_score


def classification(X_train, Y_train, X_test):
    clf = LinearSVC(random_state=42)
    #LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    #reg = SVR(kernel="poly",C=10,gamma=0.1,degree=10,coef0=1)
    #reg = SVR(kernel='linear',C=10,gamma=0.1)
    ##reg = RandomForestRegressor(n_jobs=-1,max_depth=20)
    #  train 1000 * 48 -> 1000*1  test:705 *48 -> 705*1            # income_sim_matrix   embedding  
    #pdb.set_trace()
    clf.fit(X_train, Y_train)
    y_pred = clf.predict(X_test)
    return y_pred


def kf_cls(X, Y):
    kf = KFold(n_splits=5,shuffle=True,random_state=0)
    y_preds = []
    y_truths = []
    import time
    for train_index, test_index in kf.split(X):
        b = time.time()
        #pdb.set_trace()
        X_train, X_test = X[train_index], X[test_index]
        Y_train, Y_test = Y[train_index], Y[test_index]
        y_pred = classification(X_train, Y_train, X_test)
        y_preds.append(y_pred)
        y_truths.append(Y_test)
        print(time.time()-b)
    return np.concatenate(y_preds), np.concatenate(y_truths)

def predict_cls(emb, label):
    #pdb.set_trace()
    Initialize_Seed()
    #patch_sklearn()
    #y_pred, y_test = k_means_regression(emb[mask], np.array(label))
    y_pred, y_test = kf_cls(emb, np.array(label).reshape(-1))
    #pdb.set_trace()
    f1_score = evaluation_metrics_cls(y_pred, y_test)
    return f1_score

from PIL import Image
import torchvision.transforms.functional as F
# Transform which randomly corrupts pixels with a given probabiliy
class PixelCorruption(object):
    MODALITIES = ['flip', 'drop']

    def __init__(self, p, min=0, max=1, mode='drop'):
        super(PixelCorruption, self).__init__()

        assert mode in self.MODALITIES

        self.p = p
        self.min = min
        self.max = max
        self.mode = mode

    def __call__(self, im):
        if isinstance(im, Image.Image) or isinstance(im, np.ndarray):
            im = F.to_tensor(im)

        if self.p < 1.0:
            mask = torch.bernoulli(torch.zeros(im.size(1), im.size(2)) + 1. - self.p).bool()
        else:
            mask = torch.zeros(im.size(1), im.size(2)).bool()

        if len(im.size())>2:
            mask = mask.unsqueeze(0).repeat(im.size(0),1,1)

        if self.mode == 'flip':
            im[mask] = self.max - im[mask]
        elif self.mode == 'drop':
            im[mask] = self.min

        return im