import torch
import numpy as np
import time
import json

from datetime import timedelta
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from margflow.abstract_model import AbstractModel
from margflow.utils.training_utils import check_tuple, ConditionalDataset

from fff.loss import volume_change_surrogate
from fff.other_losses.exact_nll import exact_nll


class SkipConnection(torch.nn.Module):
    def __init__(self, inner):
        super().__init__()
        self.inner = inner

    def forward(self, x, *args, **kwargs):
        return x + self.inner(x, *args, **kwargs)


def build_layers(input_dim, hidden_dim, output_dim, n_layers, activation):
    layers = []
    layers.append(torch.nn.Linear(input_dim, hidden_dim))
    layers.append(activation)
    skip_layers = []
    for _ in range(n_layers):
        layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
        layers.append(activation)
        # skip_layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
        # skip_layers.append(activation)
    # layers.append(SkipConnection(torch.nn.Sequential(*skip_layers)))
    layers.append(torch.nn.Linear(hidden_dim, output_dim))
    # layers.append(activation)
    return layers


class FreeFormFlow(AbstractModel):
    def __init__(
        self,
        x_dim,
        z_dim,
        n_layers,
        signature=None,
        script_path=None,
        hid_dim=128,
        device="cuda",
        dtype=torch.float32,
    ):
        super(FreeFormFlow, self).__init__(
            model_name="freeform_flow",
            x_dim=x_dim,
            script_path=script_path,
            signature=signature,
            device=device,
            dtype=dtype,
        )
        self.latent_dim = z_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers

        self._build_flow()
        self.trainable_params = {"encoder": self.encoder, "decoder": self.decoder}
        self.set_model_signature()

    def _build_flow(self):
        # define base distribution
        loc = torch.zeros(self.latent_dim, device=self.device)
        scale = torch.ones(self.latent_dim, device=self.device)
        normal = torch.distributions.Normal(loc=loc, scale=scale)
        self.base_distribution = torch.distributions.Independent(normal, 1)

        # define encoder and decoder architectures
        encoder_layers = build_layers(
            input_dim=self.x_dim,
            hidden_dim=self.hid_dim,
            output_dim=self.latent_dim,
            n_layers=self.n_layers,
            activation=torch.nn.SiLU(),
        )
        # self.encoder = SkipConnection(torch.nn.Sequential(*encoder_layers)).to(self.device)
        self.encoder = torch.nn.Sequential(*encoder_layers).to(self.device)
        decoder_layers = build_layers(
            input_dim=self.latent_dim,
            hidden_dim=self.hid_dim,
            output_dim=self.x_dim,
            n_layers=self.n_layers,
            activation=torch.nn.SiLU(),
        )
        # self.decoder = SkipConnection(torch.nn.Sequential(*decoder_layers)).to(self.device)
        self.decoder = torch.nn.Sequential(*decoder_layers).to(self.device)

    def sample(self, n_samples, context=None, **kwargs):
        samples_z = self.base_distribution.sample(torch.Size([n_samples]))
        samples = self.decoder(samples_z)

        return samples

    def log_prob(self, x, context=None, **kwargs):
        log_lik = -exact_nll(x, self.encoder, self.decoder, self.base_distribution).nll

        return log_lik

    def sample_and_log_prob(self, n_samples, context=None, **kwargs):
        samples = self.sample(n_samples=n_samples, context=context, **kwargs)
        log_prob = self.log_prob(x=samples, context=context, **kwargs)

        return samples, log_prob

    def train_forward(
        self,
        dataset,
        beta,
        n_epochs,
        batch_size,
        fixed_datapoints=True,
        metrics=None,
        save_best_val=False,
        overwrite=False,
        lr=1e-3,
    ):
        optim = torch.optim.Adam(
            [
                {"params": self.encoder.parameters(), "lr": lr},  # Model parameters
                {"params": self.decoder.parameters(), "lr": lr},
            ]
        )
        if fixed_datapoints:
            train_data = dataset.sample(batch_size, "train")
        if self.existing_trained_model(overwrite):
            self.load_trained_model()
        else:
            if fixed_datapoints:
                train_samples, val_samples, _ = dataset.load_dataset(overwrite=False)
                train_samples, train_context = check_tuple(
                    train_samples, move_to_torch=True, device=self.device
                )
                val_samples, val_context = check_tuple(
                    val_samples, move_to_torch=True, device=self.device
                )

                if train_context is not None:
                    train_samples = ConditionalDataset(train_samples, train_context)
                    val_samples = ConditionalDataset(val_samples, val_context)
                    train_dataset = DataLoader(train_samples, batch_size=batch_size, shuffle=True)
                    val_dataset = DataLoader(val_samples, batch_size=batch_size, shuffle=False)
                else:
                    train_dataset = DataLoader(train_samples, batch_size=batch_size, shuffle=True)
                    val_dataset = DataLoader(val_samples, batch_size=batch_size, shuffle=False)

            print("+++++ Training freeform flow ++++++")
            print(f"Model has {self.count_parameters()} trainable parameters")
            self.training_log = dict(loss_train=[], loss_val=[], logging_dict={})
            log_lik_test_prev = 100000
            if save_best_val:
                no_improv = 0
            try:
                start_time = time.monotonic()
                start_time_fixed = time.monotonic()
                training_time = 0.0
                self.encoder.train()
                self.decoder.train()
                for epoch in range(n_epochs):
                    loss_dict = {}
                    metrics_dict = {}
                    if fixed_datapoints:
                        for data in train_dataset:
                            optim.zero_grad()
                            data, context = check_tuple(
                                data, move_to_torch=False, device=self.device
                            )
                            surrogate = volume_change_surrogate(data, self.encoder, self.decoder)
                            loss_reconstruction = ((data - surrogate.x1) ** 2).sum(-1).mean(-1)
                            loss_nll = (
                                -self.base_distribution.log_prob(surrogate.z) - surrogate.surrogate
                            )
                            total_loss = (beta * loss_reconstruction + loss_nll).mean()
                            loss_dict["log_lik"] = loss_nll.mean().item()
                            loss_dict["total_loss"] = total_loss.item()
                            loss_dict["loss_reconstruction"] = loss_reconstruction.mean().item()
                            self.training_log["loss_train"].append(total_loss.item())
                            total_loss.backward()
                            optim.step()
                    else:
                        optim.zero_grad()
                        data = dataset.sample(batch_size, "train")
                        data, context = check_tuple(data, move_to_torch=False, device=self.device)
                        surrogate = volume_change_surrogate(data, self.encoder, self.decoder)
                        loss_reconstruction = ((data - surrogate.x1) ** 2).sum(-1).mean(-1)
                        loss_nll = (
                            -self.base_distribution.log_prob(surrogate.z) - surrogate.surrogate
                        )
                        total_loss = (beta * loss_reconstruction + loss_nll).mean()
                        loss_dict["log_lik"] = loss_nll.mean().item()
                        loss_dict["total_loss"] = total_loss.item()
                        loss_dict["loss_reconstruction"] = loss_reconstruction.mean().item()
                        self.training_log["loss_train"].append(total_loss.item())
                        total_loss.backward()
                        optim.step()

                    n_val_steps = min(100, n_epochs)
                    if epoch % (n_epochs // n_val_steps) == 0 and epoch >= 500:
                        training_time += time.monotonic() - start_time
                        runtime_df = {"runtime": training_time}
                        self.encoder.eval()
                        self.decoder.eval()

                        if fixed_datapoints:
                            if metrics is not None and context is None:
                                _, _, test_samples = dataset.load_dataset(overwrite=False)
                                test_samples = (
                                    torch.from_numpy(test_samples).float().to(self.device)
                                )
                                metrics_dict = self.evaluate_metrics(
                                    metrics,
                                    val_samples=test_samples,
                                    n_samples=batch_size,
                                    dataset=dataset,
                                )
                                logs_dict = runtime_df | loss_dict | metrics_dict
                                if not self.training_log["logging_dict"]:
                                    self.training_log["logging_dict"] = {
                                        key: [] for key in logs_dict
                                    }
                                for key, value in logs_dict.items():
                                    self.training_log["logging_dict"][key].append(value)

                            log_prob = []
                            reconstruction = []
                            for data in val_dataset:
                                data, context = check_tuple(data)
                                nll_val = exact_nll(
                                    data, self.encoder, self.decoder, self.base_distribution
                                )
                                # nll_val = self.log_prob(data)
                                surrogate = volume_change_surrogate(
                                    data, self.encoder, self.decoder
                                )
                                error = ((data - surrogate.x1) ** 2).sum(-1).mean(-1).unsqueeze(0)
                                # val_loss = (beta * error + nll_val).mean().unsqueeze(0)
                                # nll_val = -self.base_distribution.log_prob(surrogate.z) - surrogate.surrogate
                                log_prob.append(nll_val.nll)
                                # log_prob.append(nll_val.unsqueeze(0))
                                reconstruction.append(error)
                            log_prob = torch.cat(log_prob, -1)
                            reconstruction = torch.cat(reconstruction, -1)
                            nll_val = log_prob.mean().detach().cpu().numpy()
                            # loss_dict['train loss (NLL)'] = nll_out.nll.mean().detach().cpu().numpy().item()
                            loss_dict["Val loss (NLL)"] = nll_val.item()
                            self.training_log["loss_val"].append(nll_val.item())
                            loss_dict["Val reconstruction loss"] = (
                                reconstruction.mean().detach().cpu().numpy().item()
                            )
                            # print(f"Epoch {epoch} " + ", ".join(f"{key}:{value:.3f}" for key, value in loss_dict.items()))
                            print(
                                f"Epoch {epoch} losses: "
                                + ", ".join(
                                    f"{key}:{value:.3f}" for key, value in loss_dict.items()
                                )
                            )
                            if metrics is not None:
                                print(
                                    f"Epoch {epoch} other metrics: "
                                    + ", ".join(
                                        f"{key}:{value:.3f}" for key, value in metrics_dict.items()
                                    )
                                )
                            self.encoder.train()
                            self.decoder.train()

                            if save_best_val:
                                if nll_val < log_lik_test_prev:
                                    print("saved model")
                                    log_lik_test_prev = nll_val
                                    self.save_trained_model()
                                    no_improv = 0
                                else:
                                    no_improv += 1
                                if no_improv > 50:
                                    break

                            start_time = time.monotonic()

            except KeyboardInterrupt:
                print("interrupted...")

            end_time = time.monotonic()
            time_diff = timedelta(seconds=end_time - start_time_fixed)
            print(f"Training freeform flow took {time_diff} seconds")

            loss_train = np.array(self.training_log["loss_train"])
            loss_val = np.array(self.training_log["loss_val"])
            for losses, name in (
                [(loss_train, "train"), (loss_val, "val")]
                if len(loss_val) > 0
                else [(loss_train, "train")]
            ):
                plt.plot(range(len(losses)), losses)
                plt.title(name)
                plt.show()

            if hasattr(self, "model_path"):
                with open(f"{self.model_path}.json", "w") as file:
                    json.dump(self.training_log["logging_dict"], file, indent=4)

            if save_best_val:
                self.load_trained_model()
            else:
                self.save_trained_model()
