from EscapeEnv.common.base_estimator import ActorCriticEstimator
from EscapeEnv.common.scheduler import PolynomialParamScheduler
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
import torch.nn.functional as F

from EscapeEnv.common.optimizers import SGHMC
import numpy as np
from copy import deepcopy
from collections import deque



class LT_PPO_v2Estimator(ActorCriticEstimator):
    def __init__(self, actor_network, critic_network, learning_rate, gamma, optimizer_kwargs, estimator_kwargs, device) -> None:
        super().__init__(actor_network, critic_network, learning_rate, gamma, optimizer_kwargs, estimator_kwargs, device)
        
        self.is_importance_weight = self.estimator_kwargs['is_importance_weight']
        self.decay_burnin = self.estimator_kwargs['decay_burnin']
        self.loops_per_train = self.estimator_kwargs['loops_per_train']
        self.ent_coef = self.estimator_kwargs['ent_coef']
        
        self.max_grad_norm = self.estimator_kwargs['max_grad_norm']
        self.use_rms_prop = self.estimator_kwargs['use_rms_prop']   
        
        self.prior_sd = self.estimator_kwargs['prior_sd']
        self.sparse_sd = self.estimator_kwargs['sparse_sd']
        self.sparse_ratio = self.estimator_kwargs['sparse_ratio']
        self.obs_sd = self.estimator_kwargs['obs_sd']
        self.pseudo_population = self.optimizer_kwargs['pseudo_population']
        self.return_sd = self.obs_sd / np.sqrt(1 - self.gamma ** 2)

        if self.use_rms_prop:
            self.actor_optimizer = optim.RMSprop(self.actor_net.parameters(), lr=self.learning_rate, eps=1e-5, alpha=0.99, weight_decay=0.0)
        else:
            self.actor_optimizer = optim.SGD(self.actor_net.parameters(), lr=self.learning_rate, momentum=0.9)
            
        self.critic_optimizer = SGHMC(self.critic_net.parameters(), lr=self.learning_rate, **self.optimizer_kwargs)
        
        self.auxiliary_net = deepcopy(self.critic_net)
        self.auxiliary_optimizer = SGHMC(self.auxiliary_net.parameters(), lr=self.learning_rate, **self.optimizer_kwargs)
        
        self.actor_lr_scheduler = PolynomialParamScheduler(self.actor_optimizer, 'lr', start_value=1e-3, **self.estimator_kwargs['actor_scheduler'])
        self.schedulers = [self.actor_lr_scheduler]
        self.mse_loss = nn.MSELoss(reduction='mean')

        # self.std_normal = torch.distributions.Normal(0, 1)
        self.critic_buffer = CriticBuffer(return_sd=self.return_sd, buffer_size=self.estimator_kwargs['buffer_size'])
        self.critic_buffer.add(self.critic_net)
        self.buffer_size = self.estimator_kwargs['buffer_size']
        
        self.gmm = GaussianMixtureModel(self.prior_sd, self.sparse_sd)
        self.return_normal = torch.distributions.Normal(0, self.return_sd)
        
        
    def mixture_gaussian_prior(self, net, sparse_ratio):
        log_prob = 0
        for p in net.parameters():
            if len(p.data.shape) > 1:
                log_prob += self.gmm.log_prob(p.data.flatten(), sparse_ratio).sum()
        return log_prob
    
    def update_sparse_ratio(self):
        if self.progress < self.decay_burnin:
            return 1.0 - (1.0 - self.sparse_ratio) * (self.progress/self.decay_burnin)
        else:
            return self.sparse_ratio
        
    def clip_range(self, progress):
        return 0.2 + 0.8 * (1 - progress)
    
    def decay_ent_coef(self, progress):
        self.ent_coef = 0.01 * max(0.2 - progress, 0)/0.2
    
    def update(self, buffer, progress):
        self.progress = progress
        self.decay_ent_coef(progress)

        # optimize actor network
        rollout_data = buffer.get()
        
        if self.progress > self.decay_burnin:
            for schedule in self.schedulers:
                schedule.step()
        
        # for _ in range(self.loops_per_train):            
        if self.progress < self.decay_burnin:
            critic_loss = self.sample_critic(rollout_data, is_importance_weight=False)
        else:
            critic_loss = self.sample_critic(rollout_data, is_importance_weight=self.is_importance_weight)
        self.critic_buffer.add(self.critic_net)
        
        self.optimize_actor(rollout_data)
            
        self.n_updates += 1
        return critic_loss
    
    def optimize_actor(self, rollout_data):
        clip_range = 0.2

        log_prob, entropy = self.actor_net.evaluate_actions(rollout_data.observations, rollout_data.actions, rollout_data.action_mask)
        ratio = torch.exp(log_prob - rollout_data.log_probs)
        
        # clipped surrogate loss
        policy_loss_1 = rollout_data.advantages * ratio
        policy_loss_2 = rollout_data.advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
        # policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()
        policy_loss = -policy_loss_1.mean()

        entropy_loss = - torch.mean(entropy)
        sparse_ratio = self.update_sparse_ratio()
        loss = policy_loss + self.ent_coef * entropy_loss - self.mixture_gaussian_prior(self.actor_net, sparse_ratio)
        # loss = policy_loss + self.ent_coef * entropy_loss
        
        self.actor_optimizer.zero_grad()
        loss.backward()
        if self.progress < self.decay_burnin:
            torch.nn.utils.clip_grad_norm_(self.actor_net.parameters(), self.max_grad_norm)
        self.actor_optimizer.step()
    
    def sample_critic(self, rollout_data, is_importance_weight=False):
        
        y = rollout_data.returns
        batch_size = y.shape[0]
        obs = rollout_data.observations
        
        sparse_ratio = self.update_sparse_ratio()

        M = 5
        m = 5
        # log_normal_prob = torch.zeros(self.loops_per_train)
        if is_importance_weight:
            log_denominator = self.auxiliary_sampling(M, y, obs)
        
        for k in range(self.loops_per_train):
            # values = self.critic_net(rollout_data.observations).flatten()   
            values = self.critic_net(obs.detach()).flatten()   

            # log_error_prob = 0
            if is_importance_weight:
                # print(y)
                # print(values)
                log_normal_prob = self.return_normal.log_prob(y - values.detach()).sum()
                
                stabilizer = torch.max(log_normal_prob, log_denominator)
                
                importance_weights = torch.exp(log_normal_prob-stabilizer) / ( torch.exp(log_denominator-stabilizer) * (m/(m+1)) + torch.exp(log_normal_prob - stabilizer) / (m+1) )
                # print('importance weight:', importance_weights)
                
                log_error_prob = self.return_normal.log_prob(y - values).mean() 
                log_prior_prob = self.mixture_gaussian_prior(self.critic_net, sparse_ratio)  / self.pseudo_population
                # loss = - (log_error_prob + log_prior_prob ) * importance_weights
                # print(importance_weights.detach().clamp(lower, upper))
                loss = - (log_error_prob + log_prior_prob ) * importance_weights
                
            else:
                log_error_prob = self.return_normal.log_prob(y - values).mean()
                log_prior_prob = self.mixture_gaussian_prior(self.critic_net, sparse_ratio) / self.pseudo_population
                loss = - log_error_prob -  log_prior_prob 
                
            self.critic_optimizer.zero_grad()
            loss.backward()
            self.critic_optimizer.step(k)

        if self.n_updates % 100 == 0 and is_importance_weight:
            print('importance weight:', importance_weights)
        return loss.item()/batch_size
    
    def auxiliary_sampling(self, m, y, obs):
        self.auxiliary_net.load_state_dict(self.critic_net.state_dict())
        
        log_prob = torch.zeros(m)
        values = self.auxiliary_net(obs).flatten()
        
        B = 50
        for i in range(m):
            y_pred = values + self.return_sd * torch.randn(B, y.shape[0]) 
            # loss = - self.return_normal.log_prob(y_pred - values).mean() - self.mixture_gaussian_prior(self.auxiliary_net, 1.0) / 10000
            
            log_error_prob = self.return_normal.log_prob(y_pred - values).mean() 
            log_prior_prob = self.mixture_gaussian_prior(self.critic_net, self.update_sparse_ratio())  / self.pseudo_population
            loss = - (log_error_prob + log_prior_prob )
            
            self.auxiliary_optimizer.zero_grad()
            loss.backward()
            self.auxiliary_optimizer.step()
            
            values = self.auxiliary_net(obs).flatten()
            log_prob[i] = self.return_normal.log_prob(y - values.detach()).sum()
        
        return torch.logsumexp(log_prob, dim=0) - np.log(m)
            

class CriticBuffer(object):
    def __init__(self, return_sd, buffer_size=10) -> None:
        self.buffer_size = buffer_size
        self.return_sd = return_sd
        self.memory = deque(maxlen=buffer_size)
        self.normal_dist = torch.distributions.Normal(0, return_sd)
        
    def add(self, net):
        self.memory.append(deepcopy(net))
        
    def importance_weight(self, rollout_data): 
        probs = 0
        for net in self.memory:
            probs += self.calculate_prob(net, rollout_data)
        probs /= len(self.memory)
        return probs
        
        
    def calculate_prob(self, net, rollout_data):
        y = rollout_data.returns
        y_hat = net(rollout_data.observations).flatten().detach()
        log_prob = self.normal_dist.log_prob(y - y_hat)
        return torch.exp(log_prob).clamp(min=1e-3)


class GaussianMixtureModel:
    def __init__(self, prior_sd, sparse_sd):
        """
        Initialize the GMM with means, standard deviations, and mixture weights for each component.
        :param means: A tensor of shape (2,) containing the means of the two Gaussians.
        :param stds: A tensor of shape (2,) containing the standard deviations of the two Gaussians.
        :param mixture_weights: A tensor of shape (2,) containing the mixture weights of the two Gaussians.
        """
        self.means = torch.tensor([0.0, 0.0])  # Means of the two Gaussians
        self.stds = torch.tensor([prior_sd, sparse_sd])   # Standard deviations of the two Gaussians
        self.components = dist.Normal(self.means, self.stds)

    def sample(self, n_samples, sparse_ratio):
        """
        Sample n_samples from the GMM.
        :param n_samples: Number of samples to generate.
        :return: Samples from the GMM.
        """
        # Sample component indices
        self.mixture_weights = torch.tensor([sparse_ratio, 1-sparse_ratio])  # Ensure weights sum to 1
        mixture_indices = dist.Categorical(self.mixture_weights).sample((n_samples,))
        # Sample from each Gaussian
        samples = dist.Normal(self.means[mixture_indices], self.stds[mixture_indices]).sample()
        return samples

    def log_prob(self, x, sparse_ratio):
        """
        Calculate the log probability of data points x under the GMM.
        :param x: A tensor of data points.
        :return: A tensor of log probabilities.
        """
        # Calculate log probabilities from each component for each data point
        log_probs = self.components.log_prob(x.unsqueeze(1))  # Shape will be [N, num_components]
        # Weight log probabilities by the log of mixture weights
        self.mixture_weights = torch.tensor([sparse_ratio, 1-sparse_ratio])  # Ensure weights sum to 1
        log_weighted_probs = log_probs + torch.log(self.mixture_weights)
        # Log-sum-exp trick for numerical stability: log(sum(exp(log_probs)))
        log_sum_exp = torch.logsumexp(log_weighted_probs, dim=1)
        return log_sum_exp
            

if __name__ == '__main__':

    import matplotlib.pyplot as plt
    means = torch.tensor([0.0, 0.0])  # Means of the two Gaussians
    stds = torch.tensor([1.0, 0.1])   # Standard deviations of the two Gaussians
    mixture_weights = torch.tensor([0.5, 0.5])  # Mixture weights

    gmm = GaussianMixtureModel(1.0, 0.1)

    # Generate some samples or use existing data
    samples = gmm.sample(1000, 0.5)
    # plt.hist(samples, bins=50, density=True)
    # plt.show()

    # Calculate log probabilities of the generated samples
    log_probs = gmm.log_prob(samples, 0.5)

    print(log_probs.sum())