from EscapeEnv.common.base_estimator import ActorCriticEstimator
from EscapeEnv.common.scheduler import PolynomialParamScheduler
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist

from EscapeEnv.common.optimizers import SGHMC
import numpy as np
from copy import deepcopy
from collections import deque


class LT_A2CEstimator(ActorCriticEstimator):
    def __init__(self, actor_network, critic_network, learning_rate, gamma, optimizer_kwargs, estimator_kwargs, device) -> None:
        super().__init__(actor_network, critic_network, learning_rate, gamma, optimizer_kwargs, estimator_kwargs, device)
        
        self.is_importance_weight = self.estimator_kwargs['is_importance_weight']
        self.decay_burnin = self.estimator_kwargs['decay_burnin']
        self.loops_per_train = self.estimator_kwargs['loops_per_train']
        self.ent_coef = self.estimator_kwargs['ent_coef']
        
        self.prior_sd = self.estimator_kwargs['prior_sd']
        self.sparse_sd = self.estimator_kwargs['sparse_sd']
        self.sparse_ratio = self.estimator_kwargs['sparse_ratio']
        self.obs_sd = self.estimator_kwargs['obs_sd']
        self.pseudo_population = self.optimizer_kwargs['pseudo_population']
        self.return_sd = self.obs_sd / np.sqrt(1 - self.gamma ** 2)

        if self.estimator_kwargs['use_sgd']:
            self.actor_optimizer = optim.SGD(self.actor_net.parameters(), lr=1e-3, momentum=0.9)
        else:
            self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=1e-3)
            
        self.critic_optimizer = SGHMC(self.critic_net.parameters(), lr=self.learning_rate, **self.optimizer_kwargs)
        
        self.auxiliary_net = deepcopy(self.critic_net)
        self.auxiliary_optimizer = SGHMC(self.auxiliary_net.parameters(), lr=self.learning_rate, **self.optimizer_kwargs)
        
        self.actor_lr_scheduler = PolynomialParamScheduler(self.actor_optimizer, 'lr', start_value=1e-3, **self.estimator_kwargs['actor_scheduler'])
        # self.critic_lr_scheduler = PolynomialParamScheduler(self.critic_optimizer, 'lr', start_value=self.learning_rate , **self.estimator_kwargs['critic_scheduler'])
        self.schedulers = [self.actor_lr_scheduler]
        # self.schedulers = []
        self.mse_loss = nn.MSELoss(reduction='mean')

        # self.std_normal = torch.distributions.Normal(0, 1)
        self.critic_buffer = CriticBuffer(return_sd=self.return_sd, buffer_size=self.estimator_kwargs['buffer_size'])
        self.critic_buffer.add(self.critic_net)
        self.buffer_size = self.estimator_kwargs['buffer_size']
        
        self.gmm = GaussianMixtureModel(self.prior_sd, self.sparse_sd)
        self.return_normal = torch.distributions.Normal(0, self.return_sd)
        
        
    def mixture_gaussian_prior(self, net, sparse_ratio):
        log_prob = 0
        for p in net.parameters():
            if len(p.data.shape) > 1:
                log_prob += self.gmm.log_prob(p.data.flatten(), sparse_ratio).sum()
        return log_prob
    
    def update_sparse_ratio(self):
        if self.progress < self.decay_burnin:
            return 1.0 - (1.0 - self.sparse_ratio) * (self.progress/self.decay_burnin)
        else:
            return self.sparse_ratio
        
    def decay_ent_coef(self, progress):
        # self.ent_coef = 0.01 * max(0.2 - progress, 0)/0.2
        self.ent_coef = 0.0
    
    def update(self, buffer, progress):
        self.progress = progress
        self.decay_ent_coef(progress)
        
        # optimize actor network
        rollout_data = buffer.get()
        
        if self.progress > self.decay_burnin:
            for schedule in self.schedulers:
                schedule.step()
        
        
        if self.progress < self.decay_burnin:
            critic_loss = self.sample_critic(rollout_data, is_importance_weight=False)
        else:
            critic_loss = self.sample_critic(rollout_data, is_importance_weight=self.is_importance_weight)
        
        self.critic_buffer.add(self.critic_net)
        self.optimize_actor(rollout_data)
        
        
        self.n_updates += 1
        return critic_loss
    
    def optimize_actor(self, rollout_data):
        # actions = rollout_data.actions
        log_prob, entropy = self.actor_net.evaluate_actions(rollout_data.observations, rollout_data.actions, rollout_data.action_mask)
        advantages = rollout_data.advantages
        
        policy_loss = - (advantages * log_prob).mean()
        entropy_loss = - torch.mean(entropy)
        sparse_ratio = self.update_sparse_ratio()
        
        loss = policy_loss + self.ent_coef * entropy_loss - self.mixture_gaussian_prior(self.actor_net, sparse_ratio)
        
        self.actor_optimizer.zero_grad()
        loss.backward()
        self.actor_optimizer.step()
    
    def sample_critic(self, rollout_data, is_importance_weight=False):
        
        y = rollout_data.returns
        batch_size = y.shape[0]
        obs = rollout_data.observations
        
        sparse_ratio = self.update_sparse_ratio()

        M = 5
        m = 5
        # log_normal_prob = torch.zeros(self.loops_per_train)
        if is_importance_weight:
            log_denominator = self.auxiliary_sampling(M, y, obs)
        
        for k in range(self.loops_per_train):
            # values = self.critic_net(rollout_data.observations).flatten()   
            values = self.critic_net(obs.detach()).flatten()   

            # log_error_prob = 0
            if is_importance_weight:
                # print(y)
                # print(values)
                log_normal_prob = self.return_normal.log_prob(y - values.detach()).sum()
                
                stabilizer = torch.max(log_normal_prob, log_denominator)
                
                importance_weights = torch.exp(log_normal_prob-stabilizer) / ( torch.exp(log_denominator-stabilizer) * (m/(m+1)) + torch.exp(log_normal_prob - stabilizer) / (m+1) )
                # print('importance weight:', importance_weights)
                
                log_error_prob = self.return_normal.log_prob(y - values).mean() 
                log_prior_prob = self.mixture_gaussian_prior(self.critic_net, sparse_ratio)  / self.pseudo_population
                # loss = - (log_error_prob + log_prior_prob ) * importance_weights
                # print(importance_weights.detach().clamp(lower, upper))
                loss = - (log_error_prob + log_prior_prob ) * importance_weights
                
            else:
                log_error_prob = self.return_normal.log_prob(y - values).mean()
                log_prior_prob = self.mixture_gaussian_prior(self.critic_net, sparse_ratio) / self.pseudo_population
                loss = - log_error_prob -  log_prior_prob 
                
            self.critic_optimizer.zero_grad()
            loss.backward()
            self.critic_optimizer.step(k)

        if self.n_updates % 100 == 0 and is_importance_weight:
            print('importance weight:', importance_weights)
        return loss.item()/batch_size
    
    def auxiliary_sampling(self, m, y, obs):
        self.auxiliary_net.load_state_dict(self.critic_net.state_dict())
        
        log_prob = torch.zeros(m)
        values = self.auxiliary_net(obs).flatten()
        
        B = 50
        for i in range(m):
            y_pred = values + self.return_sd * torch.randn(B, y.shape[0]) 
            # loss = - self.return_normal.log_prob(y_pred - values).mean() - self.mixture_gaussian_prior(self.auxiliary_net, 1.0) / 10000
            
            log_error_prob = self.return_normal.log_prob(y_pred - values).mean() 
            log_prior_prob = self.mixture_gaussian_prior(self.critic_net, self.update_sparse_ratio())  / self.pseudo_population
            loss = - (log_error_prob + log_prior_prob )
            
            self.auxiliary_optimizer.zero_grad()
            loss.backward()
            self.auxiliary_optimizer.step()
            
            values = self.auxiliary_net(obs).flatten()
            log_prob[i] = self.return_normal.log_prob(y - values.detach()).sum()
        
        return torch.logsumexp(log_prob, dim=0) - np.log(m)
            
            
        


class CriticBuffer(object):
    def __init__(self, return_sd, buffer_size=10) -> None:
        self.buffer_size = buffer_size
        self.return_sd = return_sd
        self.memory = deque(maxlen=buffer_size)
        self.normal_dist = torch.distributions.Normal(0, return_sd)
        
    def add(self, net):
        self.memory.append(deepcopy(net))
        
    def importance_weight(self, rollout_data): 
        probs = 0
        for net in self.memory:
            probs += self.calculate_prob(net, rollout_data)
        probs /= len(self.memory)
        return probs
        
        
    def calculate_prob(self, net, rollout_data):
        y = rollout_data.returns
        y_hat = net(rollout_data.observations).flatten().detach()
        log_prob = self.normal_dist.log_prob(y - y_hat)
        return torch.exp(log_prob).clamp(min=1e-3)


class GaussianMixtureModel:
    def __init__(self, prior_sd, sparse_sd):
        """
        Initialize the GMM with means, standard deviations, and mixture weights for each component.
        :param means: A tensor of shape (2,) containing the means of the two Gaussians.
        :param stds: A tensor of shape (2,) containing the standard deviations of the two Gaussians.
        :param mixture_weights: A tensor of shape (2,) containing the mixture weights of the two Gaussians.
        """
        self.means = torch.tensor([0.0, 0.0])  # Means of the two Gaussians
        self.stds = torch.tensor([prior_sd, sparse_sd])   # Standard deviations of the two Gaussians
        self.components = dist.Normal(self.means, self.stds)

    def sample(self, n_samples, sparse_ratio):
        """
        Sample n_samples from the GMM.
        :param n_samples: Number of samples to generate.
        :return: Samples from the GMM.
        """
        # Sample component indices
        self.mixture_weights = torch.tensor([sparse_ratio, 1-sparse_ratio])  # Ensure weights sum to 1
        mixture_indices = dist.Categorical(self.mixture_weights).sample((n_samples,))
        # Sample from each Gaussian
        samples = dist.Normal(self.means[mixture_indices], self.stds[mixture_indices]).sample()
        return samples

    def log_prob(self, x, sparse_ratio):
        """
        Calculate the log probability of data points x under the GMM.
        :param x: A tensor of data points.
        :return: A tensor of log probabilities.
        """
        # Calculate log probabilities from each component for each data point
        log_probs = self.components.log_prob(x.unsqueeze(1))  # Shape will be [N, num_components]
        # Weight log probabilities by the log of mixture weights
        self.mixture_weights = torch.tensor([sparse_ratio, 1-sparse_ratio])  # Ensure weights sum to 1
        log_weighted_probs = log_probs + torch.log(self.mixture_weights)
        # Log-sum-exp trick for numerical stability: log(sum(exp(log_probs)))
        log_sum_exp = torch.logsumexp(log_weighted_probs, dim=1)
        return log_sum_exp
            

if __name__ == '__main__':

    import matplotlib.pyplot as plt
    means = torch.tensor([0.0, 0.0])  # Means of the two Gaussians
    stds = torch.tensor([1.0, 0.1])   # Standard deviations of the two Gaussians
    mixture_weights = torch.tensor([0.5, 0.5])  # Mixture weights

    gmm = GaussianMixtureModel(1.0, 0.1)

    # Generate some samples or use existing data
    samples = gmm.sample(1000, 0.5)
    # plt.hist(samples, bins=50, density=True)
    # plt.show()

    # Calculate log probabilities of the generated samples
    log_probs = gmm.log_prob(samples, 0.5)

    print(log_probs.sum())