import copy
import json
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from .classifier import Classifier, ClassifierConv
from .distribution import Bernoulli, NegativeBinomial, ZeroInflatedNegativeBinomial
from .module import MLP

from utils.math_utils import (
    logprob_normal, kldiv_normal,
    logprob_bernoulli_logits,
    logprob_nb_positive,
    logprob_zinb_positive
)
from utils.model_utils import total_grad_norm_, lr_lambda_exp, lr_lambda_lin, update_ema

#####################################################
#                     LOAD MODEL                    #
#####################################################

overridden_hparams_keys = ("ema_rate", "max_grad_norm", "grad_skip_threshold", "discrim_step_cutoff", "linear_warmup_steps")

def load_VCI(args, state_dict=None, device="cuda"):
    if isinstance(args["hparams"], str):
        hparams_json = json.load(open(args["hparams"])) 
        args["hparams"] = hparams_json

    if args["lr"] is not None:
        args["hparams"]["discriminator_lr"] = args["lr"]
        args["hparams"]["classifier_lr"] = args["lr"]
        args["hparams"]["generator_lr"] = args["lr"]

    for arg_key in overridden_hparams_keys:
        if args[arg_key] is not None:
            args["hparams"][arg_key] = args[arg_key]


    if args["checkpoint_classifier"] is not None:
        state_dict_classifier, args_classifier = torch.load(
            args["checkpoint_classifier"], map_location="cpu")

    classifier = None
    if args["data_name"] == "gene":
        if args["dist_mode"] == "classify":
            classifier = Classifier(
                args_classifier["num_outcomes"],
                args_classifier["num_treatments"],
                args_classifier["num_covariates"],
                hparams=args_classifier["hparams"]
            )
            classifier.load_state_dict(state_dict_classifier)
            classifier.eval()

        model = VCI(
            args["num_outcomes"],
            args["num_treatments"],
            args["num_covariates"],
            embed_outcomes=True,
            embed_treatments=False,
            embed_covariates=True,
            omega0=args["omega0"],
            omega1=args["omega1"],
            omega2=args["omega2"],
            dist_outcomes=args["dist_outcomes"],
            dist_mode=args["dist_mode"],
            classifier=classifier,
            mc_sample_size=30,
            device=device,
            hparams=args["hparams"]
        )
    elif args["data_name"] in ("celebA", "openBHB", "morphomnist"):
        if args["dist_mode"] == "classify":
            classifier = ClassifierConv(
                args_classifier["num_outcomes"],
                args_classifier["num_treatments"],
                args_classifier["num_covariates"],
                hparams=args_classifier["hparams"]
            )
            classifier.load_state_dict(state_dict_classifier)
            classifier.eval()

        model = HVCIConv(
            args["num_outcomes"],
            args["num_treatments"],
            args["num_covariates"],
            embed_outcomes=True,
            embed_treatments=True,
            embed_covariates=False,
            omega0=args["omega0"],
            omega1=args["omega1"],
            omega2=args["omega2"],
            dist_outcomes=args["dist_outcomes"],
            dist_mode=args["dist_mode"],
            classifier=classifier,
            mc_sample_size=3,
            lr_lambda=lr_lambda_lin(total_epochs=args["max_epochs"], fixed_epochs=args["hparams"]["opt_step_size"]),
            device=device,
            hparams=args["hparams"]
        )
    elif args["data_name"] == "pendulum":
        if args["dist_mode"] == "classify":
            classifier = ClassifierConv(
                args_classifier["num_outcomes"],
                args_classifier["num_treatments"],
                args_classifier["num_covariates"],
                hparams=args_classifier["hparams"]
            )
            classifier.load_state_dict(state_dict_classifier)
            classifier.eval()

        model = VCIConv(
            args["num_outcomes"],
            args["num_treatments"],
            args["num_covariates"],
            embed_outcomes=False,
            embed_treatments=False,
            embed_covariates=False,
            omega0=args["omega0"],
            omega1=args["omega1"],
            omega2=args["omega2"],
            dist_outcomes=args["dist_outcomes"],
            dist_mode=args["dist_mode"],
            classifier=classifier,
            mc_sample_size=3,
            device=device,
            hparams=args["hparams"]
        )
    else:
        raise ValueError("data_name not recognized")

    if state_dict is not None:
        model.load_state_dict(state_dict)

    return model

#####################################################
#                     MAIN MODEL                    #
#####################################################

class VCI(nn.Module):
    def __init__(
        self,
        num_outcomes,
        num_treatments,
        num_covariates,
        embed_outcomes=True,
        embed_treatments=True,
        embed_covariates=True,
        omega0=1.0,
        omega1=2.0,
        omega2=0.1,
        dist_outcomes="normal",
        dist_mode="match",
        classifier=None,
        mc_sample_size=30,
        lr_lambda=None,
        device="cuda",
        hparams=""
    ):
        super().__init__()
        # generic attributes
        self.num_outcomes = num_outcomes
        self.num_treatments = num_treatments
        self.num_covariates = num_covariates
        self.embed_outcomes = embed_outcomes
        self.embed_treatments = embed_treatments
        self.embed_covariates = embed_covariates
        self.dist_outcomes = dist_outcomes
        self.mc_sample_size = mc_sample_size
        self.lr_lambda = lr_lambda_exp(1, decay_rate=1.) if lr_lambda is None else lr_lambda
        # vci parameters
        self.omega0 = omega0
        self.omega1 = omega1
        self.omega2 = omega2
        self.dist_mode = dist_mode
        self.classifier = classifier
        # early-stopping
        self.best_score = -np.inf
        self.patience_trials = 0

        # set hyperparameters
        self._set_hparams(hparams)

        # individual-specific model
        self._init_indiv_model()

        # covariate-specific model
        self._init_covar_model()

        self.iteration = 0

        self.to_device(device)

    def _set_hparams(self, hparams):
        """
        Set hyper-parameters to default values or values fixed by user for those
        hyper-parameters specified in the JSON string `hparams`.
        """

        self.hparams = {
            "outcome_emb_dim": 256,
            "treatment_emb_dim": 64,
            "covariate_emb_dim": 16,
            "latent_dim": 128,
            "encoder_width": 128,
            "encoder_depth": 3,
            "decoder_width": 128,
            "decoder_depth": 3,
            "discriminator_width": 64,
            "discriminator_depth": 3,
            "generator_lr": 3e-4,
            "generator_wd": 4e-7,
            "discriminator_lr": 3e-4,
            "discriminator_wd": 4e-7,
            "discriminator_freq": 2,
            "opt_step_size": 900,
            "max_grad_norm": -1,
            "grad_skip_threshold": -1,
            "patience": 20,
        }

        # the user may fix some hparams
        if hparams is not None:
            if isinstance(hparams, str):
                with open(hparams) as f:
                    dictionary = json.load(f)
                self.hparams.update(dictionary)
            else:
                self.hparams.update(hparams)

        self.outcome_dim = (
            self.hparams["outcome_emb_dim"] if self.embed_outcomes else self.num_outcomes)
        self.treatment_dim = (
            self.hparams["treatment_emb_dim"] if self.embed_treatments else self.num_treatments)
        self.covariate_dim = (
            self.hparams["covariate_emb_dim"] * len(self.num_covariates) 
            if self.embed_covariates else sum(self.num_covariates)
        )

        return self.hparams

    def _init_indiv_model(self):
        # embeddings
        outcomes_embeddings = self.init_outcome_emb()
        treatments_embeddings = self.init_treatment_emb()
        covariates_embeddings = self.init_covariates_emb()

        # models
        encoder = self.init_encoder()
        decoder = self.init_decoder()

        self.generator = nn.ModuleDict({
            "outcomes_embeddings": outcomes_embeddings,
            "treatments_embeddings": treatments_embeddings,
            "covariates_embeddings": covariates_embeddings,
            "encoder": encoder,
            "decoder": decoder
        })
        self.encoder_eval = copy.deepcopy(encoder)

        # optimizer
        self.g_opt = torch.optim.Adam(
            self.generator.parameters(),
            lr=self.hparams["generator_lr"],
            weight_decay=self.hparams["generator_wd"],
        )
        self.g_sch = torch.optim.lr_scheduler.LambdaLR(
            self.g_opt, lr_lambda=self.lr_lambda
        )

    def _init_covar_model(self):
        if self.dist_mode == "classify":
            assert self.classifier is not None
        elif self.dist_mode == "discriminate":
            # embeddings
            outcomes_embeddings = self.init_outcome_emb()
            treatments_embeddings = self.init_treatment_emb()
            covariates_embeddings = self.init_covariates_emb()

            # model
            discriminator = self.init_discriminator()

            self.discriminator = nn.ModuleDict({
                "outcomes_embeddings": outcomes_embeddings,
                "treatments_embeddings": treatments_embeddings,
                "covariates_embeddings": covariates_embeddings,
                "discriminator": discriminator
            })

            # optimizer
            self.d_opt = torch.optim.Adam(
                self.discriminator.parameters(),
                lr=self.hparams["discriminator_lr"],
                weight_decay=self.hparams["discriminator_wd"],
            )
            self.d_sch = torch.optim.lr_scheduler.LambdaLR(
                self.d_opt, lr_lambda=self.lr_lambda
            )
        elif self.dist_mode == "match":
            pass
        else:
            raise ValueError("dist_mode not recognized")

    def encode(self, outcomes, treatments, covariates,
               distributionize=True, dist="normal", evaluate=False):
        outcomes = self.generator["outcomes_embeddings"](outcomes)
        treatments = self.generator["treatments_embeddings"](treatments)
        covariates = [emb(covars) for covars, emb in 
            zip(covariates, self.generator["covariates_embeddings"])
        ]

        if evaluate:
            out = self.encoder_eval(outcomes, treatments, covariates)
        else:
            out = self.generator["encoder"](outcomes, treatments, covariates)

        if distributionize:
            return self.distributionize(out, dist=dist)
        return out

    def decode(self, latents, treatments,
               distributionize=True, dist=None):
        treatments = self.generator["treatments_embeddings"](treatments)

        out = self.generator["decoder"](latents, treatments)

        if distributionize:
            return self.distributionize(out, dist=dist)
        return out

    def discriminate(self, outcomes, treatments, covariates):
        outcomes = self.discriminator["outcomes_embeddings"](outcomes)
        treatments = self.discriminator["treatments_embeddings"](treatments)
        covariates = [emb(covars) for covars, emb in 
            zip(covariates, self.discriminator["covariates_embeddings"])
        ]

        return self.discriminator["discriminator"](outcomes, treatments, covariates).squeeze()

    def reparameterize(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param sigma: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        eps = torch.randn_like(sigma)
        return eps * sigma + mu

    def distributionize(self, constructions, dist=None, eps=1e-3):
        if dist is None:
            dist = self.dist_outcomes

        if dist == "nb":
            mus = F.softplus(constructions[..., 0]).add(eps)
            thetas = F.softplus(constructions[..., 1]).add(eps)
            dist = NegativeBinomial(
                mu=mus, theta=thetas
            )
        elif dist == "zinb":
            mus = F.softplus(constructions[..., 0]).add(eps)
            thetas = F.softplus(constructions[..., 1]).add(eps)
            zi_logits = constructions[..., 2].add(eps)
            dist = ZeroInflatedNegativeBinomial(
                mu=mus, theta=thetas, zi_logits=zi_logits
            )
        elif dist == "normal":
            locs = constructions[..., 0]
            scales = F.softplus(constructions[..., 1]).add(eps)
            dist = Normal(
                loc=locs, scale=scales
            )
        elif dist == "bernoulli":
            logits = constructions[..., 0]
            dist = Bernoulli(
                logits=logits
            )

        return dist

    def sample(self, mu: torch.Tensor, sigma: torch.Tensor, treatments: torch.Tensor, 
            size: int = 1, distributionize: bool = True, dist: str = None) -> torch.Tensor:
        mu = mu.repeat(size, *[1]*(mu.ndim-1))
        sigma = sigma.repeat(size, *[1]*(sigma.ndim-1))
        treatments = treatments.repeat(size, *[1]*(treatments.ndim-1))

        latents = self.reparameterize(mu, sigma)

        return self.decode(latents, treatments, distributionize=distributionize, dist=dist)

    def predict(
        self,
        outcomes,
        treatments,
        covariates,
        cf_treatments,
        return_dist=False
    ):
        if cf_treatments is None:
            cf_treatments = treatments

        with torch.autograd.no_grad():
            latents_dist = self.encode(outcomes, treatments, covariates)

            outcomes_dist = self.decode(latents_dist.mean, cf_treatments)

        if return_dist:
            return outcomes_dist
        else:
            return outcomes_dist.mean

    def generate(
        self,
        outcomes,
        treatments,
        covariates,
        cf_treatments,
        return_dist=False
    ):
        if cf_treatments is None:
            cf_treatments = treatments

        with torch.autograd.no_grad():
            latents_dist = self.encode(outcomes, treatments, covariates)

            outcomes_dist_samp = self.decode(latents_dist.sample(), cf_treatments)

        if return_dist:
            return outcomes_dist_samp
        else:
            return outcomes_dist_samp.mean

    def logprob(self, outcomes, outcomes_param, dist=None):
        """
        Compute log likelihood.
        """
        if dist is None:
            dist = self.dist_outcomes

        num = len(outcomes)
        if isinstance(outcomes, list):
            sizes = torch.tensor(
                [out.size(0) for out in outcomes],
                device=outcomes[0].device
            )
            weights = torch.repeat_interleave(1./sizes, sizes, dim=0)
            outcomes_param = [
                torch.repeat_interleave(out, sizes, dim=0) 
                for out in outcomes_param
            ]
            outcomes = torch.cat(outcomes, 0)
        elif isinstance(outcomes_param[0], list):
            sizes = torch.tensor(
                [out.size(0) for out in outcomes_param[0]],
                device=outcomes_param[0][0].device
            )
            weights = torch.repeat_interleave(1./sizes, sizes, dim=0)
            outcomes = torch.repeat_interleave(outcomes, sizes, dim=0)
            outcomes_param = [
                torch.cat(out, 0)
                for out in outcomes_param
            ]
        else:
            weights = None

        if dist == "nb":
            logprob = logprob_nb_positive(outcomes,
                mu=outcomes_param[0],
                theta=outcomes_param[1],
                weight=weights
            )
        elif dist == "zinb":
            logprob = logprob_zinb_positive(outcomes,
                mu=outcomes_param[0],
                theta=outcomes_param[1],
                zi_logits=outcomes_param[2],
                weight=weights
            )
        elif dist == "normal":
            logprob = logprob_normal(outcomes,
                loc=outcomes_param[0],
                scale=outcomes_param[1],
                weight=weights
            )
        elif dist == "bernoulli":
            logprob = logprob_bernoulli_logits(outcomes,
                loc=outcomes_param[0],
                weight=weights
            )

        return (logprob.sum(0)/num).mean()

    def forward(self, outcomes, treatments, covariates, cf_treatments,
                sample_latent=True, sample_outcome=False,
                detach_encode=False, detach_eval=True):
        """
        Execute the workflow.
        """

        # q(z | y, x, t)
        latents_dist = self.encode(outcomes, treatments, covariates)
        if detach_eval:
            latents_dist_eval = self.encode(
                outcomes, treatments, covariates, evaluate=True)
        else:
            latents_dist_eval = latents_dist

        # p(y | z, t)
        outcomes_dist_samp = self.sample(
            latents_dist.mean, latents_dist.stddev, treatments, size=self.mc_sample_size)

        # p(y' | z, t')
        if sample_latent:
            cf_outcomes_dist_out = self.decode(latents_dist.rsample(), cf_treatments)
        else:
            cf_outcomes_dist_out = self.decode(latents_dist.mean, cf_treatments)
        if sample_outcome:
            cf_outcomes_out = cf_outcomes_dist_out.rsample()
        else:
            cf_outcomes_out = cf_outcomes_dist_out.mean

        # q(z | y', x, t')
        if detach_encode:
            if sample_latent:
                cf_outcomes_dist_in = self.decode(latents_dist.sample(), cf_treatments)
            else:
                cf_outcomes_dist_in = self.decode(latents_dist.mean.detach(), cf_treatments)
            if sample_outcome:
                cf_outcomes_in = cf_outcomes_dist_in.rsample()
            else:
                cf_outcomes_in = cf_outcomes_dist_in.mean
        else:
            cf_outcomes_in = cf_outcomes_out
        cf_latents_dist_eval = self.encode(
            cf_outcomes_in, cf_treatments, covariates, evaluate=detach_eval)

        return (outcomes_dist_samp, cf_outcomes_out, latents_dist_eval, cf_latents_dist_eval)

    def loss_reconstruction(self, outcomes_dist, outcomes):
        return -outcomes_dist.log_prob(outcomes).mean()

    def loss_causality(self, cf_outcomes_out, cf_treatments, covariates, cf_outcomes=None,
                       hinge_threshold=0.05, kde_kernel_std=1.0):
        if self.dist_mode == "classify":
            classifier_loss, _ = self.classifier.loss(cf_outcomes_out, cf_treatments, covariates)
            if hinge_threshold is not None:
                return F.relu(classifier_loss - hinge_threshold) + hinge_threshold
            return classifier_loss
        elif self.dist_mode == "discriminate":
            return F.softplus(-self.discriminate(cf_outcomes_out, cf_treatments, covariates)).mean()
        elif self.dist_mode == "match":
            notNone = [o != None for o in cf_outcomes]
            cf_outcomes = [o for (o, n) in zip(cf_outcomes, notNone) if n]
            cf_outcomes_out = cf_outcomes_out[notNone]

            kernel_std = [kde_kernel_std * torch.ones_like(o) 
                for o in cf_outcomes]
            return -self.logprob(cf_outcomes_out, (cf_outcomes, kernel_std), dist="normal")

    def loss_disentanglement(self, latents_dist, cf_latents_dist):
        return kldiv_normal(
            latents_dist.mean, latents_dist.stddev,
            cf_latents_dist.mean, cf_latents_dist.stddev
        ).mean()

    def metric_reversibility(self, outcomes, treatments, covariates, cf_treatments, num_cycles=1, accum=False):
        """
        Tests the reversibility of the pipeline, as per measuring counterfactual soundness paper.
        1 cycle is defined as do(cf_treatments) then do(treatments)
        """
        accum_loss = 0.
        original_outcomes = outcomes 

        for _ in range(num_cycles):
            _, cf_outcomes_out, _, _ = self.forward(
                outcomes, treatments, covariates, cf_treatments
            )

            _, cycle_back_factual_outcomes, _, _ = self.forward(
                cf_outcomes_out, cf_treatments, covariates, treatments
            )

            if accum: 
                accum_loss += torch.mean((original_outcomes - cycle_back_factual_outcomes) ** 2)
            
            # setting up for the next loop 
            outcomes = cycle_back_factual_outcomes
        
        return accum_loss if accum else torch.mean((original_outcomes - cycle_back_factual_outcomes) ** 2)

    def metric_consistency(self, outcomes, treatments, covariates, cf_treatments, num_cycles=1, accum=False):
        """
        Tests the consistency of the pipeline, as per measuring counterfactual soundness paper.
        """
        accum_loss = 0.
        original_outcomes = outcomes 

        for _ in range(num_cycles):
            outcomes_dist_samp, cf_outcomes_out, latents_dist, cf_latents_dist = self.forward(
                outcomes, treatments, covariates, cf_treatments
            )

            if accum: 
                accum_loss += self.loss_reconstruction(
                    outcomes_dist_samp, 
                    original_outcomes.repeat(self.mc_sample_size, *[1]*(outcomes.dim()-1))
                )
            
            # setting up for the next loop 
            outcomes = outcomes_dist_samp.mean.reshape((-1, *original_outcomes.shape)).mean(axis=0)
        
        return accum_loss if accum else self.loss_reconstruction(
            outcomes_dist_samp, 
            original_outcomes.repeat(self.mc_sample_size, *[1]*(outcomes.dim()-1))
        )

    def load_oracle_model(self, oracle_model_checkpoint_path='/home/*anonymized*/Documents/variational-causal-inference/artifact/classifier/saves/celebA-classifier-infused-deep_2023.08.31_00:40:15/model_seed=None_epoch=1999.pt'):
        if hasattr(self, "oracle_model"): 
            return self.oracle_model

        oracle_weights, oracle_state_dict = torch.load(oracle_model_checkpoint_path)

        num_outcomes = oracle_state_dict['num_outcomes']
        num_treatments = oracle_state_dict['num_treatments']
        num_covariates = oracle_state_dict['num_covariates']
        hparams = oracle_state_dict['hparams']

        oracle_model = ClassifierConv(
            num_outcomes=num_outcomes, 
            num_treatments=num_treatments, 
            num_covariates=num_covariates, 
            hparams=hparams
        )
        oracle_model.load_state_dict(oracle_weights)
        self.oracle_model = oracle_model

        return oracle_model
        
    def metric_effectiveness(self, outcomes, treatments, covariates, cf_treatments, num_cycles=1, accum=False, oracle_model_checkpoint_path='/home/*anonymized*/Documents/variational-causal-inference/artifact/classifier/saves/celebA-classifier-infused-deep_2023.08.31_00:40:15/model_seed=None_epoch=1999.pt'):
        """
        Tests the effectiveness of the pipeline, as per measuring counterfactual soundness paper.
        """
        oracle_model = self.load_oracle_model(oracle_model_checkpoint_path=oracle_model_checkpoint_path)
        distance_metric = lambda a, b: torch.nn.functional.mse_loss(a, b)
        
        _, cf_outcomes_out, _, _ = self.forward(
            outcomes, treatments, covariates, cf_treatments
        )

        cf_treatments_extracted = oracle_model(cf_outcomes_out, covariates)

        return distance_metric(cf_treatments, cf_treatments_extracted)

    def loss(self, outcomes, treatments, covariates, cf_treatments, cf_outcomes=None):
        """
        Compute losses.
        """
        outcomes_dist_samp, cf_outcomes_out, latents_dist, cf_latents_dist = self.forward(
            outcomes, treatments, covariates, cf_treatments
        )

        # (1) individual-specific likelihood
        indiv_spec_nllh = self.loss_reconstruction(
            outcomes_dist_samp,
            outcomes.repeat(self.mc_sample_size, *[1]*(outcomes.dim()-1))
        )

        # (2) covariate-specific likelihood
        covar_spec_nllh = self.loss_causality(
            cf_outcomes_out, cf_treatments, covariates, cf_outcomes=cf_outcomes
        )

        # (3) kl divergence
        kl_divergence = self.loss_disentanglement(latents_dist, cf_latents_dist)

        # # (4) cf outcomes out loss
        # if cf_outcomes is not None: 
        #     gt_cf_loss = torch.nn.functional.mse_loss(cf_outcomes, cf_outcomes_out)
        
        print({"Indiv-spec NLLH": indiv_spec_nllh.item(),
            "Covar-spec NLLH": covar_spec_nllh.item(),
            "KL Divergence": kl_divergence.item(),
            # "GT CF Loss": gt_cf_loss.item(),
        })

        return (self.omega0 * indiv_spec_nllh
            + self.omega1 * covar_spec_nllh
            + self.omega2 * kl_divergence
        ), {"Indiv-spec NLLH": indiv_spec_nllh.item(),
            "Covar-spec NLLH": covar_spec_nllh.item(),
            "KL Divergence": kl_divergence.item(),
            # "GT CF Loss": gt_cf_loss.item(),
        }

    def loss_discriminator(self, outcomes, treatments, covariates, cf_treatments):
        cf_outcomes = self.generate(outcomes, treatments, covariates, cf_treatments)

        score_real = self.discriminate(outcomes, treatments, covariates)
        score_fake = self.discriminate(cf_outcomes.detach(), cf_treatments, covariates)

        loss_real = F.softplus(-score_real).mean()
        loss_fake = F.softplus(score_fake).mean()

        return (loss_real+loss_fake)/2., {
            "Real Sample Loss": loss_real.item(),
            "Fake Sample Loss": loss_fake.item()
        }

    def update(self, batch, batch_idx, ema_model):
        outcomes, treatments, covariates, cf_treatments, cf_outcomes = batch

        loss_log = {}

        if self.dist_mode == "discriminate":
            if (batch_idx+1) % self.hparams["discriminator_freq"] == 0:
                d_loss, d_log = self.loss_discriminator(outcomes, treatments, covariates, cf_treatments)

                self.d_opt.zero_grad()
                d_loss.backward()
                if self.hparams["max_grad_norm"] > 0:
                    d_grad_norm = nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.hparams["max_grad_norm"])
                else:
                    d_grad_norm = total_grad_norm_(self.discriminator.parameters())
                d_log["d_grad_norm"] = d_grad_norm.item()
                

                if self.hparams["grad_skip_threshold"] < 0 or d_grad_norm < self.hparams["grad_skip_threshold"]:
                    self.d_opt.step()
                    update_ema(self, ema_model, self.hparams["ema_rate"])
                    d_log["d_did_skip_step"] = False
                else:
                    d_log["d_did_skip_step"] = True


                loss_log.update(d_log)

        # set cf_outcomes to None here so morpho doesn't train on 
        g_loss, g_log = self.loss(outcomes, treatments, covariates, cf_treatments, cf_outcomes=None)

        self.g_opt.zero_grad()
        g_loss.backward()
        if self.hparams["max_grad_norm"] > 0:
            g_grad_norm = nn.utils.clip_grad_norm_(self.generator.parameters(), self.hparams["max_grad_norm"])
        else:
            g_grad_norm = total_grad_norm_(self.generator.parameters())
        g_log["g_grad_norm"] = g_grad_norm.item()

        if self.hparams["grad_skip_threshold"] < 0 or g_grad_norm < self.hparams["grad_skip_threshold"]:
            self.g_opt.step()
            update_ema(self, ema_model, self.hparams["ema_rate"])
            g_log["g_did_skip_step"] = False
        else: 
            g_log["g_did_skip_step"] = True

        loss_log.update(g_log)

        self.iteration += 1

        return loss_log

    def step(self):
        for target_param, param in zip(
            self.encoder_eval.parameters(), self.generator["encoder"].parameters()
        ):
            target_param.data.copy_(param.data)

        if self.dist_mode == "discriminate":
            self.d_sch.step()
        self.g_sch.step()

    def early_stopping(self, score=None):
        if score is None:
            return False

        if score > self.best_score:
            self.best_score = score
            self.patience_trials = 0
        else:
            self.patience_trials += 1

        return self.patience_trials > self.hparams["patience"]

    def init_outcome_emb(self):
        if self.embed_outcomes:
            return MLP([self.num_outcomes, self.hparams["outcome_emb_dim"]])
        else:
            return nn.Identity()

    def init_treatment_emb(self):
        if self.embed_treatments:
            return MLP([self.num_treatments, self.hparams["treatment_emb_dim"]])
        else:
            return nn.Identity()

    def init_covariates_emb(self):
        if self.embed_covariates:
            covariates_emb = []
            for num_cov in self.num_covariates:
                covariates_emb.append(
                    MLP([num_cov, self.hparams["covariate_emb_dim"]])
                )
            return nn.ModuleList(covariates_emb)
        else:
            return nn.ModuleList([nn.Identity()]*len(self.num_covariates))

    def init_encoder(self):
        return MLP([self.outcome_dim+self.treatment_dim+self.covariate_dim]
            + [self.hparams["encoder_width"]] * (self.hparams["encoder_depth"] - 1)
            + [self.hparams["latent_dim"]],
            heads=2, final_act="relu"
        )

    def init_decoder(self):
        if self.dist_outcomes == "nb":
            heads = 2
        elif self.dist_outcomes == "zinb":
            heads = 3
        elif self.dist_outcomes == "normal":
            heads = 2
        elif self.dist_outcomes == "bernoulli":
            heads = 1
        else:
            raise ValueError("dist_outcomes not recognized")

        return MLP([self.hparams["latent_dim"]+self.treatment_dim]
            + [self.hparams["decoder_width"]] * (self.hparams["decoder_depth"] - 1)
            + [self.num_outcomes],
            heads=heads
        )

    def init_discriminator(self):
        return MLP([self.outcome_dim+self.treatment_dim+self.covariate_dim]
            + [self.hparams["discriminator_width"]] * (self.hparams["discriminator_depth"] - 1)
            + [1]
        )

    def to_device(self, device):
        self.device = device
        self.to(self.device)

    @classmethod
    def defaults(self):
        """
        Returns the list of default hyper-parameters for VCI
        """

        return self._set_hparams(self, "")

#####################################################
#                     EXTENSIONS                    #
#####################################################

from .convolution import ConvModel

from utils.model_utils import conv_1x1, conv_3x3, parse_block_string


class VCIConv(VCI):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _set_hparams(self, hparams):
        """
        Set hyper-parameters to default values or values fixed by user for those
        hyper-parameters specified in the JSON string `hparams`.
        """

        self.hparams = {
            "outcome_emb_dim": 32,
            "treatment_emb_dim": 8,
            "covariate_emb_dim": 2,
            "encoder_resolution": "64,32,16,8,4,1",
            "encoder_width": "32,64,128,256,512,1024",
            "encoder_depth": "2,3,3,2,1,2",
            "decoder_resolution": "1,4,8,16,32,64",
            "decoder_width": "1024,512,256,128,64,32",
            "decoder_depth": "2,1,2,3,3,2",
            "discriminator_resolution": "64,32,16,8,4,1",
            "discriminator_width": "32,64,128,256,512,1024",
            "discriminator_depth": "2,3,3,2,1,4",
            "generator_lr": 0.0003,
            "generator_wd": 4e-07,
            "discriminator_lr": 0.0003,
            "discriminator_wd": 4e-07,
            "discriminator_freq": 2,
            "opt_step_size": 400,
            "max_grad_norm": -1,
            "grad_skip_threshold": -1,
            "patience": 20
        }

        # the user may fix some hparams
        if hparams is not None:
            if isinstance(hparams, str):
                with open(hparams) as f:
                    dictionary = json.load(f)
                self.hparams.update(dictionary)
            else:
                self.hparams.update(hparams)

        self.outcome_dim = (
            (self.hparams["outcome_emb_dim"], *self.num_outcomes[1:]) 
            if self.embed_outcomes else self.num_outcomes)
        self.treatment_dim = (
            self.hparams["treatment_emb_dim"] if self.embed_treatments else self.num_treatments)
        self.covariate_dim = (
            self.hparams["covariate_emb_dim"] * len(self.num_covariates) 
            if self.embed_covariates else sum(self.num_covariates)
        )

        return self.hparams

    def init_outcome_emb(self):
        if self.embed_outcomes:
            return conv_1x1(self.num_outcomes[0], self.hparams["outcome_emb_dim"])
        else:
            return nn.Identity()

    def init_encoder(self):
        return ConvModel(
            *parse_block_string(
                self.hparams["encoder_resolution"],
                self.hparams["encoder_width"],
                self.hparams["encoder_depth"],
                in_size=self.outcome_dim
            ),
            num_features=self.treatment_dim+self.covariate_dim, heads=2,
            lite_blocks=True, lite_layers=True
        )

    def init_decoder(self):
        if self.dist_outcomes == "nb":
            heads = 2
        elif self.dist_outcomes == "zinb":
            heads = 3
        elif self.dist_outcomes == "normal":
            heads = 2
        elif self.dist_outcomes == "bernoulli":
            heads = 1
        else:
            raise ValueError("dist_outcomes not recognized")

        return ConvModel(
            *parse_block_string(
                self.hparams["decoder_resolution"],
                self.hparams["decoder_width"],
                self.hparams["decoder_depth"],
                out_size=self.num_outcomes
            ),
            num_features=self.treatment_dim, heads=heads,
            lite_blocks=False, lite_layers=True
        )

    def init_discriminator(self):
        return ConvModel(
            *parse_block_string(
                self.hparams["discriminator_resolution"],
                self.hparams["discriminator_width"],
                self.hparams["discriminator_depth"],
                in_size=self.outcome_dim,
                out_size=(1, 1, 1)
            ),
            num_features=self.treatment_dim+self.covariate_dim,
            lite_blocks=True, lite_layers=True, spectral_norm=True
        )


from .hierarchy import HConvEncoder, HConvDecoder


class HVCIConv(VCIConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _set_hparams(self, hparams):
        """
        Set hyper-parameters to default values or values fixed by user for those
        hyper-parameters specified in the JSON string `hparams`.
        """

        self.hparams = {
            "outcome_emb_dim": 32,
            "treatment_emb_dim": 8,
            "covariate_emb_dim": 2,
            "defuse_steps": 2,
            "encoder_resolution": "64,32,16,8,4,1",
            "encoder_width": "32,64,128,256,512,1024",
            "encoder_depth": "2,3,3,2,1,2",
            "decoder_resolution": "1,4,8,16,32,64",
            "decoder_width": "1024,512,256,128,64,32",
            "decoder_depth": "2,1,2,3,3,2",
            "discriminator_resolution": "64,32,16,8,4,1",
            "discriminator_width": "32,64,128,256,512,1024",
            "discriminator_depth": "2,3,3,2,1,4",
            "generator_lr": 0.0003,
            "generator_wd": 4e-07,
            "discriminator_lr": 0.0003,
            "discriminator_wd": 4e-07,
            "discriminator_freq": 2,
            "opt_step_size": 400,
            "max_grad_norm": -1,
            "grad_skip_threshold": -1,
            "patience": 20
        }

        # the user may fix some hparams
        if hparams is not None:
            if isinstance(hparams, str):
                with open(hparams) as f:
                    dictionary = json.load(f)
                self.hparams.update(dictionary)
            else:
                self.hparams.update(hparams)

        self.outcome_dim = (
            (self.hparams["outcome_emb_dim"], *self.num_outcomes[1:]) 
            if self.embed_outcomes else self.num_outcomes)
        self.treatment_dim = (
            self.hparams["treatment_emb_dim"] if self.embed_treatments else self.num_treatments)
        self.covariate_dim = (
            self.hparams["covariate_emb_dim"] * len(self.num_covariates) 
            if self.embed_covariates else sum(self.num_covariates)
        )

        return self.hparams

    def encode(self, outcomes, treatments, covariates,
               distributionize=True, dist="normal", evaluate=False):
        outcomes = self.generator["outcomes_embeddings"](outcomes)
        treatments = self.generator["treatments_embeddings"](treatments)
        covariates = [emb(covars) for covars, emb in 
            zip(covariates, self.generator["covariates_embeddings"])
        ]

        if evaluate:
            outs, hiddens = self.encoder_eval(outcomes, treatments, covariates)
        else:
            outs, hiddens = self.generator["encoder"](outcomes, treatments, covariates)

        if distributionize:
            return [self.distributionize(out, dist=dist) for out in outs], hiddens
        return outs, hiddens

    def decode(self, latents, treatments,
               distributionize=True, dist=None):
        treatments = self.generator["treatments_embeddings"](treatments)

        out, hiddens = self.generator["decoder"](latents, treatments)

        if distributionize:
            return self.distributionize(out, dist=dist), hiddens
        return out, hiddens

    def sample(self, mu, sigma, treatments,
               size=1, distributionize=True, dist=None):
        mu = [m.repeat(size, *[1]*(m.ndim-1)) for m in mu]
        sigma = [s.repeat(size, *[1]*(s.ndim-1)) for s in sigma]
        treatments = treatments.repeat(size, *[1]*(treatments.ndim-1))

        latents = [self.reparameterize(m, s) for m, s in zip(mu, sigma)]

        return self.decode(latents, treatments, distributionize=distributionize, dist=dist)

    def predict(
        self,
        outcomes,
        treatments,
        covariates,
        cf_treatments,
        return_dist=False
    ):
        if cf_treatments is None:
            cf_treatments = treatments

        with torch.autograd.no_grad():
            latents_dist, _ = self.encode(outcomes, treatments, covariates)

            outcomes_dist, _ = self.decode([d.mean for d in latents_dist], cf_treatments)

        if return_dist:
            return outcomes_dist
        else:
            return outcomes_dist.mean

    def generate(
        self,
        outcomes,
        treatments,
        covariates,
        cf_treatments,
        return_dist=False
    ):
        if cf_treatments is None:
            cf_treatments = treatments

        with torch.autograd.no_grad():
            latents_dist, _ = self.encode(outcomes, treatments, covariates)

            outcomes_dist_samp, _ = self.decode([d.sample() for d in latents_dist], cf_treatments)

        if return_dist:
            return outcomes_dist_samp
        else:
            return outcomes_dist_samp.mean

    def forward(self, outcomes, treatments, covariates, cf_treatments,
                sample_latent=True, sample_outcome=False,
                detach_encode=False, detach_eval=True):
        """
        Execute the workflow.
        """

        # q(z | y, x, t)
        latents_dist, hiddens_in = self.encode(outcomes, treatments, covariates)
        if detach_eval:
            latents_dist_eval, _ = self.encode(
                outcomes, treatments, covariates, evaluate=True)
        else:
            latents_dist_eval = latents_dist

        # p(y | z, t)
        outcomes_dist_samp, hiddens_out_samp = self.sample(
            [d.mean for d in latents_dist], [d.stddev for d in latents_dist],
            treatments, size=self.mc_sample_size)

        # p(y' | z, t')
        if sample_latent:
            cf_outcomes_dist_out, cf_hiddens_out = self.decode(
                [d.rsample() for d in latents_dist], cf_treatments)
        else:
            cf_outcomes_dist_out, cf_hiddens_out = self.decode(
                [d.mean for d in latents_dist], cf_treatments)
        if sample_outcome:
            cf_outcomes_out = cf_outcomes_dist_out.rsample()
        else:
            cf_outcomes_out = cf_outcomes_dist_out.mean

        # q(z | y', x, t')
        if detach_encode:
            if sample_latent:
                cf_outcomes_dist_in, _ = self.decode(
                    [d.sample() for d in latents_dist], cf_treatments)
            else:
                cf_outcomes_dist_in, _ = self.decode(
                    [d.mean.detach() for d in latents_dist], cf_treatments)
            if sample_outcome:
                cf_outcomes_in = cf_outcomes_dist_in.rsample()
            else:
                cf_outcomes_in = cf_outcomes_dist_in.mean
        else:
            cf_outcomes_in = cf_outcomes_out

        cf_latents_dist, cf_hiddens_in = self.encode(cf_outcomes_in, cf_treatments, covariates)
        if detach_eval:
            cf_latents_dist_eval, _ = self.encode(
                cf_outcomes_in, cf_treatments, covariates, evaluate=detach_eval)
        else:
            cf_latents_dist_eval = cf_latents_dist

        return (
            outcomes_dist_samp, cf_outcomes_out, latents_dist_eval, cf_latents_dist_eval,
            hiddens_in, hiddens_out_samp, cf_hiddens_in, cf_hiddens_out
        )

    def loss_reconstruction(self, outcomes_dist_samp, outcomes,
                            hiddens_out_samp, hiddens_in, hiddens_ratio=0.5):
        indiv_spec_nllh = super().loss_reconstruction(outcomes_dist_samp, outcomes)

        for h_out, h_in in zip(hiddens_out_samp, reversed(hiddens_in)):
            indiv_spec_nllh = (indiv_spec_nllh + 
                hiddens_ratio * F.mse_loss(h_out, h_in.detach())
            )
        return indiv_spec_nllh

    def loss_causality(self, cf_outcomes_out, cf_treatments, covariates,
                       cf_hiddens_in, cf_hiddens_out, cf_outcomes=None,
                       hiddens_ratio=0.5, hinge_threshold=0.05, kde_kernel_std=1):
        covar_spec_nllh = super().loss_causality(
            cf_outcomes_out, cf_treatments, covariates, cf_outcomes, hinge_threshold, kde_kernel_std
        )

        for h_in, h_out in zip(cf_hiddens_in, reversed(cf_hiddens_out)):
            covar_spec_nllh = (covar_spec_nllh + 
                hiddens_ratio * F.mse_loss(h_in, h_out.detach())
            )
        return covar_spec_nllh

    def loss_disentanglement(self, latents_dist, cf_latents_dist):
        kl_divs = [
            kldiv_normal(dist.mean, dist.stddev, cf_dist.mean, cf_dist.stddev).mean()
            for dist, cf_dist in zip(latents_dist, cf_latents_dist)
        ]
        return sum(kl_divs)

    def loss(self, outcomes, treatments, covariates, cf_treatments, cf_outcomes=None):
        """
        Compute losses.
        """
        (
            outcomes_dist_samp, cf_outcomes_out, latents_dist, cf_latents_dist,
            hiddens_in, hiddens_out_samp, cf_hiddens_in, cf_hiddens_out
        ) = self.forward(outcomes, treatments, covariates, cf_treatments)

        # (1) individual-specific likelihood
        indiv_spec_nllh = self.loss_reconstruction(
            outcomes_dist_samp,
            outcomes.repeat(self.mc_sample_size, *[1]*(outcomes.dim()-1)),
            hiddens_out_samp,
            [hidden.repeat(self.mc_sample_size, *[1]*(hidden.dim()-1)) for hidden in hiddens_in]
        )

        # (2) covariate-specific likelihood
        covar_spec_nllh = self.loss_causality(
            cf_outcomes_out, cf_treatments, covariates,
            cf_hiddens_in, cf_hiddens_out,
            cf_outcomes=cf_outcomes
        )

        # (3) kl divergence
        kl_divergence = self.loss_disentanglement(latents_dist, cf_latents_dist)

        # # (4) cf outcomes out loss
        if cf_outcomes is not None: 
            gt_cf_mae_loss = torch.mean(
                                torch.abs(cf_outcomes - cf_outcomes_out), 
                                axis=tuple(range(1, cf_outcomes.ndim))
                            )
            gt_cf_mse_loss = torch.mean(
                                torch.pow((cf_outcomes - cf_outcomes_out), 2),
                                axis=tuple(range(1, cf_outcomes.ndim))
                            )

        return_dict = {"Indiv-spec NLLH": indiv_spec_nllh.item(),
            "Covar-spec NLLH": covar_spec_nllh.item(),
            "KL Divergence": kl_divergence.item(),
        }

        if cf_outcomes is not None:
            return_dict["GT MAE CF Loss"] = gt_cf_mae_loss.detach()
            return_dict["GT MSE CF Loss"] = gt_cf_mse_loss.detach()
            return_dict["cf_outcomes_out"] = cf_outcomes_out
        
        print_dict = {k: v for k, v in return_dict.items()}

        if cf_outcomes is not None:
            print_dict["GT MAE CF Loss"] = print_dict["GT MAE CF Loss"].mean()
            print_dict["GT MSE CF Loss"] = print_dict["GT MSE CF Loss"].mean()
            print_dict.pop("cf_outcomes_out")

        print(print_dict)

        return (self.omega0 * indiv_spec_nllh
            + self.omega1 * covar_spec_nllh
            + self.omega2 * kl_divergence
        ), return_dict

    def init_encoder(self):
        return HConvEncoder(
            *parse_block_string(
                self.hparams["encoder_resolution"],
                self.hparams["encoder_width"],
                self.hparams["encoder_depth"],
                in_size=self.outcome_dim
            ),
            num_features=self.treatment_dim+self.covariate_dim, heads=2,
            defuse_steps=self.hparams["defuse_steps"]
        )

    def init_decoder(self):
        if self.dist_outcomes == "nb":
            heads = 2
        elif self.dist_outcomes == "zinb":
            heads = 3
        elif self.dist_outcomes == "normal":
            heads = 2
        elif self.dist_outcomes == "bernoulli":
            heads = 1
        else:
            raise ValueError("dist_outcomes not recognized")

        return HConvDecoder(
            *parse_block_string(
                self.hparams["decoder_resolution"],
                self.hparams["decoder_width"],
                self.hparams["decoder_depth"],
                out_size=self.num_outcomes
            ),
            num_features=self.treatment_dim, heads=heads,
            construct_steps=self.hparams["defuse_steps"]
        )
