import logging
from logging import getLogger
from numpy.testing import measure
from pandas._libs.tslibs import add_overflowsafe
import torch
import numpy as np

from causally.trainer.CARD_trainer import CARD_trainer
from causally.utils.utils import get_model, get_trainer, init_seed,get_function
from causally.utils.logger import init_logger
from causally.utils.logger import set_color
from causally.utils.utils import create_dataset,data_preparation
import causally.start.autoencoder as ae
import causally.start.TabDDPMdiff as TabDiff
from causally.start.TabDDPMdiff import MLPDiffusion
import causally.start.diffusion as diff
from causally.utils.arguments import Torch_models
from causally.start.autoencoder import DeapStack
from torch.utils.tensorboard import SummaryWriter

def run_casual(config=None):
    # configurations initialization
    init_seed(config['seed'], config['reproducibility'])
    normal_test_result = None
    hidden_test_result = None
    measure_test_result = None
    missing_test_result = None


    results={'ate':[],'pehe':[]}
    measure_results = {}
    measure_ratio = [0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
    for mea in measure_ratio:
        measure_results['ate_'+str(mea)] = []
        measure_results['pehe_'+str(mea)] = []
    
    missing_results={}
    missing_ratio = [0.01, 0.05, 0.1,0.2,0.3,0.4,0.5]
    for mea in missing_ratio:
        missing_results['ate_'+str(mea)] = []
        missing_results['pehe_'+str(mea)] = []
    
    hidden_results={}
    hidden_ratio = [1,5,10,15,20,25,30]
    for mea in hidden_ratio:
        hidden_results['ate_'+str(mea)] = []
        hidden_results['pehe_'+str(mea)] = []

    # dataset filtering
    treated_ratios = []
    control_ratios = []
    while config['start_order'] <= config['end_order']:
        # logger initialization
        logfilename = init_logger(config)
        config['logfilename'] = logfilename
        logger = getLogger(logfilename)
        logger.info(config)
        dataset = create_dataset(config)
        logger.info(dataset)

        writer = None

        if config['trainer'] == 'CARD_trainer':
            logger.info('[{}-startorder={}-{}-{}-{}]'.format('RL',config['start_order'],config['model'],config['dataset'],config['v_rl_algo']))
        elif config['trainer'] == 'standard_trainer':
            logger.info('[{}-startorder={}-{}-{}]'.format('IL',config['start_order'],config['model'],config['dataset']))

        train_data, valid_data, my_val_data, test_treated_data, test_control_data = data_preparation(config, dataset)

        if config['model'] in Torch_models:
            model = get_model(config['model'])(config, train_data).to(config['device'])
        else:
            logger.info('model not in lib_models')

        logger.info(model)

        if config['auto_diff']:
    
            lr = config['autodiff_lr'] 
            weight_decay = config['autodiff_weight_decay'] 
            batch_size = config['autodiff_batch_size'] 

            n_epochs = config['autodiff_ae_n_epochs'] 
            hidden_size = config['autodiff_ae_hidden_size'] 
            num_layers = config['autodiff_ae_num_layers'] 

            diff_n_epochs = config['autodiff_diff_n_epochs'] 
            eps = config['autodiff_eps'] 
            sigma = config['autodiff_sigma']  
            num_batches_per_epoch = config['autodiff_num_batches_per_epoch'] 
            maximum_learning_rate = config['autodiff_maximum_learning_rate'] 
            threshold = config['autodiff_threshold'] 
            T = config['n_steps']   
            
            print(dataset.train.columns)
            real_df = dataset.train.iloc[:,6:] 

            if config['pretrain_then_optimize'] == True:
                ds = ae.train_autoencoder(real_df, hidden_size, num_layers, lr, weight_decay, n_epochs, batch_size, threshold, config,writer)
                latent_features = ds[1].detach()
                config['latent_features_shape1'] = latent_features.shape[1]
                score = TabDiff.train_diffusion(latent_features, T, eps, sigma, lr, num_batches_per_epoch, maximum_learning_rate, weight_decay, diff_n_epochs, batch_size, config, writer)
            else:

                loaded_checkpoint = torch.load(f'pretrain/ACIC-StandardScalerX_{config["start_order"]}.pth',
                                                map_location=torch.device(config["device"]),weights_only=False)

                n_bins, n_cats, n_nums, cards, d_in, hidden_size, bottleneck_size, num_layers = ae.get_ae_recover_param(real_df,config,threshold)
                DS = DeapStack(n_bins, n_cats, n_nums, cards, d_in, hidden_size, bottleneck_size, num_layers).to(config['device'])
                DS.load_state_dict(loaded_checkpoint['DS_decoder'])
                ds = (DS.decoder,
                        loaded_checkpoint['latent_features'],
                        loaded_checkpoint['num_min_values'],
                        loaded_checkpoint['num_max_values'],
                        DS.featurize,
                        loaded_checkpoint['parser'])
                
                latent_features = loaded_checkpoint['latent_features']
                print('fineture latent feature shape = ',latent_features.shape)
                rtdl_params={
                    'd_in': latent_features.shape[1],
                    'd_layers': [256,256],
                    'dropout': 0.0,
                    'd_out': latent_features.shape[1],
                }
                score = MLPDiffusion(latent_features.shape[1], rtdl_params).to(config['device'])
                is_load_dm_checkpoint = True#  False #  
                if is_load_dm_checkpoint:
                    score.load_state_dict(loaded_checkpoint['score'])
 
                # score = score.to(config['device'])
                if config['device'] == 'cpu':
                    # cpu use
                    score = score.to(config['device'])
                else:
                    # only cuda use
                    score = torch.nn.DataParallel(score)
                    score = score.to(config['device'])

            trainer = CARD_trainer(config, model, score, ds, real_df,writer)
            best_valid_score = trainer.fit(train_data, valid_data, my_val_data, test_treated_data,test_control_data)
            normal_test_result = trainer.evaluate(test_treated_data,test_control_data)
            measure_test_result = trainer.adversarial_measure_evaluate(test_treated_data,test_control_data,measure_ratio)
            missing_test_result = trainer.adversarial_missing_evaluate(test_treated_data,test_control_data,missing_ratio)
            hidden_test_result = trainer.adversarial_hidden_evaluate(test_treated_data,test_control_data,hidden_ratio)

        else: # standard trainer
            trainer = get_trainer(config['trainer'])(config, model)
            best_valid_score = trainer.fit(train_data, valid_data, my_val_data,test_treated_data,test_control_data)
            normal_test_result = trainer.evaluate(test_treated_data,test_control_data)
            hidden_test_result = trainer.adversarial_hidden_evaluate(test_treated_data,test_control_data,hidden_ratio)
            measure_test_result = trainer.adversarial_measure_evaluate(test_treated_data,test_control_data,measure_ratio)
            missing_test_result = trainer.adversarial_missing_evaluate(test_treated_data,test_control_data,missing_ratio)
            
        logger.info(set_color('best valid ', 'yellow') + f': {best_valid_score}')
        logger.info(set_color('test result', 'yellow') + f': {normal_test_result}')
        logger.info(set_color('measure test result', 'yellow') + f': {measure_test_result}')
        logger.info(set_color('missing test result', 'yellow') + f': {missing_test_result}')
        logger.info(set_color('hidden test result', 'yellow') + f': {hidden_test_result}')

        results['ate'].append(float(normal_test_result['ate']))
        results['pehe'].append(float(normal_test_result['pehe']))
        for mea in measure_ratio:
            measure_results['ate_'+str(mea)].append(float(measure_test_result['ate_'+str(mea)]))
            measure_results['pehe_'+str(mea)].append(float(measure_test_result['pehe_'+str(mea)]))
        for mea in missing_ratio:
            missing_results['ate_'+str(mea)].append(float(missing_test_result['ate_'+str(mea)]))
            missing_results['pehe_'+str(mea)].append(float(missing_test_result['pehe_'+str(mea)]))
        for mea in hidden_ratio:
            hidden_results['ate_'+str(mea)].append(float(hidden_test_result['ate_'+str(mea)]))
            hidden_results['pehe_'+str(mea)].append(float(hidden_test_result['pehe_'+str(mea)]))

        treated_ratios.append(config['treated_ratio'])
        control_ratios.append(config['control_ratio'])

        if config['stop']:
            break
        config['start_order'] += 1
        
    logger.info('\nTreated ratios:{}\nControl ratios:{}'.format(treated_ratios,control_ratios))
    logger.info('\nThe average treated ratio is: {}, control ratio: {}'
                ''.format(np.mean(treated_ratios),np.mean(control_ratios)))

    return results, measure_results, missing_results, hidden_results

