import os
import random
import time

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, Subset, DataLoader, TensorDataset
from tools.data_process import feature_reduc, feature_redc_test


def load_dataset(name, dir, reduction, structure, resize=False, class_idx=[0, 1], scale=1.0, batch_size=64):
    if name == 'mnist':
        train_set = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
        test_set = datasets.MNIST(dir, train=False, download=True, transform=transforms.ToTensor())
    elif name == 'fashion_mnist':
        train_set = datasets.FashionMNIST(dir, train=True, download=True, transform=transforms.ToTensor())
        test_set = datasets.FashionMNIST(dir, train=False, download=True, transform=transforms.ToTensor())


    scaled_train_idx, scaled_train_images, scaled_train_labels = sample_data(train_set.data, train_set.targets, class_idx, scale, shuffle=False,
                                                                             bias_pro=[1 for _ in range(len(class_idx))])
    scaled_test_idx, scaled_test_images, scaled_test_labels = sample_data(test_set.data, test_set.targets, class_idx, scale, shuffle=False,
                                                                          bias_pro=[1 for _ in range(len(class_idx))])

    print(f'load {name} data for {len(class_idx)} classification, classes: {class_idx}, scale: {scale * 100}%, '
          f'num of training data: {len(scaled_train_idx)}, num of testing data: {len(scaled_test_idx)}')

    if resize:
        to_size = 8 if structure == 'drnn' else 16
        scaled_train_images, scaled_test_images = feature_reduc(scaled_train_images, scaled_test_images, f_type=reduction, d_name=name, to_size=to_size)
        print(f'feature reduction {reduction} has completed: {scaled_train_images.shape}')

    scaled_train_images, scaled_train_labels = scaled_train_images / 255.0, transform_labels(scaled_train_labels,
                                                                                                 class_idx)
    scaled_test_images, scaled_test_labels = scaled_test_images / 255.0, transform_labels(scaled_test_labels,
                                                                                              class_idx)


    train_set = TensorDataset(scaled_train_images, scaled_train_labels)
    test_set = TensorDataset(scaled_test_images, scaled_test_labels)

    return DataLoader(train_set, batch_size=batch_size, shuffle=True), DataLoader(test_set, batch_size=batch_size), scaled_train_images[0].shape


def transform_labels(y, class_idx=[0, 1]):
    for idx, i in enumerate(class_idx):
        a = torch.where(y == i)[0]
        y[a] = idx

    return y


def sample_data(x, y, class_idx=[0, 1], scale=0.1, shuffle=True, bias_pro=[1,1]):
    # y: original labels
    if shuffle:
        torch.manual_seed(int(time.time()))
        shuffle_indices = torch.randperm(x.shape[0])
        x = x[shuffle_indices]
        y = y[shuffle_indices]

    num_class = len(class_idx)
    if type(scale) is int:
        total = scale * num_class
    else:
        total = scale*torch.sum(torch.isin(y, torch.tensor(class_idx))).item()

    a = int(total / sum(bias_pro))
    idx = torch.tensor([], dtype=torch.long)
    for i in range(num_class):
        c_idx = torch.where(y == class_idx[i])[0]
        c_idx = c_idx[:a*bias_pro[i]]
        print(f'{len(c_idx)} data from class {class_idx[i]}')
        idx = torch.cat((idx, c_idx))
    idx = idx.tolist()
    return idx, x[idx], y[idx]


def load_correct_data(conf, model, train=False, n=None):
    # 挑选被模型预测正确的data
    name = conf.dataset
    dir = conf.data_dir
    class_idx = conf.class_idx
    if n is None: n = conf.num_test
    to_size = 8 if conf.structure == 'drnn' else 16

    if name == 'mnist':
        d_set = datasets.MNIST(dir, train=train, download=True)
    elif name == 'fashion_mnist':
        d_set = datasets.FashionMNIST(dir, train=train, download=True)


    _, x, y = sample_data(d_set.data.clone(), d_set.targets.clone(), class_idx, 0.2, shuffle=False, bias_pro=[1 for _ in range(len(class_idx))])

    x = x.float() if name == 'iris' else x.float() / 255.0
    y_ = y.clone()  # original labels
    y_uni = transform_labels(y, class_idx)  # transformed labels

    if conf.resize:
        x = feature_redc_test(x, f_type=conf.reduction, d_name=name, to_size=to_size)
        print(f'feature reduction {conf.reduction} has completed: {x.shape}')

    with torch.no_grad():
        y_probs = model.predict(x)
        y_preds = y_probs.argmax(1)
        correct_idx = torch.where(y_uni == y_preds)[0]
        total = y_uni.shape[0]
        x, y_uni, y_, y_probs = x[correct_idx], y_uni[correct_idx], y_[correct_idx], y_probs[correct_idx]
        print(f'acc of {total} data: {correct_idx.shape[0] / total * 100}%')

    scaled_idx, scaled_x, scaled_y = sample_data(x, y_, class_idx, n, shuffle=True, bias_pro=[1 for _ in range(len(class_idx))])
    y_uni = transform_labels(scaled_y, class_idx)

    return scaled_x, y_uni


def load_ori_data(conf, p_c):
    idx_list = []
    ori_y_list = []
    ori_list = []
    transform = transforms.Compose([transforms.ToTensor()])
    mode = 'L'

    num_per_class = [0 for _ in range(len(conf.class_idx))]
    files = [f for f in os.listdir(p_c) if f.endswith('.png')]
    files = sorted(files, key=lambda f: int(f.split('.')[0].split('_')[0]))

    for i, file in enumerate(files):
        file_wo_ext = file.split('.')[0]
        idx = int(file_wo_ext.split('_')[0])
        ori_y = int(file_wo_ext.split('_')[1])
        if num_per_class[ori_y] == conf.num_test:
            continue

        img_p = os.path.join(p_c, file)
        img = Image.open(img_p).convert(mode)
        img = transform(img)
        ori_list.append(img)
        idx_list.append(idx)  # img idx
        ori_y_list.append(ori_y)
        num_per_class[ori_y] += 1
    idx_list, ori_list, ori_y_list = torch.tensor(idx_list), torch.stack(ori_list), torch.tensor(ori_y_list)
    print(ori_y_list)
    print(f'load {len(ori_list)} ori images from {p_c}')
    return idx_list, ori_list, ori_y_list


