from itertools import chain
from torch.nn import functional
import torch.optim as optim
import torch
import logging
import numpy as np
from collections import defaultdict
from causal import compute_implicit_causal_effects, run_enco
from pathlib import Path
from torch.utils.data import TensorDataset
from dci_metrics import compute_dci
import os
import torch.distributed as dist
from causal import fix_topological_order
from model.encoder import GaussianEncoder, ImageEncoder, ImageDecoder, CoordConv2d
from dataset import *


def create_optimizer_and_scheduler(args, model, betas=(0.9, 0.999),
                                   separate_param_groups=False):
    """Initializes optimizer and scheduler"""

    if hasattr(model, 'mine_model'):
        optimizer_discriminator = optim.Adam(model.mine_model.parameters(), lr=args.lr, betas=betas)

        # Create another optimizer for the rest of the model excluding mi_discriminator
        model_params_except_discriminator = [p for name, p in model.named_parameters() if 'mine_model' not in name]
        optim_model = optim.Adam(model_params_except_discriminator, lr=args.lr, betas=betas)

    else:
        optim_model = optim.Adam(model.parameters(), lr=args.lr, betas=betas)
        optimizer_discriminator = None

    if args.lr_schedule == "constant":
        scheduler = None
    elif args.lr_schedule == "cosine":
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optim_model, args.epochs, eta_min=args.lr_schedule_minimal
        )
    elif args.lr_schedule in ["cosine_restarts", "cosine_restarts_reset"]:
        try:
            t_mult = args.lr_schedule_increase_period_by_factor
        except:
            t_mult = 1.0
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optim_model,
            args.lr_schedule_restart_every_epochs,
            eta_min=args.lr_schedule_minimal,
            T_mult=t_mult,
        )
    elif args.lr_schedule == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optim_model,
            step_size=args.lr_schedule_step_every_epochs,
            gamma=args.lr_schedule_step_gamma,
        )
    else:
        raise ValueError(
            f"Unknown value in cfg.training.lr_schedule: {args.lr_schedule}"
        )

    return optim_model, optimizer_discriminator, scheduler


def determine_graph_learning_settings(args, epoch, model):
    """Put together kwargs for graph parameterization"""

    try:
        if model.module.scm.graph is None:
            return {}

        if epoch < args.graph_sampling_initial_unfreeze_epoch:
            if epoch == 0:
                logging.info(
                    f"Freezing adjacency matrix initially. Value:\n{model.module.scm.graph.adjacency_matrix}"
                )
                model.module.scm.graph.freeze()
            graph_kwargs = dict(
                graph_mode=args.graph_sampling_mode,
                graph_temperature=args.graph_sampling_temperature,
                graph_samples=args.graph_sampling_samples,
            )
        elif epoch >= args.graph_sampling_final_freeze_epoch:
            if epoch == args.graph_sampling_final_freeze_epoch:
                logging.info(
                    f"Freezing adjacency matrix after epoch {epoch}. Value:\n"
                    f"{model.module.scm.graph.adjacency_matrix}"
                )
                model.module.scm.graph.freeze()
            graph_kwargs = dict(
                graph_mode=args.graph_sampling_final_mode,
                graph_temperature=args.graph_sampling_final_temperature,
                graph_samples=args.graph_sampling_final_samples,
            )
        else:
            if epoch == args.graph_sampling_initial_unfreeze_epoch:
                logging.info(
                    f"Unfreezing adjacency matrix after epoch {epoch}. Value:\n"
                    f"{model.module.scm.graph.adjacency_matrix}"
                )
                model.module.scm.graph.unfreeze()
            graph_kwargs = dict(
                graph_mode=args.graph_sampling_mode,
                graph_temperature=args.graph_sampling_temperature,
                graph_samples=args.graph_sampling_samples,
            )
    except AttributeError:
        if model.scm.graph is None:
            return {}

        if epoch < args.graph_sampling_initial_unfreeze_epoch:
            if epoch == 0:
                logging.info(
                    f"Freezing adjacency matrix initially. Value:\n{model.scm.graph.adjacency_matrix}"
                )
                model.scm.graph.freeze()
            graph_kwargs = dict(
                graph_mode=args.graph_sampling_mode,
                graph_temperature=args.graph_sampling_temperature,
                graph_samples=args.graph_sampling_samples,
            )
        elif epoch >= args.graph_sampling_final_freeze_epoch:
            if epoch == args.graph_sampling_final_freeze_epoch:
                logging.info(
                    f"Freezing adjacency matrix after epoch {epoch}. Value:\n"
                    f"{model.scm.graph.adjacency_matrix}"
                )
                model.scm.graph.freeze()
            graph_kwargs = dict(
                graph_mode=args.graph_sampling_final_mode,
                graph_temperature=args.graph_sampling_final_temperature,
                graph_samples=args.graph_sampling_final_samples,
            )
        else:
            if epoch == args.graph_sampling_initial_unfreeze_epoch:
                logging.info(
                    f"Unfreezing adjacency matrix after epoch {epoch}. Value:\n"
                    f"{model.scm.graph.adjacency_matrix}"
                )
                model.scm.graph.unfreeze()
            graph_kwargs = dict(
                graph_mode=args.graph_sampling_mode,
                graph_temperature=args.graph_sampling_temperature,
                graph_samples=args.graph_sampling_samples,
            )

    return graph_kwargs


def epoch_schedules(args, model, epoch, optim, val_loader):
    """Epoch-based schedulers"""

    # Pretraining?
    pretrain = args.pretrain_epochs is not None and epoch < args.pretrain_epochs
    if epoch == args.pretrain_epochs:
        logging.info(f"Stopping pretraining at epoch {epoch}")

    # Model interventions in SCM / noise model?
    model_interventions = (
            args.model_interventions_after_epoch is None
            or epoch >= args.model_interventions_after_epoch
    )
    if epoch == args.model_interventions_after_epoch:
        logging.info(f"Beginning to model intervention distributions at epoch {epoch}")

    # Freeze encoder?
    if args.freeze_encoder_epoch is not None and epoch == args.freeze_encoder_epoch:
        logging.info(f"Freezing encoder and decoder at epoch {epoch}")
        optim.param_groups[0]["lr"] = 0.0

    # Fix noise-centric model to a topological order encoder?
    if (args.fix_topological_order_epoch is not None
            and epoch == args.fix_topological_order_epoch
    ):
        logging.info(f"Determining topological order at epoch {epoch}")
        fix_topological_order(args, model, partition="val", dataloader=val_loader)

    model_noise = args.model_noise_after_epoch is not None and epoch >= args.model_noise_after_epoch

    # Deterministic intervention encoders?
    if args.deterministic_intervention_encoder_after_epoch is None:
        deterministic_intervention_encoder = False
    else:
        deterministic_intervention_encoder = (
                epoch >= args.deterministic_intervention_encoder_after_epoch
        )
    if epoch == args.deterministic_intervention_encoder_after_epoch:
        logging.info(f"Switching to deterministic intervention encoder at epoch {epoch}")

    return model_interventions, pretrain, deterministic_intervention_encoder, model_noise


def step_schedules(args, model, fractional_epoch):
    """Step-based schedulers"""

    set_manifold_thickness(args, model, fractional_epoch)

    beta = generic_scheduler(args, args.beta_schedule_initial, args.beta_schedule_final,
                             args.beta_schedule_initial_constant_epochs, args.beta_schedule_decay_epochs,
                             args.beta_schedule, fractional_epoch, default_value=1.0)
    beta_intervention = args.increase_intervention_beta * beta

    consistency_regularization_amount = generic_scheduler(args, args.consistency_regularization_schedule_initial,
                                                          args.consistency_regularization_schedule_final,
                                                          args.consistency_regularization_schedule_initial_constant_epochs,
                                                          args.consistency_regularization_schedule_decay_epochs,
                                                          args.consistency_regularization_schedule, fractional_epoch,
                                                          default_value=0.0
                                                          )
    inverse_consistency_regularization_amount = generic_scheduler(args,
                                                                  args.inverse_consistency_regularization_schedule_initial,
                                                                  args.inverse_consistency_regularization_schedule_final,
                                                                  args.inverse_consistency_regularization_schedule_initial_constant_epochs,
                                                                  args.inverse_consistency_regularization_schedule_decay_epochs,
                                                                  args.inverse_consistency_regularization_schedule,
                                                                  fractional_epoch,
                                                                  default_value=0.0,
                                                                  )

    z_regularization_amount = generic_scheduler(args, args.z_regularization_schedule_initial,
                                                args.z_regularization_schedule_final,
                                                args.z_regularization_schedule_initial_constant_epochs,
                                                args.z_regularization_schedule_decay_epochs,
                                                args.z_regularization_schedule, fractional_epoch, default_value=0.0
                                                )

    edge_regularization_amount = generic_scheduler(args, args.edge_regularization_schedule_initial,
                                                   args.edge_regularization_schedule_final,
                                                   args.edge_regularization_schedule_initial_constant_epochs,
                                                   args.edge_regularization_schedule_decay_epochs,
                                                   args.edge_regularization_schedule, fractional_epoch,
                                                   default_value=0.0
                                                   )

    cyclicity_regularization_amount = generic_scheduler(args, args.cyclicity_regularization_schedule_initial,
                                                        args.cyclicity_regularization_schedule_final,
                                                        args.cyclicity_regularization_schedule_initial_constant_epochs,
                                                        args.cyclicity_regularization_schedule_decay_epochs,
                                                        args.cyclicity_regularization_schedule, fractional_epoch,
                                                        default_value=0.0
                                                        )

    intervention_entropy_regularization_amount = generic_scheduler(args,
                                                                   args.intervention_entropy_regularization_schedule_initial,
                                                                   args.intervention_entropy_regularization_schedule_final,
                                                                   args.intervention_entropy_regularization_schedule_initial_constant_epochs,
                                                                   args.intervention_entropy_regularization_schedule_decay_epochs,
                                                                   args.intervention_entropy_regularization_schedule,
                                                                   fractional_epoch,
                                                                   default_value=0.0,
                                                                   )

    intervention_encoder_offset = generic_scheduler(args,
                                                    args.intervention_encoder_offset_schedule_initial,
                                                    args.intervention_encoder_offset_schedule_final,
                                                    args.intervention_encoder_offset_schedule_initial_constant_epochs,
                                                    args.intervention_encoder_offset_schedule_decay_epochs,
                                                    args.intervention_encoder_offset_schedule, fractional_epoch,
                                                    default_value=0.0
                                                    )

    cov_regularization_amount = generic_scheduler(args,
                                                  args.cov_regularization_schedule_initial,
                                                  args.cov_regularization_schedule_final,
                                                  args.cov_regularization_schedule_initial_constant_epochs,
                                                  args.cov_regularization_schedule_decay_epochs,
                                                  args.cov_regularization_schedule, fractional_epoch,
                                                  default_value=0.0
                                                  )

    return (
        beta,
        beta_intervention,
        consistency_regularization_amount,
        cyclicity_regularization_amount,
        edge_regularization_amount,
        inverse_consistency_regularization_amount,
        z_regularization_amount,
        intervention_entropy_regularization_amount,
        intervention_encoder_offset,
        cov_regularization_amount,
    )


def generic_scheduler(args, initial, final, initial_constant_epochs, decay_epochs, schedule_cfg, epoch,
                      default_value=None):
    """Generic scheduler (wraps around constant / exponential / ... schedulers)"""
    if epoch is None:
        return default_value
    elif schedule_cfg == "constant":
        return final
    elif schedule_cfg == "constant_constant":
        if epoch < initial_constant_epochs:
            return initial
        else:
            return final
    elif schedule_cfg == "exponential":
        return exponential_scheduler(
            epoch, args.epochs, initial, final
        )
    elif schedule_cfg == "exponential_constant":
        return exponential_scheduler(
            epoch,
            decay_epochs,
            initial,
            final,
        )
    elif schedule_cfg == "constant_exponential_constant":
        return exponential_scheduler(
            epoch - initial_constant_epochs,
            decay_epochs,
            initial,
            final,
        )
    elif schedule_cfg == "constant_linear_constant":
        return linear_scheduler(
            epoch - initial_constant_epochs,
            decay_epochs,
            initial,
            final,
        )
    else:
        raise ValueError(f"Unknown scheduler type: {schedule_cfg}")


def exponential_scheduler(step, total_steps, initial, final):
    """Exponential scheduler"""

    if step >= total_steps:
        return final
    if step <= 0:
        return initial
    if total_steps <= 1:
        return final

    t = step / (total_steps - 1)
    log_value = (1.0 - t) * np.log(initial) + t * np.log(final)
    return np.exp(log_value)


def linear_scheduler(step, total_steps, initial, final):
    """Linear scheduler"""

    if step >= total_steps:
        return final
    if step <= 0:
        return initial
    if total_steps <= 1:
        return final

    t = step / (total_steps - 1)
    return (1.0 - t) * initial + t * final


def set_manifold_thickness(args, model, epoch):
    """For models with non-zero manifold thickness, set that according to a scheduler"""
    manifold_thickness = generic_scheduler(args, args.manifold_thickness_schedule_initial,
                                           args.manifold_thickness_schedule_final,
                                           args.manifold_thickness_schedule_initial_constant_epochs,
                                           args.manifold_thickness_schedule_decay_epochs,
                                           args.manifold_thickness_schedule, epoch)
    if manifold_thickness is None:  # Reset manifold thickness
        manifold_thickness = args.scm_manifold_thickness

    try:
        model.module.scm.manifold_thickness = manifold_thickness
    except AttributeError:
        model.scm.manifold_thickness = manifold_thickness


def check_for_nan_gradients(model):
    nan_gradients = []

    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                nan_gradients.append(name)

    return nan_gradients


def optimizer_step(args, vae_loss, disc_loss, model, optim, optim_discriminator, x1, x2, interventions, model_noise):
    """Optimizer step, plus some logging and NaN handling"""
    finite = torch.isfinite(vae_loss)
    if finite:
        nan_grads = check_for_nan_gradients(model)
        if nan_grads:
            print("Found NaN gradients in the following parameters:", nan_grads)

        if args.clip_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm, error_if_nonfinite=True)

        if optim_discriminator is not None and model_noise:
            optim_discriminator.zero_grad()
            disc_loss.backward(retain_graph=True)
            optim_discriminator.step()

            # === Encoder (VAE) adversarial training ===
            try:
                optim.zero_grad()
                loss_adv = model.get_model_adversarial_loss(x1, x2, interventions)
                total_vae_loss = vae_loss + loss_adv
                total_vae_loss.backward()
                optim.step()
            except AttributeError:
                optim.zero_grad()
                loss_adv = model.module.get_model_adversarial_loss(x1, x2, interventions)
                total_vae_loss = vae_loss + loss_adv
                total_vae_loss.backward()
                optim.step()

        else:
            optim.zero_grad()
            vae_loss.backward()
            optim.step()

    else:
        logging.info("NaN in loss")
        raise RuntimeError

    return finite


def reset_optimizer_state(optimizer):
    """Resets optimizer state"""
    optimizer.__setstate__({"state": defaultdict(dict)})


def compute_metrics_on_dataset(args, model, criteria, data_loader, symmetric_index=None, reverse=False):
    """Computes metrics on a full dataset"""
    # At test time, always use the canonical manifold thickness

    set_manifold_thickness(args, model, None)

    nll, samples = 0.0, 0
    loss = 0.0
    metrics = None
    batches = 0

    # Loop over batches
    for i, batch in enumerate(data_loader):
        batches += 1

        if args.dataset in ["epickitchens", "procthor"]:
            x1, x2, label, noun, s1, s2 = batch

            label_sym = label.clone()
            if symmetric_index is not None:
                label_sym.apply_(lambda x: symmetric_index[x] if x in symmetric_index else x)

            if reverse:
                assert symmetric_index
                flip = torch.zeros_like(label).gt(0)
                flip.apply_(lambda x: x in symmetric_index)
                label.apply_(lambda x: symmetric_index[x] if x in symmetric_index else x)
                x1[flip], x2[flip] = x2[flip], x1[flip]
                s1[flip], s2[flip] = s2[flip], s1[flip]

            x1 = x1.cuda()
            x2 = x2.cuda()
            label = label.cuda()
            label_sym = label_sym.cuda()
            noun = noun.cuda()
            s1 = s1.cuda()
            s2 = s2.cuda()
            intervention_labels = None
            true_interventions = None

        else:
            x1, x2, z1, shifts1, scales1, z2, shifts2, scales2, intervention_labels, true_interventions = batch
            x1, x2, z1, shifts1, scales1, z2, shifts2, scales2, intervention_labels, true_interventions = (
                x1.cuda(),
                x2.cuda(),
                z1.cuda(),
                shifts1.cuda(),
                scales1.cuda(),
                z2.cuda(),
                shifts2.cuda(),
                scales2.cuda(),
                intervention_labels.cuda(),
                true_interventions.cuda(),
            )

            label = intervention_labels
            label_sym = None
            noun = None
            s1 = None
            s2 = None

        if args.mask:
            log_prob, model_outputs = model(
                x1,
                x2,
                s1,
                s2,
                beta=args.eval_beta,
                true_action=label,
                true_object=noun,
                full_likelihood=args.eval_full_likelihood,
                likelihood_reduction=args.eval_likelihood_reduction,
                graph_mode=args.eval_graph_sampling,
                graph_temperature=args.eval_graph_sampling_temperature,
                graph_samples=args.eval_graph_sampling_samples,
            )
        else:
            log_prob, model_outputs = model(
                x1,
                x2,
                beta=args.eval_beta,
                true_action=label,
                true_object=noun,
                full_likelihood=args.eval_full_likelihood,
                likelihood_reduction=args.eval_likelihood_reduction,
                graph_mode=args.eval_graph_sampling,
                graph_temperature=args.eval_graph_sampling_temperature,
                graph_samples=args.eval_graph_sampling_samples,
            )

        if args.dataset in ["epickitchens", "procthor"]:
            batch_loss, _, batch_metrics = criteria(
                log_prob,
                true_intervention_labels=label,
                **model_outputs,
            )

        else:
            batch_loss, _, batch_metrics = criteria(
                log_prob,
                true_interventions=true_interventions,
                true_intervention_labels=intervention_labels,
                **model_outputs,
            )

        # TODO
        # try:
        #     batch_log_likelihood = torch.mean(
        #         model.module.log_likelihood(
        #             x1,
        #             x2,
        #             n_latent_samples=args.iwae_samples,
        #             beta=args.eval_beta,
        #             full_likelihood=args.eval_full_likelihood,
        #             likelihood_reduction=args.eval_likelihood_reduction,
        #             graph_mode=args.eval_graph_sampling,
        #             graph_temperature=args.eval_graph_sampling_temperature,
        #             graph_samples=args.eval_graph_sampling_samples,
        #         )
        #     ).item()
        # except AttributeError:
        #     batch_log_likelihood = torch.mean(
        #         model.log_likelihood(
        #             x1,
        #             x2,
        #             n_latent_samples=args.iwae_samples,
        #             beta=args.eval_beta,
        #             full_likelihood=args.eval_full_likelihood,
        #             likelihood_reduction=args.eval_likelihood_reduction,
        #             graph_mode=args.eval_graph_sampling,
        #             graph_temperature=args.eval_graph_sampling_temperature,
        #             graph_samples=args.eval_graph_sampling_samples,
        #         )
        #     ).item()

        # Tally up metrics
        loss += batch_loss
        if metrics is None:
            metrics = batch_metrics
        else:
            for key, val in metrics.items():
                metrics[key] += batch_metrics[key]
        # nll -= batch_log_likelihood

    # Average over batches
    loss /= batches
    for key, val in metrics.items():
        metrics[key] = val / batches
    # metrics["nll"] = nll / batches

    return loss, None, metrics


@torch.no_grad()
def eval_accuracy(args, model, loader, symmetric_index=None, reverse=False):
    """Evaluates DCI scores"""

    model.eval()

    acc_sum_act = 0.0
    acc_sum_obj = 0.0
    for i, batch in enumerate(loader):
        # measure data loading time

        # data
        first_img, second_img, label, noun, first_mask, second_mask = batch

        if reverse:
            assert symmetric_index
            flip = torch.zeros_like(label).gt(0)
            flip.apply_(lambda x: x in symmetric_index)
            label.apply_(lambda x: symmetric_index[x] if x in symmetric_index else x)
            first_img[flip], second_img[flip] = second_img[flip], first_img[flip]
            first_mask[flip], second_mask[flip] = second_mask[flip], first_mask[flip]

        if torch.cuda.is_available():
            first_img = first_img.cuda()
            second_img = second_img.cuda()
            label = label.cuda()
            noun = noun.cuda()
            first_mask = first_mask.cuda()
            second_mask = second_mask.cuda()

        # test
        with torch.no_grad():
            if args.mask:
                if args.distributed:
                    action_prob, object_prob = model.module.predict_classes(first_img, second_img, label, first_mask,
                                                                            second_mask)
                else:
                    action_prob, object_prob = model.predict_classes(first_img, second_img, label, first_mask,
                                                                     second_mask)

            else:
                if args.distributed:
                    action_prob, object_prob = model.module.predict_classes(first_img, second_img, label)
                else:
                    action_prob, object_prob = model.predict_classes(first_img, second_img, label)

        # metric
        _, predict_action = action_prob.max(1)
        _, predict_object = object_prob.max(1)

        correct_action = (predict_action == label.to(torch.uint8)).float().sum()
        correct_object = (predict_object == noun.to(torch.uint8)).float().sum()

        accuracy_action_batch = correct_action / label.size(0)
        accuracy_object_batch = correct_object / noun.size(0)

        acc_sum_act += accuracy_action_batch.item()
        acc_sum_obj += accuracy_object_batch.item()

    action_accuracy = acc_sum_act / len(loader)
    object_accuracy = acc_sum_obj / len(loader)

    metrics = {"action_accuracy": action_accuracy, "object_accuracy": object_accuracy}

    return metrics


@torch.no_grad()
def eval_dci_scores(args, model, train_loader, test_loader, partition="test", full_importance_matrix=True):
    """Evaluates DCI scores"""

    model.eval()

    if args.dataset == "synthetic":
        # Load train data
        filename = Path(args.path_data) / "dci_train.pt"
        logging.debug(f"Loading data from {filename}")
        data = torch.load(filename)
        x_train, x_train_tilde, train_true_z, _, _, train_true_z_tilde, *_ = TensorDataset(*data).tensors

        # Load train data
        filename = Path(args.path_data) / f"{partition}.pt"
        logging.debug(f"Loading data from {filename}")
        data = torch.load(filename)
        x_test, x_test_tilde, test_true_z, _, _, test_true_z_tilde, *_ = TensorDataset(*data).tensors

        if args.distributed:
            train_model_z = model.module.encode_to_causal(x_train.cuda(), x_train_tilde.cuda(), deterministic=True)
            test_model_z = model.module.encode_to_causal(x_test.cuda(), x_test_tilde.cuda(), deterministic=True)
        else:
            train_model_z = model.encode_to_causal(x_train.cuda(), x_train_tilde.cuda(), deterministic=True)
            test_model_z = model.encode_to_causal(x_test.cuda(), x_test_tilde.cuda(), deterministic=True)

    else:
        raise NotImplementedError

    causal_dci_metrics = compute_dci(
        train_true_z,
        train_model_z,
        test_true_z,
        test_model_z,
        return_full_importance_matrix=full_importance_matrix,
    )

    renamed_metrics = {}
    for key, val in causal_dci_metrics.items():
        renamed_metrics[f"causal_{key}"] = val

    return renamed_metrics


def eval_enco_graph(args, model, partition="train"):
    """Post-hoc graph evaluation with ENCO"""

    # Only want to do this for ILCMs
    if args.model not in ["ilcm", "icrl"]:
        return {}

    model.eval()

    # Load data and compute noise encodings
    with torch.no_grad():
        filename = Path(args.path_data) / f"{partition}.pt"
        logging.debug(f"Loading data from {filename}")
        data = torch.load(filename)
        x0, x1, _, _, _, _, _, _, interventions, _ = TensorDataset(*data).tensors
        _, _, _, _, _, e0, e1, intervention = model.encode_decode_pair(
            x0.cuda(), x1.cuda(), interventions
        )
        z0 = model.scm.noise_to_causal(e0)
        z1 = model.scm.noise_to_causal(e1)

    # Run ENCO
    adjacency_matrix = (
        run_enco(z0, z1, intervention, lambda_sparse=args.enco_lambda, device="cuda:0")
        .cpu()
        .detach()
    )

    # Package as dict
    results = {
        f"enco_graph_{i}_{j}": adjacency_matrix[i, j].item()
        for i in range(model.dim_z)
        for j in range(model.dim_z)
    }

    return results


@torch.no_grad()
def eval_implicit_graph(args, model, partition="val"):
    """Evaluates implicit graph"""

    # This is only defined for noise-centric models (ILCMs)
    if args.model not in ["ilcm", "icrl"]:
        return {}

    # Let's skip this for large latent spaces
    if args.dim_z > 5:
        return {}

    model.eval()

    # Load data and compute noise encodings
    filename = Path(args.path_data) / f"{partition}.pt"
    logging.debug(f"Loading data from {filename}")
    data = torch.load(filename)
    x, *_ = TensorDataset(*data).tensors
    noise = model.encode_to_noise(x.cuda(), deterministic=True).detach()

    # Evaluate causal strength
    causal_effects, topological_order = compute_implicit_causal_effects(model, noise)

    # Package as dict
    results = {
        f"implicit_graph_{i}_{j}": causal_effects[i, j].item()
        for i in range(model.dim_z)
        for j in range(model.dim_z)
    }

    return results


def save_all_model(args, model, model_name, optimizer, optim_discriminator, epoch, best_loss, save_best=False):
    if args.distributed:
        checkpoint = {
            'epoch': epoch + 1,
            'state_dicts': model.module.state_dict(),
            'optimizers': optimizer.state_dict(),
            'best_loss': best_loss
        }

        if optim_discriminator is not None:
            checkpoint['optim_discriminator'] = optim_discriminator.state_dict()

    else:
        checkpoint = {
            'epoch': epoch + 1,
            'state_dicts': model.state_dict(),
            'optimizers': optimizer.state_dict(),
            'best_loss': best_loss
        }

        if optim_discriminator is not None:
            checkpoint['optim_discriminator'] = optim_discriminator.state_dict()

    filefolder = f'{args.expdir}/{model_name}/weights'
    # Check whether the specified path exists or not
    if not os.path.exists(filefolder):
        os.makedirs(filefolder)

    if save_best:
        filename = f'{filefolder}/best_model.pt'
    else:
        filename = f'{filefolder}/{model_name}_{epoch}.pt'
    torch.save(checkpoint, filename)
    logging.info(f" --> Model Saved in {filename}")


def load_all_model(args, model, optimizers, optim_discriminator, lr_scheduler=None, num_batches=0):
    model_path = args.ckpt

    if os.path.isfile(model_path):
        checkpoint = torch.load(model_path, map_location='cpu')
        args.start_epoch = checkpoint['epoch']

        models_checkpoint = checkpoint['state_dicts']
        model.load_state_dict(models_checkpoint)

        if lr_scheduler != None:
            lr_scheduler.last_epoch = (args.start_epoch - 1)

        if optimizers != None:
            optimizers.load_state_dict(checkpoint['optimizers'])

        if optim_discriminator != None:
            optim_discriminator.load_state_dict(checkpoint['optim_discriminator'])

        logging.info("=> loaded checkpoint '{}' (epoch {})".format(model_path, checkpoint["epoch"]))

    else:
        logging.info('model {} not found'.format(model_path))

    return checkpoint['best_loss']


def create_encoder_decoder(args):
    """Create encoder and decoder"""

    logging.info(f"Creating {args.encoder} encoder / decoder")

    encoder_hidden_layers = args.encoder_hidden_layers
    encoder_hidden = [args.encoder_hidden_units for _ in range(encoder_hidden_layers)]
    decoder_hidden_layers = args.decoder_hidden_layers
    decoder_hidden = [args.decoder_hidden_units for _ in range(decoder_hidden_layers)]

    if args.architecture == "mlp":
        encoder = GaussianEncoder(
            encoder_type=args.encoder,
            encoder_decoder="encoder",
            hidden=encoder_hidden,
            input_features=args.dim_x,
            output_features=args.dim_z,
            fix_std=args.encoder_fix_std,
            init_std=args.encoder_std,
            min_std=args.encoder_min_std,
            amin=args.amin,
            resolution=args.resolution,
        )

        decoder = GaussianEncoder(
            encoder_type=args.encoder,
            encoder_decoder="decoder",
            hidden=decoder_hidden,
            input_features=args.dim_z,
            output_features=args.dim_x,
            fix_std=args.decoder_fix_std,
            init_std=args.decoder_std,
            min_std=args.decoder_min_std,
            amin=args.amin,
            resolution=args.resolution,
        )

        noise_encoder = GaussianEncoder(
            encoder_type=args.encoder,
            encoder_decoder="encoder",
            hidden=encoder_hidden,
            input_features=args.dim_x,
            output_features=args.dim_z,
            fix_std=args.encoder_fix_std,
            init_std=args.encoder_std,
            min_std=args.encoder_min_std,
            amin=args.amin,
            resolution=args.resolution,

        )
        noise_decoder = GaussianEncoder(
            encoder_type=args.encoder,
            encoder_decoder="decoder",
            hidden=decoder_hidden,
            input_features=args.dim_z,
            output_features=args.dim_x,
            fix_std=args.decoder_fix_std,
            init_std=args.decoder_std,
            min_std=args.decoder_min_std,
            amin=args.amin,
            resolution=args.resolution,
        )
    else:
        encoder = ImageEncoder(
            in_features=args.dim_x,
            out_features=args.dim_z,
            in_resolution=args.resolution,
            hidden_features=args.encoder_hidden_units,
            conv_class=CoordConv2d if args.encoder_coordinate_embeddings else torch.nn.Conv2d,
            batchnorm=False,
            min_std=args.encoder_min_std,
            mlp_layers=args.encoder_extra_mlp_layers,
            mlp_hidden=args.encoder_extra_mlp_hidden_units,
            elementwise_hidden=args.encoder_elementwise_hidden_units,
            elementwise_layers=args.encoder_elementwise_layers,
        )

        noise_encoder = ImageEncoder(
            in_features=args.dim_x,
            out_features=args.dim_z,
            in_resolution=args.resolution,
            hidden_features=args.encoder_hidden_units,
            conv_class=CoordConv2d if args.encoder_coordinate_embeddings else torch.nn.Conv2d,
            batchnorm=False,
            min_std=args.encoder_min_std,
            mlp_layers=args.encoder_extra_mlp_layers,
            mlp_hidden=args.encoder_extra_mlp_hidden_units,
            elementwise_hidden=args.encoder_elementwise_hidden_units,
            elementwise_layers=args.encoder_elementwise_layers,
        )

        decoder = ImageDecoder(
            in_features=args.dim_z,
            out_features=args.dim_x,
            out_resolution=args.resolution,
            hidden_features=args.decoder_hidden_units,
            conv_class=CoordConv2d if args.decoder_coordinate_embeddings else torch.nn.Conv2d,
            batchnorm=False,
            min_std=args.decoder_min_std,
            fix_std=args.decoder_fix_std,
            mlp_layers=args.decoder_extra_mlp_layers,
            mlp_hidden=args.decoder_extra_mlp_hidden_units,
            elementwise_hidden=args.decoder_elementwise_hidden_units,
            elementwise_layers=args.decoder_elementwise_layers,
        )

        noise_decoder = ImageDecoder(
            in_features=args.dim_z,
            out_features=args.dim_x,
            out_resolution=args.resolution,
            hidden_features=args.decoder_hidden_units,
            conv_class=CoordConv2d if args.decoder_coordinate_embeddings else torch.nn.Conv2d,
            batchnorm=False,
            min_std=args.decoder_min_std,
            fix_std=args.decoder_fix_std,
            mlp_layers=args.decoder_extra_mlp_layers,
            mlp_hidden=args.decoder_extra_mlp_hidden_units,
            elementwise_hidden=args.decoder_elementwise_hidden_units,
            elementwise_layers=args.decoder_elementwise_layers,
        )

    return encoder, decoder, noise_encoder, noise_decoder


def save_configs(datetime, args, exp_name):
    info = f'{datetime}\n{args["description"]}\n'
    for item in args.keys():
        info += f'{item}:{args[item]}\n'

    with open(os.path.join(args["expdir"], f'{exp_name}/{exp_name}_args.txt'), 'w') as fp:
        fp.write(info)
    fp.close()
