import logging
import matplotlib.pyplot as plt
import torch
import numpy as np
from runners.Basic_runner import BasicRunner
from utils import (
    split_dataset,
    check_memory)
from sampling import SDE_sampler_manifolds


class SOnRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)
        self.load_data()

        """---------------------------------------exhibit dataset----------------------------------------"""
        test_set_samples = self.data_set[:self.config.sample.sample_num]
        self.test_set_statistics = self.get_statistics(test_set_samples).cpu().numpy()
        self.plot_hist(self.test_set_statistics, savefig="true", compare=False)

        uniform = self.manifold.uniform_sample(self.config.sample.sample_num)
        uniform_statistics = self.get_statistics(uniform).numpy()
        self.plot_hist(uniform_statistics, savefig="uniform", compare=False)

        if self.config.if_train or self.config.if_sample:
            x_hist = self.training_set_path[:self.config.sample.sample_num].clone().transpose(0,1)
            x = x_hist[-1]
            x = self.filter_sample(x)
            statistics = self.get_statistics(x).cpu().numpy()
            self.plot_hist(statistics, savefig='forward_end')
            plot_idx = list(range(10)) + list(range(10, 101, 10))
            for i in range(self.sde.N+1):
                if (100 * i / self.sde.N in plot_idx) or (i < 5):
                    x_temp = x_hist[i].clone()
                    statistics = self.get_statistics(x_temp).cpu().numpy()
                    self.plot_hist(statistics, savefig=f'generating_fwd_{i}')
            
            statistics_path = self.get_statistics_path(x_hist.cpu().detach()).numpy()
            np.save(f"{self.samples_dir}/{self.dataset_name}_statistics_fwd.npy", statistics_path)

    def load_data(self):
        self.power_list = [1, 2, 4, 5]

        data_ori = torch.tensor(np.load(f"./data/SOn/{self.dataset_name}.npy")).reshape(-1, self.manifold.out_dim)
        self.data_set = data_ori[torch.randperm(data_ori.shape[0])].clone()
        self.training_set, self.test_set, self.val_set = split_dataset(self.data_set, self.config.seed)

        if self.config.if_train or self.config.if_sample:
            self.training_set_path = self.generate_path_dataset(self.training_set, keep_quiet=False)
            check_memory(self.training_set_path)

    def filter_sample(self, samples):
        all_mat = samples.reshape(-1, self.manifold.mat_dim, self.manifold.mat_dim)
        manifolds_idx = torch.where(torch.linalg.det(all_mat) > 0)[0]
        logging.info(f"The number of samples on the correct connected component: {manifolds_idx.shape[0]}/{samples.shape[0]}, the others are dropped.")
        return samples[manifolds_idx]
    
    def get_statistics(self, samples):
        if not isinstance(samples, torch.Tensor): samples = torch.tensor(samples)

        samples = samples.reshape(-1, self.manifold.mat_dim, self.manifold.mat_dim)
        trace_list = []
        for i in range(4):
            samples_pow = torch.matrix_power(samples, self.power_list[i])
            trace = samples_pow.diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)
            trace_list.append(trace)
        return torch.cat(trace_list, dim=1)
    
    def get_statistics_path(self, path):
        statistics = torch.zeros(path.shape[0], path.shape[1], 4)
        for i in range(path.shape[0]):
            temp = self.get_statistics(path[i])
            statistics[i] = temp.clone()
        return statistics

    def plot_hist(self, statistics, savefig=None, compare=True):
        bins = int(statistics.shape[0]/100)
        fig = plt.figure(figsize=(10, 10))
        for i in range(4):
            ax = plt.subplot(2, 2, i + 1)
            if compare:
                ax.hist(statistics[:, i], bins=bins, histtype='stepfilled', alpha=0.5, density=True, color='green')
                ax.hist(self.test_set_statistics[:, i], bins=bins, histtype='stepfilled', alpha=0.5, density=True, color='red')
            else:
                ax.hist(statistics[:, i], bins=bins, alpha=1.0, density=True)
            ax.set_title(rf"Histogram of $tr(S^{self.power_list[i]})$")

        plt.savefig(self.savefig_dir + f"/Hist_Statistics_{savefig}.png", dpi=300)
        plt.close(fig)

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start':
            self.best_nll_K_val = torch.inf
        elif mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
            # sample
            init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
            x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                          reverse=True,
                                                          score_net=self.network,
                                                          keep_quiet=True)
            x = self.filter_sample(x)
            statistics_gen = self.get_statistics(x).cpu().numpy()
            self.plot_hist(statistics_gen, savefig=f'val_{epoch}_generated')
            logging.info("-------------------------End validating.-------------------------")

    def test(self):
        return

    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        device = self.device
        if self.network is not None: self.network.to(device)

        # backward
        logging.info("Start sampling backward SDE.")
        init = self.manifold.uniform_sample(self.config.sample.sample_num).to(device)
        x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                      reverse=True,
                                                      score_net=self.network,
                                                      keep_quiet=False)
        self.calculate_constrain(x)
        x = self.filter_sample(x)
        statistics_gen = self.get_statistics(x).cpu().numpy()
        self.plot_hist(statistics_gen, savefig='generated')
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                statistics = self.get_statistics(x_hist[i]).cpu().numpy()
                self.plot_hist(statistics, savefig=f'generating_bwd_{i}')

        statistics_path = self.get_statistics_path(x_hist.cpu().detach()).numpy()
        np.save(f"{self.samples_dir}/{self.dataset_name}_statistics_bwd.npy", statistics_path)

        return


