import logging
from logging import getLogger

import numpy as np
from causally.utils.arguments import lib_models
from causally.utils.utils import get_model, get_trainer, init_seed,get_function
from causally.utils.logger import init_logger
from causally.utils.logger import set_color
from causally.utils.utils import create_dataset,data_preparation


def run_casual(config=None):

    # configurations initialization

    init_seed(config['seed'], config['reproducibility'])
    # logger initialization
    init_logger(config)
    logger = getLogger()

    logger.info(config)


    results={}

    if config['model'] in lib_models:
        func = get_function(config['model'])
        results = func(config)
    else:
        # dataset filtering
        treated_ratios = []
        control_ratios = []
        while config['start_order'] <= config['end_order']:
            if config['start_order'] == 4 and config['dataset'] == 'ACIC':
                config['start_order'] += 1
                continue
            dataset = create_dataset(config)
            logger.info(dataset)
            logger.info('[{}-{}]'.format(config['model'],config['start_order']))
            # dataset splitting
            train_data, valid_data, test_treated_data,test_control_data,train_treated_data,train_control_data = data_preparation(config, dataset)
            # model loading and initialization

            model = get_model(config['model'])(config, train_data).to(config['device'])
            logger.info(model)

            trainer = get_trainer(config['model'])(config, model)

            best_valid_score = trainer.fit(train_data, valid_data)

            test_result = trainer.evaluate(test_treated_data,test_control_data,train_treated_data,train_control_data)

            logger.info(set_color('best valid ', 'yellow') + f': {best_valid_score}')
            logger.info(set_color('test result', 'yellow') + f': {test_result}')
            if test_result['in_pehe'] > 0:
                for key in test_result:
                    if key not in results.keys():results[key] = []
                    results[key].append(test_result[key])

            treated_ratios.append(config['treated_ratio'])
            control_ratios.append(config['control_ratio'])
            if config['start_order'] >= 100:
                break
            if config['dataset'] == 'ACIC' and config['start_order'] >=10:
                break
            if config['stop']:
                break
            config['start_order'] += 1

        logger.info('\nTreated ratios:{}\nControl ratios:{}'.format(treated_ratios,control_ratios))
        logger.info('\nThe average treated ratio is: {}, control ratio: {}'
                    ''.format(np.mean(treated_ratios),np.mean(control_ratios)))

    return results
