import os

import torch
from util import *
from training import VAEMetrics
from parse_args import parse_arguments
import logging
from experiment_utils import *
import pandas as pd
from main import set_logger, set_model, set_loader
from visualize import *
import datetime
import cv2


@torch.no_grad()
def evaluate(args, model, criteria, train_loader, test_loader, model_name):
    """High-level test function"""

    model.eval()

    logging.info("Starting evaluation")

    # Compute metrics
    _, _, test_metrics = compute_metrics_on_dataset(args, model, criteria, test_loader)
    # test_metrics = {}

    # Print DCI disentanglement score
    if args.dataset == "synthetic":
        dci = eval_dci_scores(args, model, train_loader, test_loader, partition="test")
        test_metrics.update(dci)

        if args.model in ["ilcm", "icrl"] and args.dim_z <= 5:
            results = eval_implicit_graph(args, model, partition="test")
            test_metrics.update(results)
            causal_effect = np.zeros((model.dim_z, model.dim_z))
            for i in range(model.dim_z):
                for j in range(model.dim_z):
                    causal_effect[i, j] = test_metrics[f"implicit_graph_{i}_{j}"]

            logging.info(f"\nCausal effects =\n {causal_effect}")

        if args.model == "icrl" and args.dim_z <= 5:
            enco_graph = eval_enco_graph(args, model, partition="test")
            test_metrics.update(enco_graph)

            filename = Path(args.path_data) / "test.pt"
            data = torch.load(filename)
            x_test, x_test_tilde, *_ = TensorDataset(*data).tensors
            e_e_corr, n_e_corr, shift_e_corr, scales_e_corr = model.latents_correlations(x_test.cuda(),
                                                                                         x_test_tilde.cuda())
            logging.info(f"\nNoise and Exogenous Correlations =\n {n_e_corr}")
            logging.info(f"\nExogenous and Exogenous Correlations =\n {e_e_corr}")
            logging.info(f"\nShifts and Exogenous Correlations =\n {shift_e_corr}")
            logging.info(f"\nScales and Exogenous Correlations =\n {scales_e_corr}")

            graph = np.zeros((model.dim_z, model.dim_z))
            for i in range(model.dim_z):
                for j in range(model.dim_z):
                    graph[i, j] = test_metrics[f"enco_graph_{i}_{j}"]

            logging.info(f"\nENCO graph =\n {graph}")

        logging.info(
            f"\ncausal disentanglement = {test_metrics['causal_disentanglement']:.2f}"
        )
        logging.info(
            f"\ncausal completeness = {test_metrics['causal_completeness']:.2f}"
        )

        dec = 4
        causal_importance_matrix = np.zeros((args.dim_z, args.dim_z))
        for r in range(args.dim_z):
            for c in range(args.dim_z):
                causal_importance_matrix[r, c] = round(test_metrics[f'causal_importance_matrix_{r}_{c}'], dec)
        logging.info('\n------- Causal Importance Matrix ------')
        logging.info(causal_importance_matrix)

    else:
        test_metrics.update(eval_accuracy(args, model, test_loader))
        logging.info(f"Action accuracy = {test_metrics['action_accuracy']:.2f}")
        logging.info(f"Object accuracy = {test_metrics['object_accuracy']:.2f}")

    # Store results in csv file
    # Pandas does not like scalar values, have to be iterables
    test_metrics_ = {key: [val.cpu() if isinstance(val, torch.Tensor) and val.is_cuda else val] for key, val in
                     test_metrics.items()}
    df = pd.DataFrame.from_dict(test_metrics_)

    folder_name = os.path.join(args.expdir, f"{model_name}/test/metrics")
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    df.to_csv(f'{folder_name}/test_metrics.tsv')

    return test_metrics


def main(args):
    model_name = os.path.basename(os.path.dirname(os.path.dirname(args.ckpt)))
    args.expdir = os.path.dirname(os.path.dirname(os.path.dirname(args.ckpt)))
    set_seed(args.seed)
    set_logger(args, model_name, stage='test')

    loader_train, loader_valid, loader_test, loader_ood, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun = set_loader(
        args)

    if args.model is None:
        output_path = f'fig/pairs/'
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        for idx, (first, second, verb, noun, _, _) in enumerate(loader_test.dataset):
            show_pair(first, second, dict_verb_class[verb], dict_noun_class[noun], suffix=f'_iid_{idx}',
                      savedir=output_path)
            if idx % 10 == 0:
                print(f'# {idx} / {len(loader_train.dataset)}')
            if idx >= 10:
                break
        for idx, (first, second, verb, noun, _, _) in enumerate(loader_ood.dataset):
            show_pair(first, second, dict_verb_class[verb], dict_noun_class[noun], suffix=f'_ood_{idx}',
                      savedir=output_path)
            if idx % 10 == 0:
                print(f'# {idx} / {len(loader_train.dataset)}')
            if idx >= 10:
                break
    else:

        model = set_model(args, args.num_actions, args.num_objects)
        optim_model, optim_disc, scheduler = create_optimizer_and_scheduler(args, model)

        if args.ckpt:
            best_loss = load_all_model(args, model, optim_model, optim_disc, lr_scheduler=scheduler)
        else:
            best_loss = 1e10

        model.cuda()

        for state in optim_model.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        criterion = VAEMetrics(args.dim_z)

        # Test
        metrics = evaluate(args, model, criterion, loader_train, loader_test, model_name)

    print('\nWell done')


if __name__ == "__main__":
    args = parse_arguments()
    print("Number of GPUs:", torch.cuda.device_count())
    main(args)
