import numpy as np
import os
import torch
from sbi import inference as inference

# from sbi.utils.metrics import unbiased_mmd_squared as mmd
# from sbi.utils.metrics import c2st
from sbibm.metrics import mmd, median_distance, c2st
from sbibm import get_task

# from sbibm.visualisation import fig_posterior

from margflow.datasets.dataset_abstracts import HybridDataset
from margflow.utils.training_utils import check_tuple
import matplotlib.pyplot as plt


class SimulationBasedInference(HybridDataset):
    def __init__(self, args):
        super(SimulationBasedInference, self).__init__(args)
        assert args.dataset.split("_")[0] == "sbi"
        self.taskname = "_".join(args.dataset.split("_")[1:])
        self.n_simulations = args.n_simulations
        self.dataset_suffix += f"_sbi_{self.taskname}"
        self.D = self.args.x_dim

        self.task = get_task(self.taskname)
        self.simulator = self.task.get_simulator()
        self.proposal = self.task.get_prior_dist()

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        theta, x = inference.simulate_for_sbi(
            self.simulator,
            self.proposal,
            num_simulations=self.n_simulations,
            show_progress_bar=False,
        )
        # mb, dim --> m_b, 1, dim (implicitly assumes one datapoint per condition)
        theta = theta[:, None, :]
        return theta, x

    @torch.no_grad()
    def evaluate_metric(
        self,
        model_sampler,
        task,
        n_sim,
        observations=list(range(1, 11)),
        fourier_sigma=0.01,
        seed=1234,
        device="cuda",
        subsample=False,
        normalized_data=False,
    ):
        c2st_, mmd_, med_dist_ = [], [], []
        directory_name = f"./plots/{task}_ns{n_sim}_fs{fourier_sigma:.4}_seed{seed}"
        os.makedirs(directory_name, exist_ok=True)
        n_obs = len(observations)
        print("Evaluation started...")

        if normalized_data:
            train_samples, val_samples, _ = self.load_dataset(overwrite=False)
            train_samples, train_context = check_tuple(
                train_samples, move_to_torch=True, device=self.device
            )
            data_mean = train_samples.mean(0)
            data_std = train_samples.std(0)
            context_mean = train_context.mean(0)
            context_std = train_context.std(0)

        for n_observation in observations:
            reference_samples = (
                self.task.get_reference_posterior_samples(num_observation=n_observation)
                .detach()
                .cpu()
            )
            if subsample:
                idx = np.random.randint(0, len(reference_samples), len(reference_samples) // 10)
                reference_samples = reference_samples[idx]
            n_samples = len(reference_samples)
            observation = self.task.get_observation(n_observation)
            observation = observation.repeat(n_samples, 1).to(device)
            if normalized_data:
                observation -= context_mean
                observation /= context_std
            n_batches = 200
            batch_size = n_samples // n_batches
            assert n_batches * batch_size == n_samples
            posterior_samples = [
                model_sampler(
                    n_samples=1, context=observation[i * batch_size : (i + 1) * batch_size]
                )[0].squeeze()
                for i in range(n_batches)
            ]
            posterior_samples = torch.cat(posterior_samples, 0).detach().cpu()
            if normalized_data:
                posterior_samples *= data_std.detach().cpu()
                posterior_samples += data_mean.detach().cpu()
            new_c2st = c2st(posterior_samples, reference_samples)
            new_mmd = mmd(posterior_samples, reference_samples)
            new_med_dist = median_distance(posterior_samples, reference_samples)

            c2st_.append(new_c2st.item())
            mmd_.append(new_mmd.item())
            med_dist_.append(new_med_dist.item())

            print(
                f"Obs {n_observation}/{n_obs} - "
                f"c2st: {new_c2st.item():.4f} "
                f"mmd: {new_mmd:.4f} "
                f"med_dist: {new_med_dist.item():.4f}"
            )

            if posterior_samples.shape[-1] == 2:
                try:
                    # fig = fig_posterior(
                    #     task_name=task,
                    #     num_observation=n_observation,
                    #     samples_tensor=posterior_samples,
                    # )
                    # fig.save(f"{directory_name}/{task}_{n_observation}.png")

                    plt.scatter(
                        reference_samples[:, 0].numpy(),
                        reference_samples[:, 1].numpy(),
                        alpha=0.2,
                        label="reference samples",
                    )
                    plt.scatter(
                        posterior_samples[:, 0].numpy(),
                        posterior_samples[:, 1].numpy(),
                        alpha=0.2,
                        label="margflow samples",
                    )
                    plt.legend()
                    plt.savefig(f"{directory_name}/{task}_scatter_{n_observation}.png")
                    plt.clf()
                except:
                    raise ValueError("Task name not recognized by fig_posterior")

        c2st_ = np.array(c2st_)
        mmd_ = np.array(mmd_)
        med_dist_ = np.array(med_dist_)

        with open(f"{directory_name}/metrics.txt", "w") as output:
            output.write(f"c2st\t{c2st_.mean()}\t{c2st_.std()}\n")
            output.write(f"mmd\t{mmd_.mean()}\t{mmd_.std()}\n")
            output.write(f"med_dist\t{med_dist_.mean()}\t{med_dist_.std()}\n")

        return c2st_, mmd_, med_dist_
