import os
from pathlib import Path

import numpy as np
import pandas as pd

import cfg


def create():
    for dataset in cfg.all_datasets:
        # create metadata folder
        path_exp_metadata = Path(cfg.path_data, dataset, 'exp_metadata')
        os.makedirs(path_exp_metadata, exist_ok=True)
        for dataset_type in cfg.all_dataset_types:
            random_seeds = None
            if dataset_type == 'complete':
                random_seeds = cfg.random_seeds_complete
            elif dataset_type == '2k':
                random_seeds = cfg.random_seeds_2k
            for random_seed in range(random_seeds):
                for init_train_samples in cfg.all_budgets:
                    fname_metadata = (f'metadata_{dataset}_{dataset_type}_{random_seed}_{init_train_samples}.csv')
                    # if file exists, continue
                    if os.path.exists(Path(path_exp_metadata, fname_metadata)):
                        continue

                    # load metadata
                    df = pd.read_csv(Path(cfg.path_data, dataset, 'metadata.csv'))
                    # shorten metadata if 2k
                    if dataset_type == '2k':
                        # lower the nr of unlabelled samples to 2k
                        df = pd.concat([
                            df[df['subset'] == cfg.tag_evaluate],
                            df[df['subset'] == cfg.tag_unlabelled].sample(
                                n=min((df['subset'] == cfg.tag_unlabelled).sum(), 2000), random_state=random_seed
                            )])

                    # get unlabelled samples
                    df_unlabelled = df[df['subset'] == cfg.tag_unlabelled].copy()
                    df_unlabelled = df_unlabelled[
                        df_unlabelled.columns[df_unlabelled.columns.str.startswith(cfg.label_prefix)]]

                    # set random state
                    np.random.seed(random_seed)

                    # select indices
                    indices_init_train = []
                    for sample_nr in range(init_train_samples):
                        if df_unlabelled.shape[1] > 0:
                            # select a sample that has the highest number of until now unseen classes
                            row_sums = df_unlabelled.sum(axis=1)
                            max_indices = row_sums[row_sums == row_sums.max()]
                            selected_index = np.random.choice(max_indices.index)

                            # delete columns that the sample covers
                            columns_covered = df_unlabelled.columns[df_unlabelled.loc[selected_index] == 1]
                            df_unlabelled = df_unlabelled.drop(columns=columns_covered)
                            print(
                                f'{dataset}, {random_seed} | Sample {sample_nr}: Initial set: New class included')
                        else:
                            # select random sample
                            remaining_indices = df_unlabelled.index.difference(indices_init_train)
                            selected_index = np.random.choice(remaining_indices)
                            print(f'{dataset}, {random_seed} | Sample {sample_nr}: All classes included')

                        # save index
                        indices_init_train.append(selected_index)

                    # set training indices
                    df.loc[indices_init_train, 'subset'] = cfg.tag_train

                    # save file
                    df.index.name = 'index'
                    df.to_csv(Path(path_exp_metadata, fname_metadata), index=True)

