#!/usr/bin/env python
"""
SFE with information gain limit
"""

# %%
import math
import sys
import os
import argparse
import vti.utils.logging as logging
from vti.utils.logging import _bm, _dm


import torch
from torch import optim
from torch.nn import functional as F

from torch.optim.lr_scheduler import (
    ChainedScheduler,
    CosineAnnealingWarmRestarts,
    ExponentialLR,
    ConstantLR,
    SequentialLR,
)

# from torch.distributions import Categorical

from vti.optim.rm_lr import RobbinsMonroScheduler
from vti.utils.plots import plot_fit_marginals

from vti.model_samplers.sfedag import ScoreFunctionEstimatorMADEDAG
from vti.dgp.dgp_factory import create_dgp
from vti.utils.debug import check_for_nans, tonp
from vti.utils.torch_nn_helpers import print_gradient_l2_norm
from vti.distributions.uniform_dag import PermutationDAGUniformDistribution

PRINTIG = False
PRINTGRAD = False
PRINTPROBS = False  # will not work, always keep as False
DOPLOTS = False
CHECKNANS = True
# torch.autograd.set_detect_anomaly(True)


def parse_args():
    """
    Parse CLI args in a way that is robust to Jupyter notebooks and scripts.
    """
    parser = argparse.ArgumentParser(description="experiment args", allow_abbrev=False)
    # parser.add_argument(
    #    "--num-categories",
    #    type=int,
    #    default=10,
    #    help="Number of categories (default: 10)",
    # )
    parser.add_argument(
        "--num-nodes",
        type=int,
        default=3,
        help="Number of graph nodes (dimenson of problem, default: 3)",
    )
    parser.add_argument(
        "--num-iterations",
        type=int,
        default=10000,
        help="Number of iterations (default: 10000)",
    )
    parser.add_argument(
        "--batch-size", type=int, default=32, help="Batch size (default: 32)"
    )
    parser.add_argument(
        "--ig-threshold",
        type=float,
        default=1e-3,
        help="Information gain threshold (default: 1e-3). If >0 information will be capped",
    )
    parser.add_argument(
        "--sfe-lr",
        type=float,
        default=1e-3,
        help="Score function estimator learning rate (default: 1e-3)",
    )
    parser.add_argument(
        "--flow-type",
        type=str,
        choices=["diagnorm", "affine2", "affine5", "affine7", "spline"],
        default="affine2",
        help="Variational density type. Choose from 'affine2', 'affine5', 'affine7', 'spline', and 'diagnorm'. (default: 'diagnorm')",
    )
    parser.add_argument(
        "--dgp",
        type=str,
        choices=["lineardag"],
        default="lineardag",
        dest="dgp_key",
        help="Data generating process. Only choice is lineardag.",
    )
    parser.add_argument(
        "--device",
        type=str,
        choices=["cuda", "cpu"],
        default=None,
        help="Device to run on (default: 'cuda' if available, else 'cpu')",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        choices=["float32", "float64"],
        default="float64",
        help="Dtype to use (default: 'float64')",
    )
    parser.add_argument(
        "--resume", type=str, default=None, help="Filename to resume from (optional)"
    )
    # parser.add_argument(
    #     '--job-id', type=str, default="0",
    #     help="Job ID (optional)")
    parser.add_argument(
        "--output-dir",
        type=str,
        default="output",
        help="Output directory (default: 'output')",
    )
    parser.add_argument(
        "--plot", action="store_true", default=False, help="Plot target (default)"
    )
    parser.set_defaults(plot=True)

    known_args = parser.parse_known_args()[0]
    return vars(known_args)


class SFEDAGProblem(object):
    """
    SFE strictly for permutation formulation of DAG DGP
    using the MADEPlus model for the model indicator
    """

    @staticmethod
    def createDir(output_dir):
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir, exist_ok=True)

    def __init__(
        self,
        dgp,
        ig_threshold,
        flow_type="affine",  # 'diagnorm' or 'spline'
        job_id=0,
        output_dir="output",
        device=None,
        dtype=None,
    ):
        self.job_id = job_id
        self.device = device
        self.dtype = dtype
        self.num_inputs = dgp.num_inputs()
        # self.num_categories = dgp.num_categories()
        self.num_context_inputs = dgp.num_context_features()

        self.ig_threshold = ig_threshold

        # output
        self.output_dir = output_dir
        self.createDir(output_dir)
        self.model_save_path = "{}/best_model_{}.pt".format(
            self.output_dir, self.job_id
        )  # Define the path to save the model

        # data generating process
        self.dgp = dgp

        # construct the parameter flow
        self.delaysfe = 10

        self.param_transform = self.dgp.construct_param_transform(flow_type)
        self.flow_lr = self.dgp.flow_lr
        # self.sfe_lr = self.dgp.sfe_lr
        self.sfe_lr = self.dgp.get_sfe_lr()
        logging.info("SFE LR = ", self.sfe_lr)

        self.num_nodes = self.dgp.num_dag_nodes()  # will fail if DGP not a DAG
        logging.info("num nodes = ", self.num_nodes)

        self.prior_mk_dist = PermutationDAGUniformDistribution(
            num_nodes=self.num_nodes, dtype=self.dtype, device=self.device
        )

        # score function estimator
        self.sfe_mk_dist = ScoreFunctionEstimatorMADEDAG(
            self.num_nodes, device=self.device, dtype=self.dtype
        )

    def sample_reference_dist(self, batch_size):
        base_samples, base_log_prob = self.dgp.reference_dist_sample_and_log_prob(
            batch_size
        )
        return base_samples, base_log_prob

    def setup_optimizer(self, num_iterations):
        logging.info("setting up optimizer")
        self.flow_optimizer = optim.Adam(
            [
                {"params": self.param_transform.parameters(), "lr": self.flow_lr},
            ],
            lr=self.flow_lr,
        )
        self.sfe_optimizer = optim.SGD(
            [
                {"params": self.sfe_mk_dist.parameters(), "lr": self.sfe_lr},
            ],
            lr=self.sfe_lr,
        )
        self.flow_scheduler = ChainedScheduler(
            [
                CosineAnnealingWarmRestarts(
                    self.flow_optimizer,
                    T_0=100,
                    T_mult=1,
                    eta_min=1e-7,  # optionally, you can set a minimum lr
                ),
                ExponentialLR(self.flow_optimizer, gamma=1 - 1e-3),
            ]
        )
        # self.sfe_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.sfe_optimizer, T_0=50, T_mult=1,eta_min=1e-7)
        self.sfe_scheduler = SequentialLR(
            self.sfe_optimizer,
            [
                ConstantLR(self.sfe_optimizer, factor=0, total_iters=self.delaysfe),
                # ExponentialLR(self.sfe_optimizer,gamma=1.-1e-4),
                RobbinsMonroScheduler(
                    # self.sfe_optimizer, alpha=1e-1 / self.num_categories
                    self.sfe_optimizer,
                    alpha=1e-3,
                ),
            ],
            milestones=[
                self.delaysfe,
            ],
        )
        # self.sfe_scheduler = torch.optim.lr_scheduler.LinearLR(self.sfe_optimizer,1.0,1e-3,num_iterations)

        # interval (in iterations) for writing log or saving output
        self.save_interval = 50

    def loss(self, batch_size, i):
        base_samples, base_log_prob = self.sample_reference_dist(batch_size)

        mk_samples, mk_log_prob = self.sfe_mk_dist.sample_and_log_prob(batch_size)

        # logging.info("cats=",mk_catsamples)
        # TODO figure out prior
        mk_prior_log_prob = self.prior_mk_dist.log_prob(mk_samples)

        if CHECKNANS:
            check_for_nans(mk_log_prob)
            check_for_nans(mk_prior_log_prob)

        logging.info(f"tensor dtypes {base_samples.type()} {mk_samples.type()}")
        params, params_tf_log_prob = self.param_transform.inverse(
            base_samples, context=mk_samples
        )

        if CHECKNANS:
            check_for_nans(params)

        params_log_prob = base_log_prob - params_tf_log_prob
        # logging.info("params_log_prob",params_log_prob.shape)

        target_log_prob = self.dgp.log_prob(mk_samples, params)
        # logging.info("target_log_prob",target_log_prob.shape)

        loss_hat1 = -target_log_prob + params_log_prob
        loss_hat2 = -mk_prior_log_prob + mk_log_prob
        # logging.info("mk_prior_log_prob",mk_prior_log_prob.shape)
        # logging.info("mk_log_prob",mk_log_prob.shape)
        loss_hat = loss_hat1 + loss_hat2

        # check for NaNs
        if not torch.isfinite(loss_hat).all():
            logging.info(
                "WARNING: Non-finite loss:\n",
                target_log_prob,
                "\n",
                params_log_prob,
                "\n",
            )

        finite_idx = torch.isfinite(loss_hat)
        # logging.info("finite_idx",loss_hat.shape, finite_idx.shape,finite_idx)
        loss_hat = loss_hat[finite_idx]

        with torch.no_grad():
            # test gradient induced forget
            forget = 0.9
            running_forget = forget**i
            detach_loss_hat = loss_hat.detach()
            detach_mean = detach_loss_hat.nanmean().item()
            if i == 0:
                self.avg_loss_hat_biased = detach_mean
                avg_loss_hat_unbiased = 0
            else:
                self.avg_loss_hat_biased = (
                    forget * self.avg_loss_hat_biased + (1 - forget) * detach_mean
                )
                avg_loss_hat_unbiased = self.avg_loss_hat_biased / (1 - running_forget)

        self.sfe_mk_dist.set_gradients(
            mk_samples,
            detach_loss_hat - avg_loss_hat_unbiased,
            mk_log_prob[finite_idx],
            lr=self.sfe_scheduler.get_last_lr()[0],
            ig_threshold=self.ig_threshold,
        )

        return loss_hat.nanmean()

    def run_optimizer(
        self, batch_size, num_iterations, store_loss_history=True, resume=None
    ):
        if store_loss_history:
            self.loss_history = torch.zeros(
                self.save_interval,
                requires_grad=False,
                dtype=self.dtype,
                device=self.device,
            )

        minimum_loss = float("inf")
        minloss_iter = -1
        start_iter = -1
        save_delta = 1000

        if resume is not None and resume:
            logging.info(f"restarting from file {resume}")
            minimum_loss, start_iter = self.load_training_checkpoint(resume)
            minloss_iter = start_iter

        logging.info("starting optimization")

        self.avg_loss_hat_biased = 0

        last_invgrad = 0

        for i in range(start_iter + 1, num_iterations):
            self.flow_optimizer.zero_grad()
            self.sfe_optimizer.zero_grad()

            loss = self.loss(batch_size, i)

            # Backpropagation
            loss.backward()

            if PRINTIG:
                mk_prob_before = self.sfe_mk_dist.probabilities()
                # logging.info(mk_prob_before)
                check_for_nans(mk_prob_before)

            grad_norm = torch.nn.utils.clip_grad_norm_(
                parameters=self.param_transform.parameters(),
                max_norm=20.0,
                error_if_nonfinite=False,
            )

            if torch.isfinite(grad_norm).all():
                # logging.info_gradient_l2_norm(self.param_transform)
                # torch.nn.utils.clip_grad_norm_(self.param_transform.parameters(), 10)
                # logging.info_gradient_l2_norm(self.param_transform)

                # Optimizer step
                self.flow_optimizer.step()
                self.sfe_optimizer.step()
                # logging.info("Flow gradient norm = ",grad_norm)
            else:
                # Clear gradients to avoid compounding the issue
                self.flow_optimizer.zero_grad()
                self.sfe_optimizer.zero_grad()
                logging.info("Skipped update due to non-finite gradients")

            if PRINTIG:
                mk_prob_after = self.sfe_mk_dist.probabilities()
                check_for_nans(mk_prob_after)
                # logging.info(mk_prob_after)

                # compute change in information
                # H(before) - H(after)
                Hbefore = -(mk_prob_before * mk_prob_before.log()).sum()
                Hafter = -(mk_prob_after * mk_prob_after.log()).sum()
                self.IG = Hbefore - Hafter
                # logging.info("IG=",self.IG)

            # Scheduler step
            self.flow_scheduler.step()
            self.sfe_scheduler.step()

            if loss.item() < minimum_loss:  # or i - minloss_iter > save_delta:
                minimum_loss = loss.item()  # Update the minimum loss
                minloss_iter = i
                self.save_training_checkpoint(minimum_loss, i)

            # Logging or printing the loss value, every few iterations
            if i % self.save_interval == 0:
                self.debug_log(i, loss)
                if store_loss_history and i > 0:
                    torch.save(
                        self.loss_history,
                        "{}/loss_history_M{}_job{}_i{}.pt".format(
                            self.output_dir,
                            self.num_nodes,
                            self.job_id,
                            i - self.save_interval,
                        ),
                    )

            if store_loss_history:
                self.loss_history[i % self.save_interval] = loss.item()

    def debug_log(self, i, loss):
        outstr = f"Iter {i}, " + f"Loss: {loss.item():3f}, \n     "
        if PRINTPROBS:
            modelprobs = self.dgp.modelprobs
            ## set up consistent printing of probs
            # This mutates global state, but cosmetically.
            torch.set_printoptions(precision=6, sci_mode=False)
            est_probs = self.sfe_mk_dist.probabilities().detach()
            outstr += f"q: {est_probs[self.topkidx]}, \n     "
            outstr += f"p: {self.topkprobs}, \n     "
            # KL loss
            kld_loss = F.kl_div(est_probs.log(), modelprobs, log_target=False)
            outstr += f"KL: {kld_loss:3f}, \n     "
            # MSE
            mse_loss = F.mse_loss(est_probs, modelprobs)
            outstr += f"MSE: {mse_loss:3f}, \n     "
            # absolute deviation loss/ wasserstein
            abs_error = torch.sum(torch.abs(modelprobs - est_probs))
            outstr += f"AbsError: {abs_error:3f}, \n     "

        if PRINTGRAD:
            L2VG = self.l2_vargradients(epsilon=1e-6)
            # last_lr = self.flow_scheduler.get_last_lr()[0]
            # gbw = 10. # grad bandwidth
            # last_invgrad = math.exp(-(last_lr*L2VG/gbw)**2)
            # last_invgrad = math.exp(-(last_lr*L2VG)) * 0.5
            last_invgrad = math.exp(-(1e-2 * L2VG)) * 0.5
            outstr += f"Grad: {L2VG:1f}, " + f"GradKernel: {last_invgrad:3f},\n     "
            # f"sfestep: {self.sfe_scheduler.get_last_lr()[0]:3f}, ",
        if PRINTIG:
            # outstr += f"IG: {self.IG:10f}, " + \
            #            f"IG thres: {self.ig_threshold:10f}, "
            outstr += (
                f"IG: {self.IG}, "
                + f"IG thres: {self.ig_threshold}, \n     "
                + f"sfestep: {self.sfe_scheduler.get_last_lr()[0]:3f}, "
            )

        logging.info(outstr)
        sys.stdout.flush()

    def l2_vargradients(self, epsilon=None, return_var=False):
        # Initialize a list to collect all ratios
        all_m = []
        all_v = []

        for group in self.flow_optimizer.param_groups:
            for p in group["params"]:
                param_state = self.flow_optimizer.state[p]
                if "exp_avg" in param_state and "exp_avg_sq" in param_state:
                    # Use the group's epsilon if none is specified
                    if epsilon is None:
                        eps = group["eps"]
                    else:
                        eps = epsilon

                    # Retrieve the first moment (m_t) and the second moment (v_t)
                    m_t = param_state["exp_avg"]
                    v_t = param_state["exp_avg_sq"]

                    # Retrieve the step to correct the bias in m_t and v_t
                    step = param_state["step"]
                    beta1, beta2 = group["betas"]

                    # Bias correction for the first and second moments
                    m_hat_t = m_t / (1 - beta1**step)
                    v_hat_t = v_t / (1 - beta2**step)

                    # Flatten the ratio tensor and add to the list
                    all_m.append(m_hat_t.view(-1))
                    all_v.append(v_hat_t.view(-1))

        # Concatenate all ratios into a single tensor
        all_m = torch.cat(all_m)
        all_v = torch.cat(all_v)

        # Calculate the ratio m_hat_t / (sqrt(v_hat_t) + eps)
        ratio = all_m / (all_v.sqrt() + eps)

        # Compute and return the L2 norm of the concatenated ratios
        l2norm = torch.norm(ratio, p=2)
        if return_var:
            return l2norm, all_v.mean()
        else:
            return l2norm

    def save_training_checkpoint(self, loss, iteration):
        torch.save(
            {
                "mk_dist_state_dict": self.sfe_mk_dist.state_dict(),
                "param_flow_state_dict": self.param_transform.state_dict(),
                "flow_optimizer_state_dict": self.flow_optimizer.state_dict(),
                "sfe_optimizer_state_dict": self.sfe_optimizer.state_dict(),
                "loss": loss,
                "iteration": iteration,
            },
            self.model_save_path,
        )  # Save the model and optimizer state

    def load_training_checkpoint(self, filename=None):
        if filename is None:
            filename = self.model_save_path
        checkpoint = torch.load(filename)
        self.param_transform.load_state_dict(checkpoint["param_flow_state_dict"])
        self.sfe_mk_dist.load_state_dict(checkpoint["mk_dist_state_dict"])
        self.flow_optimizer.load_state_dict(checkpoint["flow_optimizer_state_dict"])
        self.sfe_optimizer.load_state_dict(checkpoint["sfe_optimizer_state_dict"])
        return checkpoint["loss"], checkpoint["iteration"]

    def plot_q(self, q_mk_probs, num_samples):
        # def plot_q(mk_identifiers,q_mk_probs,num_samples,base_dist,param_transform):
        # plot
        q_theta = []
        for k, mk in enumerate(self.dgp.mk_identifiers()):
            N = int(num_samples * q_mk_probs[k])
            if N > 0:
                # b_samples = self.base_dist._sample(N, context=None)
                b_samples, blp = self.dgp.reference_dist_sample_and_log_prob(N)
                qt, _ = self.param_transform.inverse(b_samples, context=mk.view(1, -1))
                qt = qt * self.dgp.mk_to_mask(mk.view(1, -1))
                q_theta.append(tonp(qt))
        plot_fit_marginals(q_theta[0], q_theta[1:])

    def plot_q_s(self, q_mk, q_mk_probs, num_samples=32768):
        # def plot_q(mk_identifiers,q_mk_probs,num_samples,base_dist,param_transform):
        # plot
        q_theta = []
        for k, mk in enumerate(q_mk):
            N = int(num_samples * q_mk_probs[k])
            if N > 0:
                # b_samples = self.base_dist._sample(N, context=None)
                b_samples, blp = self.dgp.reference_dist_sample_and_log_prob(N)
                qt, _ = self.param_transform.inverse(b_samples, context=mk.view(1, -1))
                qt = qt * self.dgp.mk_to_mask(mk.view(1, -1))
                q_theta.append(tonp(qt))
        plot_fit_marginals(q_theta[0], q_theta[1:])

    def plot_state(self, batch_size):
        base_samples, base_log_prob = self.sample_reference_dist(batch_size)
        mk_catsamples, mk_log_prob = self.sfe_mk_dist.sample_and_log_prob(batch_size)
        mk_samples = self.dgp.mk_cat_to_identifier(mk_catsamples)
        params, params_tf_log_prob = self.param_transform.inverse(
            base_samples, context=mk_samples
        )
        self.dgp.plot_state(mk_samples, params)


def main(
    dgp_key="lineardag",
    num_nodes=3,  # number of models / categories
    num_iterations=10000,
    batch_size=32,
    num_inputs=20,
    job_id=0,
    resume=None,
    output_dir="output_sfemade",
    flow_type="affine",
    sfe_lr=1e-3,
    plot=True,
    ig_threshold=1e-3,
    device=None,
    dtype=torch.float64,
    **kwargs,
):
    """
    set up an example of inference.
    """
    torch.set_default_dtype(dtype)
    torch.set_default_device(device)

    dgp = create_dgp(
        dgp_key=dgp_key,
        device=device,
        dtype=dtype,
        num_nodes=num_nodes,
        num_inputs=num_inputs,
        **kwargs,
    )
    logging.info("constructing problem...")
    problem = SFEDAGProblem(
        dgp,
        ig_threshold=ig_threshold,
        flow_type=flow_type,
        output_dir=output_dir,
        device=device,
        dtype=dtype,
    )
    logging.info("done!")

    # logging.info("True loss = ",problem.estimate_true_loss())
    problem.setup_optimizer(num_iterations=num_iterations)

    loss_history = problem.run_optimizer(
        batch_size=batch_size,
        num_iterations=num_iterations,
        store_loss_history=True,
        resume=resume,
    )

    # Load the saved state dictionaries
    minloss, minlossiter = problem.load_training_checkpoint()
    # mk_dist.load_state_dict(checkpoint['mk_dist_state_dict'])

    if num_nodes <= 3:
        # smoke test for low node count, plot results
        sampled_mk, sampled_mk_probs = problem.sfe_mk_dist.print_probabilities()
        problem.plot_q_s(sampled_mk, sampled_mk_probs)

    if False:
        mk_probs = problem.sfe_mk_dist.probabilities()

        # print model probs
        problem.dgp.printVTIResults(mk_probs)
        logging.info("Min loss = ", minloss, ", min loss iteration : ", minlossiter)

        if plot:
            # plot_q(problem.dgp.mk_identifiers(),mk_probs,8192,problem.base_dist,problem.param_transform)
            problem.plot_q(mk_probs, 8192)
            if hasattr(problem.dgp, "plot_state"):
                problem.plot_state(1024)


if (
    __name__ == "__main__"  # if running as a script
    and "get_ipython" not in dir()  # and not in jupyter notebook
):
    args = parse_args()
    args["dtype"] = getattr(torch, args.get("dtype", None), torch.get_default_dtype())
    device = args.get("device", None)
    device = "cuda" if device is None and torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    # set log file to output dir
    SFEDAGProblem.createDir(args["output_dir"])
    # logger = logging.getLogger(__name__)
    # logging.basicConfig(filename='{}/debug.log'.format(args['output_dir']), encoding='utf-8', level=logging.DEBUG)
    logging.set_log_directory(args["output_dir"])
    logging.info("args", args)
    main(**args)

# %%
