import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)


class RunningNormalizer:
    def __init__(self, dim, device, clip_range=5):
        self.dim = dim
        self.clip_range = clip_range

        # Running mean and std
        self.running_mean = torch.zeros(dim).to(device)
        self.running_var = torch.ones(dim).to(device)
        self.count = 1e-4  # Avoid division by zero

    def update(self, x):
        """Update running mean and variance using Welford's online algorithm"""
        batch_mean = x.mean(dim=0)
        batch_var = x.var(dim=0, unbiased=False)
        batch_count = x.shape[0]

        delta = batch_mean - self.running_mean
        self.running_mean = self.running_mean + delta * batch_count / (self.count + batch_count)

        m_a = self.running_var * self.count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + torch.square(delta) * self.count * batch_count / (self.count + batch_count)
        self.running_var = M2 / (self.count + batch_count)

        self.count += batch_count

    def normalize(self, x):
        """Normalize data using running statistics"""
        normalized_x = (x - self.running_mean) / (torch.sqrt(self.running_var) + 1e-8)
        return torch.clamp(normalized_x, -self.clip_range, self.clip_range)

    def denormalize(self, x):
        """Denormalize data using running statistics"""
        return x * torch.sqrt(self.running_var + 1e-8) + self.running_mean


class EnsembleModel(nn.Module):
    def __init__(self, state_dim, action_dim, device, hidden_units=200, ensemble_size=10, lr=1e-3,
                 min_logvar=-20, max_logvar=5, lr_decay=0.999):
        super(EnsembleModel, self).__init__()

        self.state_dim = state_dim
        self.ensemble_size = ensemble_size
        self.lr = lr
        self.lr_decay = lr_decay
        self.min_logvar = min_logvar
        self.max_logvar = max_logvar

        self.models = nn.ModuleList([nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_units),
            nn.LeakyReLU(),
            # nn.SiLU(),
            nn.Linear(hidden_units, hidden_units),
            nn.LeakyReLU(),
            # nn.SiLU(),
            nn.Linear(hidden_units, hidden_units),
            nn.LeakyReLU(),
            # nn.SiLU(),
            # nn.Linear(hidden_units, hidden_units),
            # nn.LeakyReLU(),
            # nn.SiLU(),
            nn.Linear(hidden_units, hidden_units),
            nn.LeakyReLU(),
            # nn.SiLU(),
            nn.Linear(hidden_units, (state_dim + 1) * 2)  # +1 for reward prediction, *2 for mean and variance
        ) for _ in range(ensemble_size)])

        # Initialize weights with smaller values
        for model in self.models:
            model.apply(init_weights)

        # Use a more robust optimizer
        self.model_optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=lr,
            weight_decay=0.0001,
            eps=1e-8  # Numerical stability
        )

        self.mean = None
        self.std = None

        # Add normalizers
        self.state_normalizer = RunningNormalizer(state_dim, device)
        self.diff_normalizer = RunningNormalizer(state_dim, device)  # for state differences
        self.reward_normalizer = RunningNormalizer(1, device)

    def forward(self, x):

        # Normalize states only
        states = x[:, :self.state_dim]
        actions = x[:, self.state_dim:]
        normalized_states = self.state_normalizer.normalize(states)
        normalized_x = torch.cat([normalized_states, actions], dim=1)

        means, logvars = [], []
        for ensemble_member in self.models:

            output = ensemble_member(normalized_x)
            mean, logvar = torch.chunk(output, 2, dim=-1)

            # Clip logvar for numerical stability
            logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)

            # Split predictions into next state and reward
            state_mean, reward_mean = mean[:, :-1], mean[:, -1:]
            state_logvar, reward_logvar = logvar[:, :-1], logvar[:, -1:]

            # Denormalize predictions
            state_mean = self.state_normalizer.denormalize(state_mean)
            reward_mean = self.reward_normalizer.denormalize(reward_mean)

            # Combine denormalized predictions
            mean = torch.cat([state_mean, reward_mean], dim=1)
            logvar = torch.cat([state_logvar, reward_logvar], dim=1)

            means.append(mean)
            logvars.append(logvar)

        return means, logvars

    def update_normalizer(self, states, next_states, rewards):
        """Update normalizers with new batch of data"""
        self.state_normalizer.update(states)
        self.diff_normalizer.update(next_states - states)
        self.reward_normalizer.update(rewards)

    def decay_lr(self):
        """Decay learning rate"""
        for param_group in self.model_optimizer.param_groups:
            param_group['lr'] *= self.lr_decay

