import gc
import multiprocessing
import os
import pickle
import random
import threading
import time
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input


import cfg
from src.experiments import active_learning_methods, result, train_model, load_data

print_lock = threading.Lock()


def conduct(all_tasks):
    # get number of CPUs
    num_cpus = multiprocessing.cpu_count()

    # run asynchron when multiple CPUs available and pool > 1, else serial
    if num_cpus > 10 and cfg.nr_processing_pool > 1:
        # multiprocessing
        with multiprocessing.Pool(cfg.nr_processing_pool) as pool:
            pool.map(wrapper_multiprocessing, all_tasks)
    else:
        # serial
        for task in all_tasks:
            conduct_al_experiment(task)


def wrapper_multiprocessing(task):
    try:
        conduct_al_experiment(task)
    except:
        print(f'#########################################################Error in wrapper: {task[-1]}')


def conduct_al_experiment(args):
    # unpack variables
    (dataset, dataset_type, random_seed, init_train_samples, add_train_samples, max_train_samples, train_paradigm,
     weight_init, training, qm, fname_result) = args

    # create file name
    fname = f'{fname_result}.pkl'
    if os.path.exists(Path(cfg.path_results(dataset, dataset_type), fname)):
        return

    # set random seeds
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)

    # load df metadata
    df = pd.read_csv(Path(cfg.path_data, dataset, 'exp_metadata',
                          f'metadata_{dataset}_{dataset_type}_{random_seed}_{init_train_samples}.csv'), index_col='index')

    # define number of iterations
    nr_iterations = int(1 + (max_train_samples - init_train_samples) / add_train_samples)

    # initialise dataframe
    df = df.rename(columns={'subset': cfg.get_iteration_col(0)})
    df_data = {col: df[col] for col in df.columns}
    for nr_iteration in range(nr_iterations):
        tag_col_iter = cfg.get_iteration_col(nr_iteration)
        if tag_col_iter not in df.columns:
            df_data[tag_col_iter] = pd.NA
    # add 'all' column if not 2k datasets, but only if not fine-tuning
    if dataset_type == 'complete' and training == 'frozen':
        df_data[cfg.get_iteration_col('all')] = pd.NA
    df = pd.DataFrame(df_data)
    tag_columns = [col for col in df.columns if col.startswith('iteration_')]

    # load x_data and y_data
    x_data_original, y_data_original, df = load_data.do(df, dataset, weight_init, training, train_paradigm)

    # initialize loop variables
    model = None
    y_pred = None

    # initialize result variables
    y_pred_all = np.empty((len(tag_columns), y_data_original.shape[0], y_data_original.shape[1]))
    processing_time = np.empty(nr_iterations)
    processing_time_incl_training = np.empty(len(tag_columns))

    # active learning loop
    for index, tag_col in enumerate(tag_columns):
        # start time measure for the iteration
        iteration_al_start = time.time()

        # sample selection via query methods
        if cfg.get_iteration_col(index) == tag_col:
            # columns with budget
            df = active_learning_methods.sample(df, qm, index, model, y_pred, y_data_original, x_data_original,
                                                add_train_samples)
            iteration_al_end = time.time()
            processing_time[index] = iteration_al_end - iteration_al_start
        else:
            # handle last 'all' column (set all not evaluation samples to training)
            df[tag_col] = df[cfg.get_iteration_col(0)]
            df.loc[df[tag_col] != cfg.tag_evaluate, tag_col] = cfg.tag_train

        # status display
        nr_training_samples = len(df[(df[tag_col] == cfg.tag_train)])
        with print_lock:
            print(f'{fname_result}: Training Samples: {nr_training_samples} / {max_train_samples}')

        # train model
        model = train_model.do(df, x_data_original, y_data_original, tag_col, args)

        # save time
        iteration_al_incl_training_end = time.time()
        processing_time_incl_training[index] = iteration_al_incl_training_end - iteration_al_start

        # batchwise for large data, as tensor allocation is a problem
        if dataset in ['mscoco', 'cifar10'] and x_data_original.ndim > 3:
            batch_size = 256
            y_pred_list = []

            for i in range(0, len(x_data_original), batch_size):
                print(f'{fname_result}: Training Samples: {nr_training_samples} / {max_train_samples}, '
                      f'process samples {i} / {len(x_data_original)}')
                batch = x_data_original[i:i + batch_size]
                batch = preprocess_input(batch)
                y_pred_iter = model(batch)
                y_pred_list.append(y_pred_iter)
            y_pred = np.concatenate([y.numpy() for y in y_pred_list], axis=0)
        # get and save predictions
        else:
            y_pred = model(x_data_original)
            y_pred = y_pred.numpy()

        # standardise and save y_pred
        y_pred_all[index, :, :] = y_pred

    # compute and save results
    result_ = result.Result(y_data_original, y_pred_all, df, qm, random_seed, processing_time,
                            processing_time_incl_training, dataset, init_train_samples, add_train_samples,
                            max_train_samples)

    path = Path(cfg.path_results(dataset, dataset_type))
    path.mkdir(parents=True, exist_ok=True)
    with open(Path(path, fname), 'wb') as f:
        pickle.dump(result_, f)
