import datetime
import importlib
import random

import numpy as np
import configparser
import os
import torch
from causally.data.dataloader import Dataset,TorchDataLoader,SklearnDataLoader
from causally.utils.arguments import Torch_models

def get_function(model_name):
    if model_name == 'QHTE':
        return qhte
    elif model_name == 'GANITE':
        return ganite
    else:
        raise 'Not support models!!!'

def create_dataset(config):

    return Dataset(config)

def data_preparation(config,dataset):

    train_data = TorchDataLoader(config=config,dataset=dataset.train,batch_size=config['train_batch_size'],shuffle=True)

    val_treated_data = TorchDataLoader(config=config, dataset=dataset.val_treated,
                                        batch_size=config['eval_batch_size'], shuffle=False)
    val_control_data = TorchDataLoader(config=config, dataset=dataset.val_control,
                                        batch_size=config['eval_batch_size'], shuffle=False)
    valid_data={
        'treated': val_treated_data,
        'control': val_control_data
    }
    my_val_data = TorchDataLoader(config=config, dataset=dataset.val,
                                        batch_size=config['eval_batch_size'], shuffle=False)
    test_treated_data  = TorchDataLoader(config=config,dataset=dataset.test_treated,
                                    batch_size=len(dataset.test_treated),shuffle=False)
    test_control_data = TorchDataLoader(config=config, dataset=dataset.test_control,
                                        batch_size=len(dataset.test_control), shuffle=False)

    return train_data, valid_data, my_val_data, test_treated_data, test_control_data


def getRootPath():
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('UITE')+len('UITE')]

    return rootPath

def get_local_time():
    r"""Get current time

    Returns:
        str: current time
    """
    cur = datetime.datetime.now()
    cur = cur.strftime('%b-%d-%Y_%H-%M-%S')

    return cur


def ensure_dir(dir_path):
    r"""Make sure the directory exists, if it does not exist, create it

    Args:
        dir_path (str): directory path

    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def get_model(model_name):
    module_path = '.'.join(['...model', model_name])
    if importlib.util.find_spec(module_path, __name__):
        model_module = importlib.import_module(module_path, __name__)
    else:
        module_path = '.'.join(['...model.recommender', model_name])
        model_module = importlib.import_module(module_path, __name__)
    model_class = getattr(model_module, model_name)
    return model_class


def get_trainer(trainer_name):
    # return getattr(importlib.import_module('causally.trainer'), trainer_name)
    return getattr(importlib.import_module('causally.trainer.'+trainer_name), trainer_name)
    # try:
    #     return getattr(importlib.import_module('causally.trainer'), trainer_name)
    # except:
    #     raise 'Not support trainer!'

def early_stopping(value, best, cur_step, max_step, bigger=True):
    r""" validation-based early stopping

    Args:
        value (float): current result
        best (float): best result
        cur_step (int): the number of consecutive steps that did not exceed the best result
        max_step (int): threshold steps for stopping
        bigger (bool, optional): whether the bigger the better

    Returns:
        tuple:
        - float,
          best result after this step
        - int,
          the number of consecutive steps that did not exceed the best result after this step
        - bool,
          whether to stop
        - bool,
          whether to update
    """
    stop_flag = False
    update_flag = False
    if bigger:
        if value > best:
            cur_step = 0
            best = value
            update_flag = True
        else:
            cur_step += 1
            if cur_step > max_step:
                stop_flag = True
    else:
        if value < best:
            cur_step = 0
            best = value
            update_flag = True
        else:
            cur_step += 1
            if cur_step > max_step:
                stop_flag = True
    return best, cur_step, stop_flag, update_flag


def calculate_valid_score(valid_result, valid_metric=None):

    if valid_metric:
        return valid_result[valid_metric]
    else:
        return valid_result['Recall@10']


def dict2str(result_dict):

    result_str = ''
    for metric, value in result_dict.items():
        result_str += str(metric) + ' : ' + '%.04f' % value + '    '
    return result_str


def init_seed(seed, reproducibility):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if reproducibility:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
