import flowcon.nn.nets
import torch
import time
import json
import numpy as np
import matplotlib.pyplot as plt
from datetime import timedelta
from flowcon import transforms, flows
from flowcon.distributions import StandardNormal
from flowcon.transforms.coupling import PiecewiseRationalQuadraticCouplingTransform, PiecewiseCubicCouplingTransform
import flowcon.nn.nets as nets
from torch.utils.data import DataLoader

from margflow.other_models.mog_base import MOG
from margflow.trainer import gen_cooling_schedule
from margflow.utils.training_utils import check_tuple, ConditionalDataset
from margflow.abstract_model import AbstractModel

def create_spline_coupling_layer(input_dim, num_bins=8):
    # Alternate binary mask: 1s for first half, 0s for second half
    mask = torch.arange(0, input_dim) % 2
    def transform_net(in_features, out_features):
        # Create an MLP that expects two inputs (input, context)
        # but we’ll ignore the context in the unconditional case
        class UnconditionalNet(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.net = nets.MLP(
                    in_shape=[in_features],
                    out_shape=[out_features],
                    hidden_sizes=[128,]*3,
                )

            def forward(self, inputs, context=None):
                return self.net(inputs)

        return UnconditionalNet()

    return PiecewiseRationalQuadraticCouplingTransform(
        mask=mask,
        transform_net_create_fn=transform_net,
        num_bins=num_bins,
        tails="linear",
        tail_bound=3.0
    )

class NormalizingFlow(AbstractModel):
    def __init__(
        self,
        x_dim,
        n_layers,
        direction="reverse",
        signature=None,
        script_path=None,
        cond_dim=None,
        device="cuda",
        dtype=torch.float32,
    ):
        super(NormalizingFlow, self).__init__(
            model_name="normalizing_flow",
            x_dim=x_dim,
            script_path=script_path,
            signature=signature,
            device=device,
            dtype=dtype,
        )
        self.n_layers = n_layers
        self.cond_dim = cond_dim
        self.direction = direction
        self._build_flow(direction=direction)

        self.trainable_params = {"flow": self.flow}
        self.set_model_signature()

    def _build_flow(self, direction):
        # base_distribution = distributions.StandardNormal(shape=[self.n_dim])
        n_base_means = 50
        width = 1
        random_means = torch.randn(n_base_means, self.x_dim, device=self.device) * width
        stds = torch.ones_like(random_means)
        base_distribution = StandardNormal([self.x_dim])
        # base_distribution = MOG(means=random_means, stds=stds)
        # base_distribution = distributions.MADEMoG(features=self.n_dim, hidden_features=self.n_hidden_features,
        #                                           context_features=None, num_mixture_components=50)
        hid_dim = 128
        n_shared_embeddings = 128
        densenet_factory = (
            transforms.lipschitz.iResBlock.Factory()
            .set_logabsdet_estimator(
                brute_force=False
            )  # expensive to compute inverse transformation
            .set_densenet(
                dimension=self.x_dim,
                densenet_depth=5,
                densenet_growth=64,  # 4*self.n_dim #such that it scales with the dimensionality and it is even
                activation_function=nets.CSin(10),
                condition_input=False if self.cond_dim is None else True,
                condition_multiplicative=False if self.cond_dim is None else True,
                context_features=None if self.cond_dim is None else n_shared_embeddings,
            )  # can be adjusted
        )

        transforms_list = []
        if direction == "reverse":
            # transforms.InverseTransform() must be applied such that the efficient direction is the inverse (instead of forward)
            for _ in range(self.n_layers):
                transforms_list.append(
                    transforms.InverseTransform(transforms.ActNorm(features=self.x_dim))
                )
                transforms_list.append(transforms.InverseTransform(densenet_factory.build()))
                # transforms_list.append(transforms.InverseTransform(transforms.RandomPermutation(features=self.n_dim)))
                # transforms_list.append(transforms.InverseTransform(transforms.ActNorm(features=self.n_dim)))
                # transforms_list.append(transforms.InverseTransform(transforms.SVDLinear(features=self.n_dim, num_householder=4)))
                # transforms_list.append(transforms.InverseTransform(transforms.MaskedSumOfSigmoidsTransform(features=self.x_dim, hidden_features=hid_dim, num_blocks=3, n_sigmoids=30)))
        elif direction == "forward":
            for _ in range(self.n_layers):
                # transforms_list.append(transforms.RandomPermutation(features=self.n_dim))
                # transforms_list.append(transforms.ActNorm(features=self.n_dim))
                # transforms_list.append(transforms.SVDLinear(features=self.n_dim, num_householder=4))
                # transforms_list.append(transforms.MaskedSumOfSigmoidsTransform(features=self.n_dim, hidden_features=self.n_hidden_features, num_blocks=3, n_sigmoids=10))
                # transforms_list.append(transforms.MaskedSumOfSigmoidsTransform(features=self.x_dim, hidden_features=hid_dim,
                #                                                              context_features=None if self.cond_dim is None else n_shared_embeddings, num_blocks=3, n_sigmoids=30))
                transforms_list.append(densenet_factory.build())
                # transforms_list.append(create_spline_coupling_layer(self.x_dim, num_bins=16))
                transforms_list.append(transforms.ActNorm(features=self.x_dim))
        else:
            raise ValueError("Direction must be either forward or reverse")

        transform = transforms.CompositeTransform(transforms_list)
        embedding_net = (
            None
            if self.cond_dim is None
            else (
                flowcon.nn.nets.ResidualNet(
                    in_features=self.cond_dim,
                    out_features=n_shared_embeddings,
                    hidden_features=256,
                    num_blocks=5,
                    activation=torch.nn.functional.silu,
                )
            )
        )
        self.flow = flows.Flow(
            transform=transform, distribution=base_distribution, embedding_net=embedding_net
        )
        self.flow.to(self.device)

    def sample(self, n_samples, context=None, **kwargs):
        samples = self.flow.sample(n_samples, context=context)
        return samples

    def log_prob(self, x, context=None, **kwargs):
        log_prob = self.flow.log_prob(x, context=context)
        return log_prob

    def sample_and_log_prob(self, n_samples, context=None, **kwargs):
        samples, log_prob = self.flow.sample_and_log_prob(num_samples=n_samples, context=context)
        return samples, log_prob

    def train_reverse(
        self,
        dataset,
        lr=1e-3,
        batch_size=1000,
        n_epochs=5000,
        metrics=None,
        overwrite=False,
    ):
        assert (
            self.direction == "reverse"
        ), "flow must be built in reverse direction for efficient training"

        if self.existing_trained_model(overwrite):
            self.load_trained_model()
        else:
            print("+++++ Training normalizing flow ++++++")
            print(f"Model has {self.count_parameters()} trainable parameters")
            optimizer = torch.optim.Adam(self.flow.parameters(), lr=lr)
            loss = []
            logging_dict = {}
            training_time = 0
            temperature = gen_cooling_schedule(
                T0=5, Tn=1, n_epochs=n_epochs, share_active_epochs=0.5, scheme="exp_mult"
            )
            if metrics is not None:
                if hasattr(dataset, "logp_estimator"):
                    val_samples = dataset.sample_estimator(n_samples=batch_size)
                else:
                    val_samples = dataset.sample(n_samples=batch_size)
            try:
                start_time = time.monotonic()
                self.flow.train()
                for epoch in range(n_epochs):
                    loss_dict = {}
                    metrics_dict = {}
                    T = temperature(epoch) if temperature is not None else 1.0

                    optimizer.zero_grad()
                    samples, logprob_flow = self.sample_and_log_prob(
                        n_samples=batch_size, context=None
                    )
                    logprob_target = dataset.log_prob(samples)  # uniform on lp manifold

                    kl_div = torch.mean(logprob_flow - logprob_target / T)
                    kl_div.backward()
                    kl_div_orig = torch.mean(logprob_flow - logprob_target).detach()

                    optimizer.step()
                    loss_dict["temp"] = T
                    loss_dict["kl_div_T"] = kl_div.cpu().detach().numpy().item()
                    loss_dict["kl_div"] = kl_div_orig.item()
                    loss.append(loss_dict["kl_div"])
                    # if epoch % (n_epochs // 20) == 0:
                    #     runtime = time.monotonic() - start_time
                    #     runtime_df = {"runtime": runtime}
                    #     logs_dict = runtime_df | loss_dict | metrics_dict
                    #     if not logging_dict:
                    #         logging_dict = {key: [] for key in logs_dict}
                    #     for key, value in logs_dict.items():
                    #         logging_dict[key].append(value)
                    n_val_steps = min(100, n_epochs)
                    if epoch % (n_epochs // n_val_steps) == 0:
                        training_time += time.monotonic() - start_time
                        runtime_df = {"runtime": training_time}
                        with torch.no_grad():
                            if metrics is not None and self.cond_dim is None:
                                metrics_dict = self.evaluate_metrics(
                                    metrics,
                                    val_samples=val_samples,
                                    n_samples=batch_size,
                                    dataset=dataset,
                                )

                            logs_dict = runtime_df | loss_dict | metrics_dict
                            if not logging_dict:
                                logging_dict = {key: [] for key in logs_dict}
                            for key, value in logs_dict.items():
                                logging_dict[key].append(value)

                            print(
                                f"Epoch {epoch} losses: "
                                + ", ".join(
                                    f"{key}:{value:.3f}" for key, value in loss_dict.items()
                                )
                            )
                            if metrics is not None:
                                print(
                                    f"Epoch {epoch} other metrics: "
                                    + ", ".join(
                                        f"{key}:{value:.3f}" for key, value in metrics_dict.items()
                                    )
                                )
                        start_time = time.monotonic()

            except KeyboardInterrupt:
                print("interrupted...")

            end_time = time.monotonic()
            time_diff = timedelta(seconds=end_time - start_time)
            print(f"Training normalizing flow took {time_diff} seconds")

            loss = np.array(loss)
            plt.plot(range(len(loss)), loss)
            plt.show()

            with open(f"{self.model_path}.json", "w") as file:
                json.dump(logging_dict, file, indent=4)

            self.save_trained_model()

    def train_forward(
        self,
        dataset,
        batch_size,
        fixed_datapoints=True,
        lr=1e-3,
        n_epochs=1000,
        metrics=None,
        save_best_val=False,
        overwrite=False,
    ):
        assert (
            self.direction == "forward"
        ), "flow must be built in forward (normalizing) direction for efficient training"
        if self.existing_trained_model(overwrite):
            self.load_trained_model()
        else:
            log_lik_test_prev = 100000
            if save_best_val:
                no_improv = 0
            if fixed_datapoints:
                train_samples, val_samples, _ = dataset.load_dataset(overwrite=False)
                train_samples, train_context = check_tuple(
                    train_samples, move_to_torch=True, device=self.device
                )
                val_samples, val_context = check_tuple(
                    val_samples, move_to_torch=True, device=self.device
                )

                if train_context is not None:
                    train_samples = ConditionalDataset(train_samples, train_context)
                    val_samples = ConditionalDataset(val_samples, val_context)
                    train_dataset = DataLoader(train_samples, batch_size=batch_size, shuffle=True)
                    val_dataset = DataLoader(val_samples, batch_size=batch_size, shuffle=False)
                else:
                    train_dataset = DataLoader(train_samples, batch_size=batch_size, shuffle=True)
                    val_dataset = DataLoader(val_samples, batch_size=batch_size, shuffle=False)

            print("+++++ Training normalizing flow ++++++")
            n_parameters = self.count_parameters()
            print(f"Model has {n_parameters} trainable parameters")
            optimizer = torch.optim.Adam(self.flow.parameters(), lr=lr)

            loss = []
            loss_val = []
            logging_dict = {}
            try:
                start_time = time.monotonic()
                start_time_fixed = time.monotonic()
                training_time = 0.0
                self.flow.train()
                for epoch in range(n_epochs):
                    loss_dict = {}
                    metrics_dict = {}
                    if fixed_datapoints:
                        for data in train_dataset:
                            optimizer.zero_grad()
                            data, context = check_tuple(
                                data, move_to_torch=True, device=self.device
                            )
                            log_prob = self.log_prob(data, context=context)
                            log_likelihood = -torch.mean(log_prob)
                            log_likelihood.backward()
                            optimizer.step()
                    else:
                        data = dataset.sample(batch_size, "train")
                        data, context = check_tuple(data, move_to_torch=True, device=self.device)
                        optimizer.zero_grad()
                        log_prob = self.log_prob(data, context=context)
                        log_likelihood = -torch.mean(log_prob)
                        log_likelihood.backward()
                        optimizer.step()

                    log_lik_ = log_likelihood.cpu().detach().numpy()
                    loss_dict["Train loss (NLL)"] = log_lik_.item()
                    loss.append(log_likelihood.item())

                    n_val_steps = min(100, n_epochs)
                    if epoch % (n_epochs // n_val_steps) == 0:
                        training_time += time.monotonic() - start_time
                        runtime_df = {"runtime": training_time}
                        with torch.no_grad():
                            if metrics is not None and self.cond_dim is None:
                                _, _, test_samples = dataset.load_dataset(overwrite=False)
                                test_samples = (
                                    torch.from_numpy(test_samples).float().to(self.device)
                                )
                                metrics_dict = self.evaluate_metrics(
                                    metrics,
                                    val_samples=test_samples,
                                    n_samples=batch_size,
                                    dataset=dataset,
                                )

                            logs_dict = runtime_df | loss_dict | metrics_dict
                            if not logging_dict:
                                logging_dict = {key: [] for key in logs_dict}
                            for key, value in logs_dict.items():
                                logging_dict[key].append(value)

                            if fixed_datapoints:
                                log_prob = []
                                for data in val_dataset:
                                    data, context = check_tuple(data)
                                    log_prob_flow_test = self.log_prob(data, context=context)
                                    log_prob.append(log_prob_flow_test)
                                log_prob = torch.cat(log_prob, -1)
                                log_lik_test = -log_prob.mean()
                                loss_dict["log_lik_test"] = log_lik_test.item()
                                loss_val.append(log_lik_test.item())
                                if save_best_val:
                                    if log_lik_test < log_lik_test_prev:
                                        print("saved model")
                                        log_lik_test_prev = log_lik_test
                                        self.save_trained_model()
                                        no_improv = 0
                                    else:
                                        no_improv += 1
                                    if no_improv > 50:
                                        break

                            print(
                                f"Epoch {epoch} losses: "
                                + ", ".join(
                                    f"{key}:{value:.3f}" for key, value in loss_dict.items()
                                )
                            )
                            if metrics is not None:
                                print(
                                    f"Epoch {epoch} other metrics: "
                                    + ", ".join(
                                        f"{key}:{value:.3f}" for key, value in metrics_dict.items()
                                    )
                                )
                        start_time = time.monotonic()

            except KeyboardInterrupt:
                print("interrupted...")

            end_time = time.monotonic()
            time_diff = timedelta(seconds=end_time - start_time_fixed)
            print(f"Training took {time_diff} seconds")
            for losses, name in [(np.array(loss), "train"), (np.array(loss_val), "val")]:
                plt.plot(range(len(losses)), losses)
                plt.title(name)
                plt.show()

            if hasattr(self, "model_path"):
                with open(f"{self.model_path}.json", "w") as file:
                    json.dump(logging_dict, file, indent=4)

            if save_best_val:
                self.load_trained_model()
            else:
                self.save_trained_model()
