import ast
import logging

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.loggers import TensorBoardLogger  
from src.models.utils import AlphaRise
import torch

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.set_default_dtype(torch.double)
#torch.autograd.set_detect_anomaly(True)

@hydra.main(config_name='config.yaml', config_path='../config/')
def main(args):

    """Training / evaluation  for Causal CPC.

    Args:
        args: arguments of run as DictConfig

    Returns:
        dict with results (one and nultiple-step-ahead RMSEs)
    """

    results = {}
    OmegaConf.set_struct(args, False)
    OmegaConf.register_new_resolver('sum', lambda x, y: x + y, replace=True)
    logger.info('%s', '\n' + OmegaConf.to_yaml(args, resolve=True))

    seed_everything(args.exp.seed)
    dataset_collection = instantiate(args.dataset, _recursive_=True)

    dataset_collection.process_data_rep_est()
    args.model.dim_outcomes = dataset_collection.train_f.data['outputs'].shape[-1]
    args.model.dim_treatments = dataset_collection.train_f.data[
        'current_treatments'
    ].shape[-1]

    args.model.dim_vitals = (
        dataset_collection.train_f.data['vitals'].shape[-1]
        if dataset_collection.has_vitals
        else 0
    ) # test_robustness

    args.model.dim_vitals = (
        args.model.dim_vitals - 2
        if args.exp.test_robustness
        else args.model.dim_vitals
    ) 

    args.model.dim_static_features = dataset_collection.train_f.data[
        'static_features'
    ].shape[-1]

    # Train_callbacks
    rep_callbacks = []
    if args.exp.logging:
        experiment_name = f'{args.model.name}/{args.dataset.name}'
        mlf_logger  = TensorBoardLogger(save_dir = '.', name=experiment_name)
        rep_callbacks += [LearningRateMonitor(logging_interval='epoch')]
    else:
        mlf_logger = None

    rep_callbacks += [AlphaRise(rate=args.exp.alpha_rate), EarlyStopping(**args.exp.rep_encoder.early_stopping)]
    logging.info("Initialisation & Training of rep")
    rep = instantiate(args.model.rep_encoder, args, dataset_collection, _recursive_=False)

    if args.model.rep_encoder.tune_hparams:
        rep.finetune(resources_per_trial=args.model.rep_encoder.resources_per_trial)
    
    if args.model.pretrain_rep_encoder : 
        rep_trainer = Trainer(
        gpus=ast.literal_eval(str(args.exp.gpus)),
        logger=mlf_logger,
        max_epochs=args.exp.max_epochs,
        callbacks=rep_callbacks,
        terminate_on_nan=True,
        num_sanity_val_steps=0,
    )
        rep_trainer.fit(rep) #! testing 

    # Initialisation & Training of head on original data
    def train_eval_head(
        dataset_collection, init_weight=None, dataset_name='', zero_shot=False
    ):

        head_callbacks = []
        if args.exp.logging:
            head_callbacks += [LearningRateMonitor(logging_interval='epoch')]

        head_callbacks +=  [AlphaRise(rate=args.exp.alpha_rate), EarlyStopping(**args.exp.est_head.early_stopping)]

        logging.info("Instantiate est_head ")
        head = instantiate(
            args.model.est_head,
            args,
            rep,
            dataset_collection,
            prefix=dataset_name,
            _recursive_=False,
        )

        if init_weight is not None:
            head.load_state_dict(init_weight)

        head_trainer = Trainer(
            gpus=ast.literal_eval(str(args.exp.gpus)),
            logger=mlf_logger,
            max_epochs=(0 if zero_shot else args.exp.max_epochs),
            callbacks=head_callbacks,
            terminate_on_nan=True,
            num_sanity_val_steps=0,
        )

        head_trainer.fit(head)

        test_rmses = {}
        if hasattr(dataset_collection, 'test_cf_one_step'):
            test_rmses['{}-encoder_test_rmse_last'.format(dataset_name)] = (
                head.get_normalised_1_step_rmse_syn(
                    dataset_collection.test_cf_one_step,
                    prefix=f'{dataset_name}-test_cf_one_step',
                )
            )

        if hasattr(dataset_collection, 'test_cf_treatment_seq'):
            rmses = head.get_normalised_n_step_rmses_syn(
                dataset_collection.test_cf_treatment_seq,
                prefix=f'{dataset_name}-test_cf_treatment_seq',
            )

        for k, v in enumerate(rmses):
            test_rmses[
                '{}-decoder_test_rmse_{}-step'.format(dataset_name, k + 2)
            ] = v
        if hasattr(dataset_collection, 'test_f'):
            rmses = head.get_normalised_n_step_rmses_real(
                dataset_collection.test_f, prefix=f'{dataset_name}-test_f'
            )
            for k, v in enumerate(rmses):
                test_rmses[
                    '{}-decoder_test_rmse_{}-step'.format(dataset_name, k + 1)
                ] = v

        logger.info('%s', f'Test normalised RMSE (n-step prediction): {test_rmses}')
        
        results.update(test_rmses)

        return head

    def train_context_decoder(
        dataset_collection, init_weight=None, dataset_name='', zero_shot=False
    ):

        ct_callbacks = []
        if args.exp.logging:
            ct_callbacks += [LearningRateMonitor(logging_interval='epoch')]

        ct_callbacks +=  [EarlyStopping(**args.exp.context_decoder.early_stopping)]

        head = instantiate(
            args.model.context_decoder,
            args,
            rep,
            dataset_collection,
            prefix=dataset_name,
            _recursive_=False,
        )

        if init_weight is not None:
            head.load_state_dict(init_weight)

        head_trainer = Trainer(
            gpus=ast.literal_eval(str(args.exp.gpus)),
            logger=mlf_logger,
            max_epochs=args.exp.max_epochs,
            callbacks=ct_callbacks,
            terminate_on_nan=True,
            num_sanity_val_steps=0,
        )

        head_trainer.fit(head)
        
        head_trainer.validate(head)
        
        return head

    if args.model.train_context_decoder:
        decoder = train_context_decoder(
            dataset_collection, init_weight=None, dataset_name='src'
        )
    else : 
        if args.model.train_head:
            pretrained_head = train_eval_head(
                dataset_collection, init_weight=None, dataset_name='src'
            )
    return results


if __name__ == '__main__':
    main()


