# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import json
import torchvision
import numpy as np
import math

from torchvision import transforms
from .datasetbase import BasicDataset
from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation
from semilearn.datasets.utils import split_ssl_data
import pandas as pd
import random
from sklearn.model_selection import KFold
import scipy.io as sio
import torch
mean, std = {}, {}
mean['cifar10'] = [0.485, 0.456, 0.406]
mean['cifar100'] = [x / 255 for x in [129.3, 124.1, 112.4]]
mean['kadid10k'] = [0.485, 0.456, 0.406]

std['cifar10'] = [0.229, 0.224, 0.225]
std['cifar100'] = [x / 255 for x in [68.2, 65.4, 70.4]]
std['kadid10k'] = [0.229, 0.224, 0.225]

# 随机噪声添加
class AddRandomNoise:
    def __init__(self, noise_level=0.1,r=1):
        self.noise_level = noise_level
        self.r = r

    def __call__(self, x):
        noise = torch.randn(x.size()) * self.noise_level
        augmented_x = x + self.r*noise
        return augmented_x

# 特征缩放
class ScaleFeatures:
    def __init__(self, scale_factor=0.5):
        self.scale_factor = scale_factor

    def __call__(self, x):
        augmented_x = x * self.scale_factor
        return augmented_x

# 特征平移
class ShiftFeatures:
    def __init__(self, shift_amount=0.2):
        self.shift_amount = shift_amount

    def __call__(self, x):
        shift = torch.rand(x.size()) * self.shift_amount
        augmented_x = x + shift
        return augmented_x

def get_sjasffe(args, alg, name, num_labels, num_classes, data_dir='/home/ubuntu/wwc/dataset/DataSets/Movie.mat', include_lb_to_ulb=True):
    
    mat_file = sio.loadmat(data_dir)
    all_data = mat_file['features']
    print(all_data)
    all_target = mat_file['labels']
    all_target = np.argmax(all_target, axis=1)
    
    
    
    
    # df = pd.read_csv(data_dir+'dmos.csv')
    # df.rename(columns={"dmos": "score"}, inplace=True)
    # df = df[["dist_img", 'ref_img', 'score']]
    # all_data = ['/home/ubuntu/wwc/dataset/iqa/kadid10k/images/'+path for path in df["dist_img"]]
    # all_target = [score2label(score) for score in df['score']]
    # random.seed(args.seed)
    # data_target_pairs = list(zip(all_data, all_target))
    # random.shuffle(data_target_pairs)
    # train_size = int(0.8 * len(data_target_pairs))
    # test_size = len(data_target_pairs) - train_size
    # train_pairs = data_target_pairs[:train_size]
    # test_pairs = data_target_pairs[train_size:]
    
    # data = [pair[0] for pair in train_pairs]
    # target = [pair[1] for pair in train_pairs]
    # test_data = [pair[0] for pair in test_pairs]
    # test_target = [pair[1] for pair in test_pairs]
    n_splits = 5
    kf = KFold(n_splits=n_splits,random_state=args.seed,shuffle=True)
    k = args.seed
    # 获取第k折的训练集和测试集索引
    train_index, test_index = list(kf.split(all_data))[k]

    # 根据索引获取训练集和测试集数据
    data, test_data = [all_data[i] for i in train_index], [all_data[i] for i in test_index]
    target, test_target = [all_target[i] for i in train_index], [all_target[i] for i in test_index]

    print(len(data),len(target))
    crop_size = args.img_size
    img_size = args.img_size
    crop_ratio = args.crop_ratio

    transform_weak = transforms.Compose([
        transforms.ToTensor(),
        AddRandomNoise(noise_level=1,r=1),
    ])

    transform_strong = transforms.Compose([
        transforms.ToTensor(),
        AddRandomNoise(noise_level=2,r=1),
        #ScaleFeatures(scale_factor=0.5),
        AddRandomNoise(noise_level=2,r=1),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        AddRandomNoise(noise_level=1,r=1),
    ])


    lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data(args, data, target, num_classes, 
                                                                lb_num_labels=num_labels,
                                                                ulb_num_labels=args.ulb_num_labels,
                                                                lb_imbalance_ratio=args.lb_imb_ratio,
                                                                ulb_imbalance_ratio=args.ulb_imb_ratio,
                                                                include_lb_to_ulb=include_lb_to_ulb)
    
    lb_count = [0 for _ in range(num_classes)]
    ulb_count = [0 for _ in range(num_classes)]
    for c in lb_targets:
        lb_count[c] += 1
    for c in ulb_targets:
        ulb_count[c] += 1
    print("lb count: {}".format(lb_count))
    print("ulb count: {}".format(ulb_count))
    # lb_count = lb_count / lb_count.sum()
    # ulb_count = ulb_count / ulb_count.sum()
    # args.lb_class_dist = lb_count
    # args.ulb_class_dist = ulb_count

    if alg == 'fullysupervised':
        lb_data = data
        lb_targets = target
        # if len(ulb_data) == len(data):
        #     lb_data = ulb_data 
        #     lb_targets = ulb_targets
        # else:
        #     lb_data = np.concatenate([lb_data, ulb_data], axis=0)
        #     lb_targets = np.concatenate([lb_targets, ulb_targets], axis=0)
    
    # output the distribution of labeled data for remixmatch
    # count = [0 for _ in range(num_classes)]
    # for c in lb_targets:
    #     count[c] += 1
    # dist = np.array(count, dtype=float)
    # dist = dist / dist.sum()
    # dist = dist.tolist()
    # out = {"distribution": dist}
    # output_file = r"./data_statistics/"
    # output_path = output_file + str(name) + '_' + str(num_labels) + '.json'
    # if not os.path.exists(output_file):
    #     os.makedirs(output_file, exist_ok=True)
    # with open(output_path, 'w') as w:
    #     json.dump(out, w)

    lb_dset = BasicDataset(alg, lb_data, lb_targets, num_classes, transform_weak, False, None, False)

    ulb_dset = BasicDataset(alg, ulb_data, ulb_targets, num_classes, transform_weak, True, transform_strong, False)

    # base_dataset = torchvision.datasets.ImageFolder(args.test_data_dir)
    # imgs = np.array(base_dataset.imgs)
    # test_data, test_target = imgs[:, 0], imgs[:, 1]
    # test_target = [int(element) for element in test_target]
    print(len(test_target))

    eval_dset = BasicDataset(alg, test_data, test_target, num_classes, transform_val, False, None, False)

    return lb_dset, ulb_dset, eval_dset
