import numpy as np
import torch
from nlp_dataset import get_dataset, get_num_class, split_dst_by_class, read_src_index_file
from augmentation import NLPAugment
from utils import PathConfig
from feature_extractor import FeatureExtractor
import os
from utils import read_config
import tqdm
import pickle
from train_eval import setup_seed
from angle_distribution import cal_angle_to_center
import pandas as pd
from diversity_distritbution import cal_diversity_to_set

import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('--dst_name', choices=['reuters', 'imdb', 'newsgroups'], required=False)
args = parser.parse_args()


# os.environ["CUDA_VISIBLE_DEVICES"]='3'

PC = PathConfig()


def cal_entropy(logits, eps=1e-10):
    probs = torch.nn.functional.softmax(logits, dim=1)
    entropys = -1*torch.sum(probs*torch.log2(probs+eps), dim=1)
    return entropys


def generate_augmented_data_main(dst_name):
    train_set, _ = get_dataset(dst_name, model_name="bert-base-uncased", is_post_process=False)
    nlp_auger = NLPAugment(device='cuda:0', batch_size=32)
    data_pool, src_index_dict = nlp_auger.generate_augmented_data(split_dst_by_class(train_set, get_num_class(dst_name)),
                                                                  multi_aug_p=True)
    save_path = PC.get_data_pool_path(dst_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for label in data_pool.keys():
        result = pd.DataFrame(data_pool[label])
        result.to_csv(os.path.join(save_path, str(label)+'.csv'), index=False)

    with open(os.path.join(save_path, 'src_index.pkl'), 'wb') as fo:
        pickle.dump(src_index_dict, fo)


def process_data_pool(dst_name):
    num_class = get_num_class(dst_name)
    folder_path = PC.get_data_pool_path(dst_name)
    origin_src_index_dict = read_src_index_file(folder_path)
    new_src_index_dict = {}
    for i in range(num_class):
        dst = pd.read_csv(os.path.join(folder_path, str(i) + '.csv'))
        dst['src_index'] = origin_src_index_dict[i]
        print('class %d origin size %d'%(i, len(dst)))
        dst = dst.dropna(axis=0, how='any')
        print('class %d current size %d' % (i, len(dst)))
        new_src_index_dict[i] = dst['src_index'].tolist()
        dst = dst.drop('src_index', axis=1)
        if 'Unnamed: 0' in dst.columns:
            dst = dst.drop('Unnamed: 0', axis=1)
        dst.to_csv(os.path.join(folder_path, str(i) + '.csv'), index=None)

    with open(os.path.join(folder_path, 'src_index.pkl'), 'wb') as fo:
        pickle.dump(new_src_index_dict, fo)


def cal_data_pool_info_main(dst_name, bi_type='Gaussian'):
    model_device = 'cuda:0'
    data_op_device = 'cuda:1'
    num_class = get_num_class(dst_name)
    train_set, _ = get_dataset(dst_name, model_name="bert-base-uncased")
    data_pool, src_index_dict = get_dataset(dst_name+'_data_pool', model_name="bert-base-uncased")

    feature_extractor = FeatureExtractor(num_class=get_num_class(dst_name),
                                         weight_path=os.path.join(PathConfig().get_fe_path(dst_name), 'default'))

    with open(os.path.join(PC.get_distribution_path(dst_name), 'kfold_class_distri_info_{}.pkl'.format(bi_type)), 'rb') as fi:
        class_center_list = pickle.load(fi)['class_centers']

    augmented_data_info = {}
    augmented_data_info['index_map'] = {}
    augmented_data_info['rank_info'] = {}

    src_class_split_dst = split_dst_by_class(train_set, num_class=num_class)
    for label in data_pool.keys():
        class_center_feat = torch.from_numpy(class_center_list[label]).to(data_op_device)
        src_class_samples = src_class_split_dst[label]
        src_class_features = feature_extractor.extractor_features_from_dst(src_class_samples.remove_columns(['labels']),
                                                                           src_class_samples['labels'], is_split_by_class=False)
        src_class_logits = feature_extractor.extractor_features_from_dst(src_class_samples.remove_columns(['labels']),
                                                                           src_class_samples['labels'],
                                                                           is_split_by_class=False, is_logits=True)
        src_entropy = cal_entropy(src_class_logits)

        src_angles = cal_angle_to_center(src_class_features, class_center_feat)

        aug_class_samples = data_pool[label]
        aug_class_features = feature_extractor.extractor_features_from_dst(aug_class_samples.remove_columns(['labels']),
                                                                           aug_class_samples['labels'], is_split_by_class=False)
        aug_class_logits = feature_extractor.extractor_features_from_dst(aug_class_samples.remove_columns(['labels']),
                                                                           aug_class_samples['labels'],
                                                                           is_split_by_class=False, is_logits=True)

        aug_angles = cal_angle_to_center(aug_class_features, class_center_feat)
        aug_entropy = cal_entropy(aug_class_logits)

        aug_angle_changes = torch.abs(aug_angles - src_angles[src_index_dict[label]]).cpu().numpy()
        aug_angles = aug_angles.cpu().numpy()
        aug_entropy_changes = torch.abs(aug_entropy - src_entropy[src_index_dict[label]]).cpu().numpy()
        aug_entropy = aug_entropy.cpu().numpy()

        aug_diversities = cal_diversity_to_set(aug_class_features, src_class_features, batch_size=64)
        aug_diversities = aug_diversities.cpu().numpy()

        augmented_data_info['rank_info'][label] = [(aug_angles[j], aug_angle_changes[j],
                                                    aug_entropy[j], aug_entropy_changes[j],
                                                    aug_diversities[j], j) for j in range(aug_angles.shape[0])]
        augmented_data_info['index_map'][label] = src_index_dict[label]

    save_path = PC.get_data_pool_info_path(dst_name)

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    with open(os.path.join(save_path, 'data_pool_info.pkl'), 'wb') as fo:
        pickle.dump(augmented_data_info, fo)



import functools
def cmp_func(a, b):
    # int(round(a[0])) == int(round(b[0])) and int(a[0]) == int(b[0])
    if a[0] == b[0]:
        # second rank indicate 1: angle changes 4: divsersity
        if a[2] > b[2]:
            return -1
        else:
            return 1
    else:
        if a[0] > b[0]:
            return 1
        else:
            return -1


def rank_main(dst_name):
    with open(os.path.join(PC.get_data_pool_info_path(dst_name), 'data_pool_info.pkl'), 'rb') as fi:
        info = pickle.load(fi)

    ranked_info = info['rank_info']
    num_class = len(ranked_info.keys())
    for c in tqdm.tqdm(range(num_class)):
        ranked_info[c] = sorted(ranked_info[c], key=functools.cmp_to_key(cmp_func))

    with open(os.path.join(PC.get_data_pool_info_path(dst_name), 'data_pool_info.pkl'), 'wb') as fo:
        pickle.dump({'index_map':info['index_map'], 'rank_info':ranked_info}, fo)

#
from val_sampler import DistriInfo, remove_duplication
import pandas as pd
from draw_utils import draw_real_syn_distribution

#
def show_aug_data_distribution(dst_name):
    num_class = get_num_class(dst_name)
    info = DistriInfo(PC.get_distribution_path(dst_name), is_train=True)
    print(info.get_origin_distribution())
    with open(os.path.join(PC.get_data_pool_info_path(dst_name), 'data_pool_info.pkl'), 'rb') as fi:
        augmented_data_info = pickle.load(fi)
    sample_rank_list = remove_duplication(augmented_data_info['rank_info'], tolerance=0)
    angle_pos = 0
    diversity_pos = 4
    for i in range(num_class):
        aug_samples = sample_rank_list[i]
        aug_angle_vec = np.asarray([item[angle_pos] for item in aug_samples])
        aug_diversity_vec = np.asarray([item[diversity_pos] for item in aug_samples])
        train_angle_diversity_vec = info.get_angle_diversity_data(i)
        train_aug_data = pd.concat([pd.DataFrame({'angle':aug_angle_vec, 'diversity':aug_diversity_vec, 'Type':["aug"] * len(aug_samples)}),
                                    pd.DataFrame({'angle': train_angle_diversity_vec[0,:], 'diversity': train_angle_diversity_vec[1,:],
                                                  'Type': ["train"] * train_angle_diversity_vec.shape[1]})])
        draw_real_syn_distribution(train_aug_data, hue_key='Type')


def show_selected_aug_data_distribution(dst_name, val_info):
    info = DistriInfo(PC.get_distribution_path(dst_name), is_train=True)
    print(info.get_origin_distribution())
    num_class = get_num_class(dst_name)
    for i in range(num_class):
        aug_samples = val_info[i]
        aug_angle_vec = np.asarray([item[-6] for item in aug_samples])
        aug_diversity_vec = np.asarray([item[-2] for item in aug_samples])
        train_angle_diversity_vec = info.get_angle_diversity_data(i)
        train_aug_data = pd.concat(
            [pd.DataFrame({'angle': aug_angle_vec, 'diversity': aug_diversity_vec,
                           'Type': ["selected"] * aug_angle_vec.shape[0]}),
             pd.DataFrame({'angle': train_angle_diversity_vec[0, :], 'diversity': train_angle_diversity_vec[1, :],
                           'Type': ["train"] * train_angle_diversity_vec.shape[1]})])
        draw_real_syn_distribution(train_aug_data, hue_key='Type')


if __name__ == '__main__':
    setup_seed(PC.get_global_seed())
    name = 'reuters'
    # generate_augmented_data_main(name)
    process_data_pool(name)
    cal_data_pool_info_main(dst_name=name, bi_type='Gaussian')
    rank_main(dst_name=name)
    show_aug_data_distribution(dst_name=name)

    #
    # from val_sampler import sample_val
    # val_info_dict = sample_val(dst_name=name, method='DB_ADJOINT', val_num_per_class=200, random_seed=0, is_return_info=True)
    # show_selected_aug_data_distribution(dst_name=name, val_info=val_info_dict['info'])