import logging
from logging import getLogger
import torch
import numpy as np
from causally.utils.arguments import lib_models
from causally.utils.utils import get_model, get_trainer, init_seed,get_function
from causally.utils.logger import init_logger
from causally.utils.logger import set_color
from causally.utils.utils import create_dataset,data_preparation


def run_casual(config=None):
    # configurations initialization

    init_seed(config['seed'], config['reproducibility'])
    # logger initialization
    init_logger(config)
    logger = getLogger()
    results={'pehe_in':[], 'pehe_out':[]}

    if config['model'] in lib_models:
        func = get_function(config['model'])
        _, results = func(config)
    else:
        # dataset filtering
        treated_ratios = []
        control_ratios = []
        while config['start_order'] <= config['end_order']:         
            logger.info('[{}-{}]'.format(config['model'],config['start_order']))
            logger.info(config)
            dataset = create_dataset(config)
            logger.info(dataset)             

            train_data, valid_data, test_treated_data, test_control_data, train_treated_data, train_control_data, test_data = data_preparation(config, dataset)

            model = get_model(config['model'])(config, train_data).to(config['device'])
            logger.info(model)

            trainer = get_trainer(config['model'])(config, model)
            best_valid_score = trainer.fit(train_data, valid_data)
            
            test_result = trainer.evaluate(test_treated_data, test_control_data, train_treated_data, train_control_data)
            logger.info(set_color('best valid ', 'yellow') + f': {best_valid_score}')
            logger.info(set_color('test result', 'yellow') + f': {test_result}')
            if test_result['pehe_out'] and test_result['pehe_in']> 0:
                results['pehe_in'].append(test_result['pehe_in'])
                results['pehe_out'].append(test_result['pehe_out'])

            treated_ratios.append(config['treated_ratio'])
            control_ratios.append(config['control_ratio'])

            if config['stop']:
                break                
            config['start_order'] += 1

        logger.info('\nTreated ratios:{}\nControl ratios:{}'.format(treated_ratios,control_ratios))
        logger.info('\nThe average treated ratio is: {}, control ratio: {}'
                    ''.format(np.mean(treated_ratios),np.mean(control_ratios)))

    return results#, Y0_hat, Y1_hat
