import copy
import time
from collections import MutableSequence, OrderedDict
from itertools import compress
from typing import Any

import numpy as np  # type: ignore
import torch
from tqdm import tqdm  # type: ignore

from layers.mcmc.metrics import RunningAverageMeter
from layers.mcmc.model import ProbModel

T = torch.Tensor


class Chain(MutableSequence):
    """
    A container for storing the MCMC chain conveniently:
    samples: list of state_dicts
    log_probs: list of log_probs
    accepts: list of bools
    state_idx:
        init index of last accepted via np.where(accepts==True)[0][-1]
        can be set via len(samples) while sampling
    @property
    samples: filters the samples
    """

    def __init__(self, probmodel: ProbModel = None):
        super().__init__()

        if probmodel is None:
            '''Create an empty chain'''
            self.state_dicts = []
            self.log_probs = []
            self.accepts: np.array = []

        if probmodel is not None:
            '''Initialize chain with given model'''
            assert isinstance(probmodel, ProbModel)

            self.state_dicts = [copy.deepcopy(probmodel.state_dict())]
            log_prob = probmodel.log_prob(*next(probmodel.dataloader.__iter__()))
            log_prob['log_prob'].detach_()
            self.log_probs = [copy.deepcopy(log_prob)]
            self.accepts = [True]
            self.last_accepted_idx = 0

            self.running_avgs = {}
            for key, value in log_prob.items():
                self.running_avgs.update({key: RunningAverageMeter(0.99)})

        self.running_accepts = RunningAverageMeter(0.999)

    def __len__(self) -> int:
        return len(self.state_dicts)

    def __iter__(self) -> Any:
        return zip(self.state_dicts, self.log_probs, self.accepts)

    def __repr__(self) -> str:
        return f"MCMC Chain: Length:{len(self)} Accept:{self.accept_ratio:.2f}"

    def __delitem__(self, i: int) -> None:  # type: ignore
        raise NotImplementedError

    def insert(self, i: int, v: Any) -> None:
        raise NotImplementedError

    def __setitem__(self, i: int) -> None:  # type: ignore
        raise NotImplementedError

    def __getitem__(self, i: int) -> Any:  # type: ignore
        chain = copy.deepcopy(self)
        chain.state_dicts = self.samples[i]
        # TODO: this sets a list to a single value, I am not sure why this would be necessary, probably needs to be fixed
        chain.log_probs = self.log_probs[i]  # type: ignore
        chain.accepts = self.accepts[i]
        return chain

    def __add__(self, other: Any) -> Any:
        if type(other) in [tuple, list]:
            assert len(other) == 3, f"Invalid number of information pieces passed: {len(other)} vs len(Iterable(model, log_prob, accept, ratio))==4"
            self.append(*other)
        elif isinstance(other, Chain):
            self.cat(other)

        return self

    def __iadd__(self, other):
        if type(other) in [tuple, list]:
            assert len(other) == 3, f"Invalid number of information pieces passed: {len(other)} vs len(Iterable(model, log_prob, accept, ratio))==4"
            self.append(*other)
        elif isinstance(other, Chain):
            self.cat_chains(other)

        return self

    @property
    def state_idx(self):
        '''Returns the index of the last accepted sample a.k.a. the state of the chain'''
        if not hasattr(self, 'state_idx'):
            '''If the chain hasn't a state_idx, compute it from self.accepts by taking the last True of self.accepts'''
            self.last_accepted_idx = np.where(self.accepts == True)[0][-1]
            return self.last_accepted_idx
        else:
            '''Check that the state of the chain is actually the last True in self.accepts'''
            last_accepted_sample_ = np.where(self.accepts == True)[0][-1]
            assert last_accepted_sample_ == self.last_accepted_idx
            assert self.accepts[self.last_accepted_idx] == True
            return self.last_accepted_idx

    @property
    def samples(self):
        """Filters the list of state_dicts with the list of bools from self.accepts :return: list of accepted state_dicts"""
        return list(compress(self.state_dicts, self.accepts))

    @property
    def accept_ratio(self):
        """Sum the boolean list (=total number of Trues) and divides it by its length :return: float valued accept ratio"""
        return sum(self.accepts) / len(self.accepts)

    @property
    def state(self):
        return {'state_dict': self.state_dicts[self.last_accepted_idx], 'log_prob': self.log_probs[self.last_accepted_idx]}

    def cat_chains(self, other):
        assert isinstance(other, Chain)
        self.state_dicts += other.state_dicts
        self.log_probs += other.log_probs
        self.accepts += other.accepts

        for key, value in other.running_avgs.items():
            self.running_avgs[key].avg = 0.5 * self.running_avgs[key].avg + 0.5 * other.running_avgs[key].avg

    def append(self, probmodel, log_prob, accept):
        if isinstance(probmodel, ProbModel):
            params_state_dict = copy.deepcopy(probmodel.state_dict())
        elif isinstance(probmodel, OrderedDict):
            params_state_dict = copy.deepcopy(probmodel)
        assert isinstance(log_prob, dict)
        assert type(log_prob['log_prob']) == torch.Tensor
        assert log_prob['log_prob'].numel() == 1

        log_prob['log_prob'].detach_()

        self.accepts.append(accept)
        self.running_accepts.update(1 * accept)

        if accept:
            self.state_dicts.append(params_state_dict)
            self.log_probs.append(copy.deepcopy(log_prob))
            self.last_accepted_idx = len(self.state_dicts) - 1
            for key, value in log_prob.items():
                self.running_avgs[key].update(value.item())

        elif not accept:
            self.state_dicts.append(False)
            self.log_probs.append(False)


class ChainSampler:
    def __init__(self, probmodel, step_size, num_steps, burn_in, pretrain, tune):
        self.probmodel = probmodel
        self.chain = Chain(probmodel=self.probmodel)

        self.step_size = step_size
        self.num_steps = num_steps
        self.burn_in = burn_in

        self.pretrain = pretrain
        self.tune = tune

    def propose(self):
        raise NotImplementedError

    def __repr__(self) -> str:
        raise NotImplementedError

    def tune_step_size(self):
        tune_interval_length = 100
        print("TODO: do these need to be used?")
        # num_tune_intervals = int(self.burn_in // tune_interval_length)
        # verbose = True

        print(f'Tuning: Init Step Size: {self.optim.param_groups[0]["step_size"]:.5f}')

        self.probmodel.reset_parameters()
        tune_chain = Chain(probmodel=self.probmodel)
        tune_chain.running_accepts.momentum = 0.5

        progress = tqdm(range(self.burn_in))
        for tune_step in progress:
            sample_log_prob, sample = self.propose()
            accept, log_ratio = self.acceptance(sample_log_prob['log_prob'], self.chain.state['log_prob']['log_prob'])
            tune_chain += (self.probmodel, sample_log_prob, accept)

            # if tune_step < self.burn_in and tune_step % tune_interval_length == 0 and tune_step > 0:
            if tune_step > 1:
                # self.optim.dual_average_tune(tune_chain, np.exp(log_ratio.item()))
                self.optim.dual_average_tune(tune_chain.accepts[-tune_interval_length:], tune_step, np.exp(log_ratio.item()))
                # self.optim.tune(tune_chain.accepts[-tune_interval_length:])

            if not accept:
                if torch.isnan(sample_log_prob['log_prob']):
                    print(self.chain.state)
                    exit()
                self.probmodel.load_state_dict(self.chain.state['state_dict'])

            desc = f'Tuning: Accept: {tune_chain.running_accepts.avg:.2f}/{tune_chain.accept_ratio:.2f} StepSize: {self.optim.param_groups[0]["step_size"]:.5f}'
            progress.set_description(desc=desc)

        time.sleep(0.1)  # for cleaner printing in the console

    def sample_chain(self):
        raise NotImplementedError
