import numpy as np

import torch
from sklearn.decomposition import PCA
import pdb
import random
from random import sample
import os
from utils import Initialize_Seed
import torch.nn as nn

from torchvision.transforms import Compose, RandomAffine, ToTensor
from torch.utils.data import DataLoader

from utils import PixelCorruption


import torch
import torchvision
from torchvision import datasets, transforms
from PIL import Image
import os
import pdb
import random
from scipy.io import loadmat
from random import shuffle

import math
import numpy as np
from numpy.random import randint
from sklearn.preprocessing import OneHotEncoder


def Normalize(data):
    """
    :param data:Input data
    :return:normalized data
    """
    m = np.mean(data)
    mx = np.max(data)
    mn = np.min(data)
    return (data - m) / (mx - mn)


def read_data(str_name, ratio=0.8, Normal=0):
    """read data and spilt it train set and test set evenly
    :param str_name:path and dataname
    :param ratio:training set ratio
    :param Normal:do you want normalize
    :return:dataset and view number
    """
    data = loadmat(str_name)
    view_number = data['X'].shape[1]
    X = np.split(data['X'], view_number, axis=1)
    X_train = []
    X_test = []
    labels_train = []
    labels_test = []
    if min(data['gt']) == 0:
        labels = data['gt'] + 1
    else:
        labels = data['gt']
    classes = max(labels)[0]
    all_length = 0
    for c_num in range(1, classes + 1):
        c_length = np.sum(labels == c_num)
        index = np.arange(c_length)
        shuffle(index)
        #pdb.set_trace()
        labels_train.extend(labels[all_length + index][0:math.floor(c_length * ratio)])
        labels_test.extend(labels[all_length + index][math.floor(c_length * ratio):])
        X_train_temp = []
        X_test_temp = []
        for v_num in range(view_number):
            X_train_temp.append(X[v_num][0][0].transpose()[all_length + index][0:math.floor(c_length * ratio)])
            X_test_temp.append(X[v_num][0][0].transpose()[all_length + index][math.floor(c_length * ratio):])
        if c_num == 1:
            X_train = X_train_temp;
            X_test = X_test_temp
        else:
            for v_num in range(view_number):
                #pdb.set_trace()
                X_train[v_num] = np.r_[X_train[v_num], X_train_temp[v_num]]
                X_test[v_num] = np.r_[X_test[v_num], X_test_temp[v_num]]
        all_length = all_length + c_length
    
    if (Normal == 1):
        for v_num in range(view_number):
            X_train[v_num] = Normalize(X_train[v_num])
            X_test[v_num] = Normalize(X_test[v_num])

    #pdb.set_trace()
    return X_train, X_test, np.array(labels_train).reshape(-1,1),np.array(labels_test).reshape(-1,1)
def get_sn(view_num, alldata_len, missing_rate):
    """Randomly generate incomplete data information, simulate partial view data with complete view data
    :param view_num:view number
    :param alldata_len:number of samples
    :param missing_rate:Defined in section 3.2 of the paper
    :return:Sn
    """
    one_rate = 1-missing_rate
    if one_rate <= (1 / view_num):
        enc = OneHotEncoder()
        view_preserve = enc.fit_transform(randint(0, view_num, size=(alldata_len, 1))).toarray()
        return view_preserve
    error = 1
    if one_rate == 1:
        matrix = randint(1, 2, size=(alldata_len, view_num))
        return matrix
    while error >= 0.005:
        enc = OneHotEncoder()
        view_preserve = enc.fit_transform(randint(0, view_num, size=(alldata_len, 1))).toarray()
        one_num = view_num * alldata_len * one_rate - alldata_len
        ratio = one_num / (view_num * alldata_len)
        matrix_iter = (randint(0, 100, size=(alldata_len, view_num)) < int(ratio * 100)).astype(int)
        a = np.sum(((matrix_iter + view_preserve) > 1).astype(int))
        one_num_iter = one_num / (1 - a / one_num)
        ratio = one_num_iter / (view_num * alldata_len)
        matrix_iter = (randint(0, 100, size=(alldata_len, view_num)) < int(ratio * 100)).astype(int)
        matrix = ((matrix_iter + view_preserve) > 0).astype(int)
        ratio = np.sum(matrix) / (view_num * alldata_len)
        error = abs(one_rate - ratio)
    return matrix


def fill_masked_rows_with_mean(view, mask):
    # 找出mask中值为1的位置
    mask_indices = np.where(mask == 1)[0]
    # 取出对应的行向量，并求均值
    mean_row = np.mean(view[mask_indices], axis=0)
    # 找出mask中值为0的位置
    fill_indices = np.where(mask == 0)[0]
    # 将值为0的位置替换为均值向量
    view[fill_indices] = mean_row
    return view



if __name__ == "__main__":
    Initialize_Seed()
    for data_name in ['Caltech','CUB']:
        for missing_rate in[0]:
 
            data2file = {'CUB':'cub_googlenet_doc2vec_c10.mat','Caltech':'Caltech101.mat'}  

    
            X_train, X_test, labels_train,labels_test  = read_data(str_name='./CPM_NET_DATA/data/'+data2file[data_name])

            view_num = len(X_train)

            trainLen = len(X_train[0])
            testLen = len(X_test[0])


            Sn = get_sn(view_num, trainLen + testLen, missing_rate)  
            Sn_train = Sn[np.arange(trainLen)]  # train mask
            Sn_test = Sn[np.arange(testLen) + trainLen]  # test mask (test_size,view_num)
            #pdb.set_trace()
    
            for i,view in enumerate(X_train):
                X_train[i] = fill_masked_rows_with_mean(view,Sn_train[:,i])
            for i,view in enumerate(X_test):
                X_test[i] = fill_masked_rows_with_mean(view,Sn_test[:,i])
                X_test[i] = np.concatenate([X_test[i],X_train[i]],axis=0)

   
    
            Experiment_data = {}

            Experiment_data['test_label']= np.concatenate([labels_test,labels_train],axis=0)

            Experiment_data['train']= X_train 
    
            Experiment_data['test']= X_test 

            np.save("./CPM_data/{}_{}_{}_views.npy".format(data_name,view_num,missing_rate),Experiment_data)
    
#    保存模型
    # torch.save(model_digit, './selected_polymnist_pngs/resnet18_mnist_digit.pth')
    # torch.save(model_style, './selected_polymnist_pngs/resnet18_mnist_style.pth')




