import os

import torch
from util import *
from training import VAEMetrics
from parse_args import parse_arguments
import logging
from experiment_utils import *
import pandas as pd
from main import set_logger, set_model, set_loader
from visualize import *
import datetime
import cv2

@torch.no_grad()
def evaluate(args, model, criteria, test_loader, model_name):
    """High-level test function"""

    model.eval()

    logging.info("Starting evaluation")

    # Compute metrics
    _, _, test_metrics = compute_metrics_on_dataset(args, model, criteria, test_loader)

    # Print DCI disentanglement score
    if args.dataset == "synthetic":
        dci = eval_dci_scores(args, model, partition="test")
        test_metrics.update(dci)

        if args.model in ["ilcm", "softilcm"] and args.dim_z <= 5:
            results = eval_implicit_graph(args, model, partition="test")
            test_metrics.update(results)
            causal_effect = np.zeros((model.dim_z, model.dim_z))
            for i in range(model.dim_z):
                for j in range(model.dim_z):
                    causal_effect[i, j] = test_metrics[f"implicit_graph_{i}_{j}"]

            logging.info(f"\nCausal effects =\n {causal_effect}")

        if args.model == "softilcm" and args.dim_z <= 5:
            enco_graph = eval_enco_graph(args, model, partition="test")
            test_metrics.update(enco_graph)

            filename = Path(args.path_data) / "test.pt"
            data = torch.load(filename)
            x_test, x_test_tilde, *_ = TensorDataset(*data).tensors
            e_e_corr, n_e_corr, shift_e_corr, scales_e_corr = model.latents_correlations(x_test.cuda(), x_test_tilde.cuda())
            logging.info(f"\nNoise and Exogenous Correlations =\n {n_e_corr}")
            logging.info(f"\nExogenous and Exogenous Correlations =\n {e_e_corr}")
            logging.info(f"\nShifts and Exogenous Correlations =\n {shift_e_corr}")
            logging.info(f"\nScales and Exogenous Correlations =\n {scales_e_corr}")

            graph = np.zeros((model.dim_z, model.dim_z))
            for i in range(model.dim_z):
                for j in range(model.dim_z):
                    graph[i, j] = test_metrics[f"enco_graph_{i}_{j}"]

            logging.info(f"\nENCO graph =\n {graph}")

        logging.info(
            f"\ncausal disentanglement = {test_metrics['causal_disentanglement']:.2f}"
        )
        logging.info(
            f"\ncausal completeness = {test_metrics['causal_completeness']:.2f}"
        )

        dec = 4
        causal_importance_matrix = np.zeros((args.dim_z, args.dim_z))
        for r in range(args.dim_z):
            for c in range(args.dim_z):
                causal_importance_matrix[r, c] = round(test_metrics[f'causal_importance_matrix_{r}_{c}'], dec)
        logging.info('\n------- Causal Importance Matrix ------')
        logging.info(causal_importance_matrix)

    else:
        test_metrics.update(eval_accuracy(args, model, test_loader))
        logging.info(f"Action accuracy = {test_metrics['action_accuracy']:.2f}")
        logging.info(f"Object accuracy = {test_metrics['object_accuracy']:.2f}")

    # Store results in csv file
    # Pandas does not like scalar values, have to be iterables
    test_metrics_ = {key: [val.cpu() if isinstance(val, torch.Tensor) and val.is_cuda else val] for key, val in
                     test_metrics.items()}
    df = pd.DataFrame.from_dict(test_metrics_)

    folder_name = os.path.join(args.expdir, f"{model_name}/test/metrics")
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    df.to_csv(f'{folder_name}/test_metrics.tsv')

    return test_metrics


@torch.no_grad()
def evaluate_images(args, model, test_loader):
    """High-level test function"""

    model.eval()

    if args.dataset in ["epickitchens", "procthor"]:

        # Loop over batches
        for i, batch in enumerate(test_loader):

            x1, x2, label, noun, s1, s2 = batch

            x1 = x1.cuda()
            x2 = x2.cuda()
            label = label.cuda()
            noun = noun.cuda()
            s1 = s1.cuda()
            s2 = s2.cuda()
            intervention_labels = None
            true_interventions = None

            x1_reco, x2_reco, e1_mean, e2_mean, e1_proj, e2_proj, one_hot_interventions = model.encode_decode_pair(x1, x2, label)

            z = model.encode_to_causal(x1, x2)
            z_modified = z.clone()
            z_modified[torch.arange(len(label)), label] = 0.0
            x1_rec_int_z = model.decode_causal(z_modified)


            mean = (0.485, 0.456, 0.406)
            std = (0.229, 0.224, 0.225)
            denormalized_img = denormalize(x1_rec_int_z[0], mean, std)
            imshow(denormalized_img)
            denormalized_img = denormalize(x1[0], mean, std)
            imshow(denormalized_img)

    else:
        return {}


def denormalize(tensor, mean, std):
    denormalized = torch.clone(tensor)
    for t, m, s in zip(denormalized, mean, std):
        t.mul_(s).add_(m)  # denormalize
    return denormalized

def imshow(tensor):
    # Convert tensor to numpy array
    npimg = tensor.detach().cpu().numpy()
    # Transpose the image from (C, H, W) to (H, W, C)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

def main(args):
    # timestamp = datetime.datetime.now().astimezone().strftime("%Y%m%d_%H%M%S")
    # model_name = f"{args.encoder}_{args.epochs}_{args.seed}_{timestamp}"
    model_name = os.path.basename(os.path.dirname(os.path.dirname(args.ckpt)))
    args.expdir = os.path.dirname(os.path.dirname(os.path.dirname(args.ckpt)))
    set_seed(args.seed)
    set_logger(args, model_name, stage='test')

    loader_train, loader_valid, loader_test, loader_ood, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun = set_loader(args)

    if args.model is None:
        output_path = f'fig/pairs/'
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        for idx, (first, second, verb, noun, _, _) in enumerate(loader_test.dataset):
            show_pair(first, second, dict_verb_class[verb], dict_noun_class[noun], suffix=f'_iid_{idx}',
                      savedir=output_path)
            if idx % 10 == 0:
                print(f'# {idx} / {len(loader_train.dataset)}')
            if idx >= 10:
                break
        for idx, (first, second, verb, noun, _, _) in enumerate(loader_ood.dataset):
            show_pair(first, second, dict_verb_class[verb], dict_noun_class[noun], suffix=f'_ood_{idx}',
                      savedir=output_path)
            if idx % 10 == 0:
                print(f'# {idx} / {len(loader_train.dataset)}')
            if idx >= 10:
                break
    else:

        model = set_model(args, args.num_actions, args.num_objects)
        optim_model, optim_disc, scheduler = create_optimizer_and_scheduler(args, model)

        if args.ckpt:
            best_loss = load_all_model(args, model, optim_model, optim_disc, lr_scheduler=scheduler)
        else:
            best_loss = 1e10

        model.cuda()

        for state in optim_model.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        criterion = VAEMetrics(args.dim_z)

        # Test
        if args.eval == "quantitative":
            metrics = evaluate(args, model, criterion, loader_test, model_name)
        else:
            evaluate_images(args, model, loader_test)

    print('\nWell done')


def read_tsv(tsv_file):
    df = pd.read_csv(tsv_file, sep='\t', low_memory=False)
    return df


def main_avg(avg_seed=False):

    if not avg_seed:
        return

    exp_root = '/home/zahra/Desktop/soft_ilcm/ILCM-A/experiments/'
    model_path_dict = {}
    for var in range(4, 16):
        model_path_dict = {f'{var}': []}

    for var in range(4, 16):
        for seed in range(1, 13):
            if seed in [3, 4]:
                continue
            exp_path = f'{exp_root}/synthetic/{var}/soft/{seed}/softilcm'
            weights_path = [f"{exp_path}/{f}/weights" for f in os.listdir(exp_path)
                            if os.path.isdir(f"{exp_path}/{f}/weights")]
            if len(weights_path) > 1:
                print('2 experiments results are available')
                break
            model_path = [f'{weights_path[0]}/{m}' for m in os.listdir(weights_path[0])
                          if os.path.splitext(m)[1] == '.pt' and ''.join(list(os.path.splitext(m)[0])[-3:]) == '_89']

            model_path_dict[f'{var}'].append(model_path)
    return model_path_dict

if __name__ == "__main__":
    args = parse_arguments()
    print("Number of GPUs:", torch.cuda.device_count())
    main(args)

    # model_path_dict = main_avg(avg_seed=True)
