import enum
import logging
import matplotlib.pyplot as plt
import mergedeep
import numpy as np
import os
import pandas as pd
import random
import torch
import yaml
from collections import OrderedDict
from collections.abc import Iterable
from pathlib import Path
from sklearn.model_selection import train_test_split
import csv


logger = logging.getLogger(__name__)


class RunPhase(enum.Enum):
    TRAIN = 'train'
    TEST = 'test'


def set_gpu(params):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = params['device_config']['gpu']

    params['cuda'] = torch.cuda.is_available()
    is_cuda = params['cuda']
    logger.info(f'cuda is available?: {is_cuda}')


def load_yaml(filename: str):
    with open(filename) as f:
        d = yaml.load(f, Loader=yaml.FullLoader)
    return d


def yaml_load(yaml_path, override_yaml_path):
    assert os.path.isfile(yaml_path), f"No yaml configuration file found at {yaml_path}"
    params = load_yaml(yaml_path)

    if override_yaml_path:
        assert os.path.isfile(override_yaml_path), f"No yaml configuration file found at {override_yaml_path}"
        params = mergedeep.merge(params, load_yaml(override_yaml_path))

    return params


def set_seed(params):
    # Set the random seed for reproducible experiments
    random.seed(params['config']['seed'])
    np.random.seed(params['config']['seed'])
    os.environ['PYTHONHASHSEED'] = str(params['config']['seed'])
    torch.manual_seed(params['config']['seed'])

    if params['cuda']:
        torch.cuda.manual_seed(params['config']['seed'])


def make_save_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def save_checkpoint(state, checkpoint, filename):
    filepath = os.path.join(checkpoint, filename)
    if not os.path.exists(checkpoint):
        os.mkdir(checkpoint)
    torch.save(state, filepath)


def save_yaml(path, params):
    with open(path, 'w') as file:
        yaml.dump(params, file)


def save_df(params, outputs, labels, names, task, met, save_fname, test_save_dir):
    if not os.path.exists(test_save_dir):
        os.makedirs(test_save_dir)
    names = names.reshape([-1, 1])
    df = np.concatenate([names, labels, outputs], 1)
    label_cols = params['preproc']['task']
    prob_cols = [x + '_prob' for x in label_cols]
    df = pd.DataFrame(df, columns=['name'] + label_cols + prob_cols)
    save_path = os.path.join(test_save_dir, save_fname)
    df.to_csv(save_path + '.csv', index=False)


def wirte_log_csv(path, name, summary_dict, epoch="", itr="", val_set_name=""):
    file_path = os.path.join(path, name)
    # if epoch and itr and val_set_name:  # for validation log
    dict_lst = []
    if epoch and itr:  # for validation log
        dict_lst.append(dict(epoch=epoch, itr=itr, **summary_dict))
    else:
        dict_lst.append(dict(**summary_dict))

    if not os.path.exists(file_path):
        with open(file_path, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=list(dict_lst[0].keys()))
            writer.writeheader()
            for row_dict in dict_lst:
                writer.writerow(row_dict)
    else:
        with open(file_path, 'a') as f:
            writer = csv.DictWriter(f, fieldnames=list(dict_lst[0].keys()))
            for row_dict in dict_lst:
                writer.writerow(row_dict)


# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_classification.py#L1287
def prf_divide(numerator, denominator, zero_divide_return=0.):
    mask = denominator == 0.
    if not isinstance(mask, Iterable):
        result = np.ones_like(numerator) * zero_divide_return if mask else numerator / denominator
        return result
    denominator = denominator.copy()
    denominator[mask] = 1
    result = numerator / denominator
    if not np.any(mask):
        return result
    result[mask] = zero_divide_return
    return result
