import numpy as np
import torch
from tabular_dataset import get_dataset, get_num_class
from augmentation import DataGenerator
from utils import TabularConfig
from feature_extractor import FeatureExtractor
import os
from utils import read_config
import tqdm
import pickle
from globa_utils import setup_seed
from angle_distribution import cal_angle_to_center
from diversity_distritbution import cal_diversity_to_set
from args import get_dst_name_args
# os.environ["CUDA_VISIBLE_DEVICES"]='7'

TC = TabularConfig()


def generate_augmented_data_main(dst_name):
    device = 'cuda:0'
    cfg = read_config(cfg_path=TC.get_dataset_config_path() + dst_name + '.yaml')
    train_set, col_info = get_dataset(dst_name, split='train', rand_number=cfg['split_seed'], test_ratio=cfg['test_ratio'])
    generator = DataGenerator(train_set, random_seed=cfg['generation_seed'], categorical_col=col_info.cate_name,
                              numerical_col=col_info.cont_name)
    data_pool, src_index_dict = generator.generate_data_to_pool(method=cfg['aug_method'])
    print("generating finish!")

    feature_extractor = FeatureExtractor(scaler=pickle.load(open(os.path.join(TC.get_scaler_save_path(), dst_name+'_scaler.pkl'), 'rb')))

    with open(os.path.join(os.path.join(TC.get_distribution_path(), dst_name), 'kfold_class_distri_info_Gaussian.pkl'), 'rb') as fi:
        class_center_list = pickle.load(fi)['class_centers']

    augmented_data_info = {}
    augmented_data_info['index_map'] = {}
    augmented_data_info['rank_info'] = {}

    src_class_split_dst = generator.split_set_by_class(dataset=train_set)
    num_class = len(data_pool.keys())
    data_pool_save_dict = {}
    for label in data_pool.keys():
        class_center_feat = torch.from_numpy(class_center_list[label]).to(device)
        src_class_samples = src_class_split_dst[label]
        src_class_samples = (src_class_samples.drop(['label'], axis=1), src_class_samples['label'])
        src_class_features = feature_extractor.extractor_features_from_dst(
            (pd.concat([train_set[0], src_class_samples[0]]), None), num_class, col_info, is_split_by_class=False)
        src_class_features = src_class_features[len(train_set[0]):, :]
        src_angles = cal_angle_to_center(src_class_features, class_center_feat)

        aug_class_samples = generator.ndarray_to_pd(data_pool[label])
        data_pool_save_dict[label] = aug_class_samples

        aug_class_features = feature_extractor.extractor_features_from_dst(
            (pd.concat([train_set[0], aug_class_samples[0]]), None), num_class, col_info, is_split_by_class=False)
        aug_class_features = aug_class_features[len(train_set[0]):, :]

        aug_angles = cal_angle_to_center(aug_class_features, class_center_feat)

        aug_angle_changes = torch.abs(aug_angles - src_angles[src_index_dict[label]]).cpu().numpy()
        aug_angles = aug_angles.cpu().numpy()

        aug_diversities = cal_diversity_to_set(aug_class_features, src_class_features, batch_size=512)
        aug_diversities = aug_diversities.cpu().numpy()

        augmented_data_info['rank_info'][label] = [(aug_angles[j], aug_angle_changes[j], aug_diversities[j], j)
                                                   for j in range(aug_angles.shape[0])]
        augmented_data_info['index_map'][label] = src_index_dict[label]

    if not os.path.exists(os.path.join(TC.get_data_pool_info_path(), dst_name)):
        os.makedirs(os.path.join(TC.get_data_pool_info_path(), dst_name))

    if not os.path.exists(os.path.join(TC.get_data_pool_path(), dst_name)):
        os.makedirs(os.path.join(TC.get_data_pool_path(), dst_name))

    with open(os.path.join(TC.get_data_pool_info_path(), dst_name, 'data_pool_info.pkl'), 'wb') as fo:
        pickle.dump(augmented_data_info, fo)

    with open(os.path.join(TC.get_data_pool_path(), dst_name, 'data_pool.pkl'), 'wb') as fo:
        pickle.dump(data_pool_save_dict, fo)


import functools
def cmp_func(a, b):
    # int(round(a[0])) == int(round(b[0])) and int(a[0]) == int(b[0])
    if a[0] == b[0]:
        # second rank indicate 1: angle changes 2: divsersity
        if a[2] > b[2]:
            return -1
        else:
            return 1
    else:
        if a[0] > b[0]:
            return 1
        else:
            return -1


def rank_main(dst_name):
    with open(os.path.join(TC.get_data_pool_info_path(), dst_name, 'data_pool_info.pkl'), 'rb') as fi:
        info = pickle.load(fi)

    ranked_info = info['rank_info']
    num_class = len(ranked_info.keys())
    for c in tqdm.tqdm(range(num_class)):
        ranked_info[c] = sorted(ranked_info[c], key=functools.cmp_to_key(cmp_func))

    with open(os.path.join(TC.get_data_pool_info_path(), dst_name, 'data_pool_info.pkl'), 'wb') as fo:
        pickle.dump({'index_map':info['index_map'], 'rank_info':ranked_info}, fo)


from val_sampler import DistriInfo, remove_duplication
import pandas as pd
from draw_utils import draw_real_syn_distribution


def show_aug_data_distribution(dst_name):
    num_class = get_num_class(dst_name)
    info = DistriInfo(
        os.path.join(TC.get_distribution_path(), dst_name), is_train=True)
    print(info.get_origin_distribution())
    with open(os.path.join(TC.get_data_pool_info_path(), dst_name, 'data_pool_info.pkl'), 'rb') as fi:
        augmented_data_info = pickle.load(fi)
    sample_rank_list = remove_duplication(augmented_data_info['rank_info'], tolerance=0)
    for i in range(num_class):
        aug_samples = sample_rank_list[i]
        aug_angle_vec = np.asarray([item[0] for item in aug_samples])
        aug_diversity_vec = np.asarray([item[2] for item in aug_samples])
        train_angle_diversity_vec = info.get_angle_diversity_data(i)
        train_aug_data = pd.concat([pd.DataFrame({'angle':aug_angle_vec, 'diversity':aug_diversity_vec, 'Type':["aug"] * len(aug_samples)}),
                                    pd.DataFrame({'angle': train_angle_diversity_vec[0,:], 'diversity': train_angle_diversity_vec[1,:],
                                                  'Type': ["train"] * train_angle_diversity_vec.shape[1]})])
        draw_real_syn_distribution(train_aug_data, hue_key='Type')


def show_selected_aug_data_distribution(dst_name, val_info):
    info = DistriInfo(
        os.path.join(TC.get_distribution_path(), dst_name), is_train=True)
    print(info.get_origin_distribution())
    num_class = get_num_class(dst_name)
    for i in range(num_class):
        aug_samples = val_info[i]
        aug_angle_vec = np.asarray([item[-4] for item in aug_samples])
        aug_diversity_vec = np.asarray([item[-2] for item in aug_samples])
        train_angle_diversity_vec = info.get_angle_diversity_data(i)
        train_aug_data = pd.concat(
            [pd.DataFrame({'angle': aug_angle_vec, 'diversity': aug_diversity_vec,
                           'Type': ["selected"] * aug_angle_vec.shape[0]}),
             pd.DataFrame({'angle': train_angle_diversity_vec[0, :], 'diversity': train_angle_diversity_vec[1, :],
                           'Type': ["train"] * train_angle_diversity_vec.shape[1]})])
        draw_real_syn_distribution(train_aug_data, hue_key='Type')
