""" Evaluates a generative model on a split of a specified data set """

import argparse
from functools import partial
from collections import defaultdict
import logging
import os
from typing import Any, Callable, Dict, List, Optional

from accelerate import Accelerator
from accelerate.tracking import WandBTracker
import numpy as np
import torch
import yaml
from tqdm import tqdm

from text2graph.models.base_model import BaseModel
from text2graph.models.model_loader import init_and_load_models
from text2graph.data.data_loader import init_dataloader
from text2graph.training.losses_and_metrics import MetricFactory
from text2graph.training.trainer import read_wandb_secret

logger = logging.getLogger(__name__)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'


def process_batch(
    sample_batched: Dict[str, torch.Tensor],
    model: BaseModel,
    model_info: Dict[str, Any],
    eval_function: Callable,
    metric_names: List[str],
    device: torch.device
) -> Dict[str, List[float]]:
    metrics = {}
    model_outputs = {
        'dataset' : {
            key: sample_batched[key].to(device)
            if isinstance(sample_batched[key], torch.Tensor) else sample_batched[key]
            for key in sample_batched.keys()
        }
    }
    model_outputs['dataset']['eval_function'] = eval_function
    for name in metric_names:
        model_outputs['dataset'][name] = name
    loss = torch.zeros(1).float().to(device)
    model_outputs[model.name] = model(model_outputs['dataset'])
    model_outputs[model.name]['ground_truths'] = model.inputs2graph(
        model_outputs['dataset'],
        model.tokenizer,
        file_paths=model_outputs['dataset'].get('file_path', None)
    )
    for call_name, call_dict in model_info['gen']['calls'].items():
        generated_graphs, _ = model.generate_graph(
            **{**model_outputs['dataset'], **call_dict}
        )
        model_outputs[model.name][f"{call_name}_graphs_gen"] = generated_graphs
    for eval_name, eval_args in model_info['evaluations'].items():
        metric_name = eval_name.split(',')[0]
        input_list = [
            model_outputs[i_source][i_name]
            for i_name, i_source in eval_args['inputs'].items()
            if i_source in model_outputs and i_name in model_outputs[i_source]
        ]
        assert len(input_list) == len(eval_args['inputs'])
        metric_value = MetricFactory[metric_name].value(*tuple(input_list)).mean()
        if eval_args['is_loss']:
            loss += metric_value * eval_args.get('weight', 1.0)
        metrics[eval_args['logging_name']] = metric_value.item()
    metrics['aggregate_loss'] = loss.item()
    return metrics


def get_evaluation_config(model_path: str, dataset_path: Optional[str] = None) -> Dict[str, Any]:
    """ Returns a dictionary containing the configuration run parameters for a specific model
        trained on a specific data set
    """
    with open(os.path.join(model_path, 'metadata.yaml'), "r", encoding='utf-8') as ymlfile:
        training_metadata = yaml.safe_load(ymlfile)
    dataset_type = training_metadata['dataset_config']['type']
    repository_directory= os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    with open(
        os.path.join(repository_directory, "configs", f"{dataset_type}_evaluation.example.yml"),
        "r",
        encoding='utf-8'
    ) as ymlfile:
        cfg = yaml.safe_load(ymlfile)
    cfg['dataset_config']['type'] = training_metadata['dataset_config']['type']
    cfg['dataset_config']['name'] = training_metadata['dataset_config']['name']
    cfg['dataset_config']['path'] = (
        dataset_path if dataset_path is not None else training_metadata['dataset_config']['path']
    )
    assert len(cfg['info_flow']) == 1 and len(training_metadata['info_flow']) == 1
    default_model_type = list(cfg['info_flow'].keys())[0]
    model_type = list(training_metadata['info_flow'].keys())[0]
    cfg['info_flow'][model_type] = cfg['info_flow'].pop(default_model_type)
    cfg['info_flow'][model_type]['model_dir'] = training_metadata['info_flow'][model_type]['metadata']['model_dir']
    cfg['info_flow'][model_type]['metadata'] = training_metadata['info_flow'][model_type]['metadata']
    cfg['info_flow'][model_type]['inputs'] = training_metadata['info_flow'][model_type]['inputs']
    return cfg


def evaluate(model_path: str, split_name: str = 'test', dataset_path: Optional[str] = None):
    """ Evaluates a generative model on a specific split of the data set it was trained on """
    cfg = get_evaluation_config(model_path=model_path, dataset_path=dataset_path)
    accelerator = Accelerator()
    model_type = list(cfg['info_flow'].keys())[0]
    model_dict = init_and_load_models(
        info_flow=cfg['info_flow'],
        saving_directory=cfg['info_flow'][model_type]['model_dir'],
        accelerator=accelerator
    )
    model_name = list(model_dict.keys())[0]
    model = model_dict[model_name]
    model.to(accelerator.device)
    model_info = cfg['info_flow'][model_name]
    model.eval()
    test_dataloader = init_dataloader(
        split_name=split_name,
        dataset_config=cfg['dataset_config'],
        collate_function=model.get_collate_fn(),
        shuffle=False,
        multiprocessing_flag=False
    )
    test_dataloader, model = accelerator.prepare(test_dataloader, model)
    read_wandb_secret()
    if accelerator.is_main_process:
        tracker = WandBTracker(
            run_name=f"text2graph-{split_name}-{test_dataloader.dataset.name}",
            name=model.name,
            config=cfg
        )
    aggregate_metrics = defaultdict(list)
    process_batch_w_arguments = partial(
        process_batch,
        model=model,
        model_info=model_info,
        eval_function=test_dataloader.dataset.calculate_metrics_batch,
        metric_names=test_dataloader.dataset.metric_names(),
        device=accelerator.device
    )
    with torch.no_grad():
        with tqdm(
            test_dataloader,
            unit=" batch",
            disable=not accelerator.is_main_process
        ) as tepoch:
            for sample_batched in tepoch:
                batch_metrics = process_batch_w_arguments(sample_batched)
                for key, val in batch_metrics.items():
                    aggregate_metrics[key].append(val)
                if accelerator.is_main_process:
                    metrics = {
                        key: np.round(np.nanmean(value), 3)
                        for key, value in aggregate_metrics.items()
                    }
                    tracker.log({key: value.mean().item() for key, value in metrics.items()})
                    tepoch.set_postfix(**metrics)
        metrics = {
            key: torch.from_numpy(np.nanmean(value).reshape(-1)).float().to(accelerator.device).detach()
            for key, value in aggregate_metrics.items()
        }
        accelerator.wait_for_everyone()
        metrics = accelerator.gather_for_metrics(metrics)
    if accelerator.is_main_process:
        tracker.log({f"mean_{key}": value.mean().item() for key, value in metrics.items()})


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluating text graph generators')
    parser.add_argument('--model-path', type=str, required=False)
    parser.add_argument('--split-name', type=str, default='test')
    parser.add_argument('--dataset-path', type=str, required=False)
    args = parser.parse_args()
    if args.model_path is None:
        evaluation_path = os.path.dirname(os.path.realpath(__file__))
        with open(f"{evaluation_path}/model_path.txt", 'r', encoding='utf-8') as text_file:
            args.model_path = text_file.readlines()[0]

    evaluate(
        model_path=args.model_path,
        split_name=args.split_name,
        dataset_path=args.dataset_path
    )
