import logging
from pathlib import Path
import traceback

from dotenv import load_dotenv
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
import wandb

from text_ood import EmbeddingType
from text_ood.ood import OODEvaluator
from text_ood.utils import set_seed
from text_ood.utils.dataset_util import load_dataset_embeddings


logger = logging.getLogger(__name__)


def run_method(method, id_dataset_train, id_dataset_train_exclude_indices, id_dataset_test, aux_dataset, ood_datasets, embedding_type, config):
    method = instantiate(method)

    if method.requires_id:
        id_dataset = load_dataset_embeddings(
            id_dataset_train,
            embedding_type,
            config,
            exclude_indices=id_dataset_train_exclude_indices,
        )

        logger.info(f'Fitting method {method}')
        if method.requires_aux:
            logger.info(f'{method} requires auxiliary data. Loading...')
            aux_dataset_ = load_dataset_embeddings(
                aux_dataset,
                embedding_type,
                config,
                n_data=config.n_aux_data,
            )
            method.fit(id_dataset, aux_dataset_)
            del aux_dataset_
        else:
            method.fit(id_dataset)

        del id_dataset

    ood_train_evaluator = OODEvaluator(
        id_dataset=id_dataset_train,
        out_datasets=[id_dataset_test] + [aux_dataset] + ood_datasets,
        metrics=instantiate(config.metrics),
        logger=instantiate(config.train_logger),
        batch_size=config.batch_size,
        device=config.device,
        embedding_type=embedding_type,
        config=config,
    )
    ood_aux_evaluator = OODEvaluator(
        id_dataset=id_dataset_test,
        out_datasets=[aux_dataset],
        metrics=instantiate(config.metrics),
        logger=instantiate(config.aux_logger),
        batch_size=config.batch_size,
        device=config.device,
        embedding_type=embedding_type,
        config=config,
    )
    ood_test_evaluator = OODEvaluator(
        id_dataset=id_dataset_test,
        out_datasets=ood_datasets,
        metrics=instantiate(config.metrics),
        logger=instantiate(config.test_logger),
        batch_size=config.batch_size,
        device=config.device,
        embedding_type=embedding_type,
        config=config,
    )

    if config.eval_train:
        ood_train_evaluator.evaluate(method.predict, epoch=None)
    if config.eval_aux:
        ood_aux_evaluator.evaluate(method.predict, epoch=None)
    if config.eval_test:
        scores = ood_test_evaluator.evaluate(method.predict, epoch=None)
        return scores
    return


@hydra.main(config_path='config_run_methods', config_name='summarization-pegasus-xsum-input', version_base='1.2')
def main(config):
    
    try:
        wandb.init(
            project='token-based-ood-run-methods-2.0',
            config={
                'hydra': OmegaConf.to_container(
                    config,resolve=True,
                    throw_on_missing=True
        )})
        set_seed(config.seed + config.addtl_seed)
        print(OmegaConf.to_yaml(config))
        
        embedding_type: EmbeddingType = EmbeddingType[config.embedding_type]
        
        metrics = run_method(
            method=config.method,
            id_dataset_train=config.data.id.train.path,
            id_dataset_train_exclude_indices=config.data.id.train.exclude_indices,
            id_dataset_test=config.data.id.test,
            aux_dataset=config.data.aux,
            ood_datasets=config.data.ood,
            embedding_type=embedding_type,
            config=config
        )
        wandb.finish()
    except Exception as e:
        logging.error(traceback.format_exc())
        raise e


if __name__ == '__main__':
    load_dotenv()
    main()
