from typing import List

import numpy as np  # type: ignore
import torch
from matplotlib import pyplot as plt  # type: ignore

from hmc.chain import Chain, ProbModel


class Sampler:
    def __init__(self, probmodel: ProbModel, step_size: float, num_steps: int, num_chains: int, burn_in: int, pretrain: bool, tune: bool):

        self.probmodel = probmodel
        self.num_chains = num_chains
        self.chain: Chain

        self.step_size = step_size
        self.num_steps = num_steps
        self.burn_in = burn_in

        self.pretrain = pretrain
        self.tune = tune

        test_log_prob = self.probmodel.log_prob(*next(self.probmodel.dataloader.__iter__()))
        assert type(test_log_prob) == dict
        assert list(test_log_prob.keys())[0] == 'log_prob'

    def sample_chains(self) -> List[Chain]:
        raise NotImplementedError

    def __str__(self) -> str:
        raise NotImplementedError

    def sample_chain(self, step_size: float = None) -> None:
        raise NotImplementedError

    def posterior_dist(self, verbose: bool = False, plot: bool = True) -> None:
        raise ValueError("This function has not been modified, it has problems which need to be fixed before using")

        if len(self.probmodel.state_dict()) == 1:
            '''We're sampling from a predefined distribution like a GMM and simulating a particle'''
            post = []

            accepted_models = self.chain.samples
            for model_state_dict in accepted_models:
                post.append(list(model_state_dict.values())[0])

            post = torch.cat(post, dim=0)

            if plot:
                hist2d = plt.hist2d(x=post[:, 0].cpu().numpy(), y=post[:, 1].cpu().numpy(), bins=100, range=np.array([[-3, 3], [-3, 3]]), density=True)
                plt.colorbar(hist2d[3])
                plt.show()

        elif len(self.probmodel.state_dict()) > 1:
            '''There is more than one parameter in the model '''
            param_names = list(self.probmodel.state_dict().keys())
            accepted_models = self.chain.samples

            for param_name in param_names:

                post = []

                for model_state_dict in accepted_models:

                    post.append(model_state_dict[param_name])

                post = torch.cat(post)
                # print(post)

                if plot:
                    plt.hist(x=post, bins=50, range=np.array([-3, 3]), density=True, alpha=0.5)
                    plt.title(param_name)
                plt.show()

    def trace(self, verbose: bool = False, plot: bool = True) -> None:
        raise ValueError("This function has not been modified, it has problems which need to be fixed before using")

        if len(self.probmodel.state_dict()) >= 1:
            '''There is more than one parameter in the model'''
            param_names = list(self.probmodel.state_dict().keys())
            accepted_models = [self.chain.samples[idx] for idx in self.chain.accepted_steps]

            for param_name in param_names:

                post = []

                for model_state_dict in accepted_models:
                    post.append(model_state_dict[param_name])

                # print(post)

                post = torch.cat(post)
                # print(post)

                if plot:
                    plt.plot(np.arange(len(accepted_models)), post)
                    plt.title(param_name)

                plt.show()
