import logging
from logging import getLogger

import numpy as np
from causally.trainer.CARD_trainer import CARD_trainer
from causally.utils.utils import get_model, get_trainer, init_seed,get_function
from causally.utils.logger import init_logger
from causally.utils.logger import set_color
from causally.utils.utils import create_dataset,data_preparation
import causally.start.autoencoder as ae
import causally.start.TabDDPMdiff as TabDiff
import causally.start.diffusion as diff
from causally.utils.arguments import Torch_models

from torch.utils.tensorboard import SummaryWriter
import torch

def run_pretrain(config=None):
    # configurations initialization
    init_seed(config['seed'], config['reproducibility'])
    # logger initialization
    logfilename = init_logger(config)
    logger = getLogger('pretrain_'+logfilename)


    logger.info(config)
    dataset = create_dataset(config)
    logger.info(dataset)

    while config['start_order'] <= config['end_order']: 

        logger.info('[{}-{}-{}]'.format(config['model'],config['trainer'],config['start_order']))

        lr = config['autodiff_lr'] 
        weight_decay = config['autodiff_weight_decay'] 
        batch_size = config['autodiff_batch_size'] 

        n_epochs = config['autodiff_ae_n_epochs'] 
        hidden_size = config['autodiff_ae_hidden_size'] 
        num_layers = config['autodiff_ae_num_layers'] 

        diff_n_epochs = config['autodiff_diff_n_epochs'] 
        eps = config['autodiff_eps'] 
        sigma = config['autodiff_sigma']  
        num_batches_per_epoch = config['autodiff_num_batches_per_epoch'] 
        maximum_learning_rate = config['autodiff_maximum_learning_rate'] 
        threshold = config['autodiff_threshold'] 
        T = config['n_steps'] 
        
        print(dataset.train.columns)
        real_df = dataset.train.iloc[:,6:] 

        ds = ae.train_autoencoder(real_df, hidden_size, num_layers, lr, weight_decay, n_epochs, batch_size, threshold, config,writer)

        latent_features = ds[1].detach()

        config['latent_features_shape1'] = latent_features.shape[1]

        score = TabDiff.train_diffusion(latent_features, T, eps, sigma, lr, num_batches_per_epoch, maximum_learning_rate, weight_decay, diff_n_epochs, batch_size, config, writer)

        pretrain_checkpoint = {
            'DS_decoder': ds[4].state_dict(),  
            'latent_features': ds[1],  
            'num_min_values': ds[2],  
            'num_max_values': ds[3],  
            'score': score.state_dict(),
            'parser': ds[5],
        }

        torch.save(pretrain_checkpoint, 'pretrain/{}_{}.pth'.format(config['dataset'],config['start_order']))    
        config['start_order'] += 1