import os
import pickle
from datetime import datetime as dt

import numpy as np
import torch

from config import RESULTS_DIR
from src.experiments.cifar import CIFAR
from src.experiments.mnist import MNIST
from src.experiments.kather import KATHER
from src.experiments.derma import DERMA


def create_experiment(args_dict):
    if args_dict['dataset'] == 'mnist':
        return MNIST(args_dict)
    elif 'cifar' in args_dict['dataset']:
        return CIFAR(args_dict)
    elif args_dict['dataset'] == 'kather':
        return KATHER(args_dict)
    elif args_dict['dataset'] == 'derma':
        return DERMA(args_dict)
    else:
        print("Please specify the dataset.")
        return NotImplementedError


def create_folders(args):
    create_folder(RESULTS_DIR)

    # Create a separate folder for each time running the experiment
    ext = f"_p{int(args['p'])}_{args['mc_type']}" if 'kde' in args['loss'] else ""
    loss_param = f"_g{args['loss_param']}" if args['loss'] not in ['mse', 'ce'] else ""
    adaptive = f"_adpt" if args['adaptive'] else loss_param
    classes = f"_{args['classes']}" if args['num_classes'] == 3 and args['classes'] else ""
    depth = f"{args['depth']}" if args['depth'] else ""
    curr_dir = os.path.join(RESULTS_DIR, f"{dt.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args['dataset']}"
                                         f"{args['num_classes']}_{args['model_name']}{depth}_"
                                         f"{args['loss']}{adaptive}{ext}{classes}")
    create_folder(curr_dir)

    models_dir = create_folder(os.path.join(curr_dir, 'models'))
    create_folder(os.path.join(models_dir, 'checkpoints'))
    create_folder(os.path.join(models_dir, 'final'))

    return curr_dir


def create_folder(path):
    try:
        os.mkdir(path)
    except FileExistsError:
        print(f"Directory {path} already exists.")
    except OSError:
        print(f"Creation of the directory {path} failed.")
    else:
        print(f"Successfully created the directory {path}.")

    return path


def load_files(path, folder, logits_name):
    preds_path = os.path.join(path, folder, f'{logits_name}.npy')
    pred_logits = torch.tensor(np.load(preds_path))
    pred_scores = torch.sigmoid(pred_logits) if pred_logits.shape[1] == 1 else torch.softmax(pred_logits, dim=1)

    targets_path = os.path.join(path, folder, 'targets.npy')
    targets = torch.tensor(np.load(targets_path))

    config_path = os.path.join(path, folder, 'config.npy')
    config = np.load(config_path, allow_pickle=True).item()

    assert torch.min(pred_scores) >= 0
    assert torch.max(pred_scores) <= 1

    return pred_logits, pred_scores, targets, config


def load_pickle(path, **kwargs):
    with open(path, 'rb') as fp:
        return pickle.load(fp, **kwargs)


def dump_pickle(path, data, **kwargs):
    with open(path, 'wb') as fp:
        pickle.dump(data, fp, **kwargs)


def str_to_num(integers):
    strings = [str(integer) for integer in integers]
    a_string = "".join(strings)
    return a_string
