import datetime
import importlib
import random

import numpy as np
import configparser
import os
import torch
from causally.data.dataloader import Dataset,TorchDataLoader,SklearnDataLoader
from causally.utils.arguments import Torch_models,SK_models,DF_models,recommender_models
from causally.model.qhte.qhte import qhte
from causally.model.ganite.ganite import ganite
from causally.data.recommender_dataloader import Recommender_Dataset,RecommenderDataLoader
def get_function(model_name):
    if model_name == 'QHTE':
        return qhte
    elif model_name == 'GANITE':
        return ganite
    else:
        raise 'Not support models!!!'

def create_dataset(config):
    if config['model'] in recommender_models:
        return Recommender_Dataset(config)
    else:
        return Dataset(config)

def data_preparation(config,dataset):
    train_treated_data,train_control_data = None,None
    if config['model'] in Torch_models or config['model'] in DF_models:

        train_data = TorchDataLoader(config=config,dataset=dataset.train,
                                     batch_size=config['train_batch_size'],shuffle=True)
        if config['robustness']:
            config['testing'] = True

        valid_data = TorchDataLoader(config=config, dataset=dataset.val,
                                            batch_size=config['eval_batch_size'], shuffle=False)
        test_treated_data  = TorchDataLoader(config=config,dataset=dataset.test_treated,
                                     batch_size=config['eval_batch_size'],shuffle=False)
        test_control_data = TorchDataLoader(config=config, dataset=dataset.test_control,
                                            batch_size=config['eval_batch_size'], shuffle=False)
        train_treated_data = TorchDataLoader(config=config, dataset=dataset.train_treated,
                                            batch_size=config['eval_batch_size'], shuffle=False)
        train_control_data = TorchDataLoader(config=config, dataset=dataset.train_control,
                                            batch_size=config['eval_batch_size'], shuffle=False)

        if  config['testing']:
            config['testing'] = False

    elif config['model'] in SK_models:

        train_data = SklearnDataLoader(config,dataset.train)
        valid_data = SklearnDataLoader(config,dataset.val)
        test_treated_data = SklearnDataLoader(config, dataset.test_treated)
        test_control_data = SklearnDataLoader(config, dataset.test_control)
        train_treated_data = SklearnDataLoader(config, dataset.train_treated)
        train_control_data = SklearnDataLoader(config, dataset.train_control)

    elif config['model'] in recommender_models:

        train_data = RecommenderDataLoader(config=config, dataset=dataset.train,
                                     batch_size=config['train_batch_size'], shuffle=True)

        val_treated_data = RecommenderDataLoader(config=config, dataset=dataset.val_treated,
                                           batch_size=config['eval_batch_size'], shuffle=False)
        val_control_data = RecommenderDataLoader(config=config, dataset=dataset.val_control,
                                           batch_size=config['eval_batch_size'], shuffle=False)
        valid_data = {
            'treated': val_treated_data,
            'control': val_control_data
        }
        test_treated_data = RecommenderDataLoader(config=config, dataset=dataset.test_treated,
                                            batch_size=config['eval_batch_size'], shuffle=False)
        test_control_data = RecommenderDataLoader(config=config, dataset=dataset.test_control,
                                            batch_size=config['eval_batch_size'], shuffle=False)

    else:
        raise 'No supported model!'

    return train_data, valid_data, test_treated_data,test_control_data,train_treated_data,train_control_data


def getRootPath():
    curPath = os.path.abspath(os.path.dirname(__file__))
    rootPath = curPath[:curPath.find('UITE')+len('UITE')]

    return rootPath

def get_local_time():
    r"""Get current time

    Returns:
        str: current time
    """
    cur = datetime.datetime.now()
    cur = cur.strftime('%b-%d-%Y_%H-%M-%S')

    return cur


def ensure_dir(dir_path):
    r"""Make sure the directory exists, if it does not exist, create it

    Args:
        dir_path (str): directory path

    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def get_model(model_name):
    module_path = '.'.join(['...model', model_name])
    if importlib.util.find_spec(module_path, __name__):
        model_module = importlib.import_module(module_path, __name__)
    else:
        module_path = '.'.join(['...model.recommender', model_name])
        model_module = importlib.import_module(module_path, __name__)
    model_class = getattr(model_module, model_name)
    return model_class


def get_trainer(model):

    if model in Torch_models:
        return getattr(importlib.import_module('causally.trainer'), 'TorchTrainer')
    elif model in SK_models:
        return getattr(importlib.import_module('causally.trainer'), 'SKTrainer')
    elif model in DF_models:
        return getattr(importlib.import_module('causally.trainer'), 'DFTrainer')
    elif model in recommender_models:
        return getattr(importlib.import_module('causally.trainer'), 'RecommenderTrainer')
    else:
        raise 'Not support trainer!'

def early_stopping(value, best, cur_step, max_step, bigger=False):
    r""" validation-based early stopping

    Args:
        value (float): current result
        best (float): best result
        cur_step (int): the number of consecutive steps that did not exceed the best result
        max_step (int): threshold steps for stopping
        bigger (bool, optional): whether the bigger the better

    Returns:
        tuple:
        - float,
          best result after this step
        - int,
          the number of consecutive steps that did not exceed the best result after this step
        - bool,
          whether to stop
        - bool,
          whether to update
    """
    stop_flag = False
    update_flag = False
    if bigger:
        if value > best:
            cur_step = 0
            best = value
            update_flag = True
        else:
            cur_step += 1
            if cur_step > max_step:
                stop_flag = True
    else:
        if value < best:
            cur_step = 0
            best = value
            update_flag = True
        else:
            cur_step += 1
            if cur_step > max_step:
                stop_flag = True
    return best, cur_step, stop_flag, update_flag


def calculate_valid_score(valid_result, valid_metric=None):

    if valid_metric:
        return valid_result[valid_metric]
    else:
        return valid_result['Recall@10']


def dict2str(result_dict):

    result_str = ''
    for metric, value in result_dict.items():
        result_str += str(metric) + ' : ' + '%.04f' % value + '    '
    return result_str


def init_seed(seed, reproducibility):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if reproducibility:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
