import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from collections import deque

class Discriminator(nn.Module):
    def __init__(self, state_dim, num_latents=4, hidden_size=128, learning_rate=0.0003, layernorm = True, device=None):
        super(Discriminator, self).__init__()
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_latents = num_latents
        self.latents = torch.eye(self.num_latents, device=self.device)
        self.learning_rate = learning_rate
        self.training = False

        self.ema_mean = 0.0
        self.ema_std = 1.0
        self.ema_alpha = 0.99
        self.clip_ratio = 0.3
        
        # Track discrimination accuracy
        self.accuracy_buffer = deque(maxlen=1000)
        
        if layernorm:
            self.fc = nn.Sequential(
                nn.Linear(state_dim, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, num_latents)
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(state_dim, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, num_latents)
            )

        self.to(self.device)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
    

    def forward(self, state):
        logits = self.fc(state)
        return torch.softmax(logits, dim=-1) 

    def compute_loss(self, states, true_z, weight_clip):
        logits = self(states)
        loss = F.cross_entropy(logits, true_z.argmax(dim=1))
        loss = torch.mean(weight_clip.detach() * loss.view(-1,1))
        return loss, logits

    def train_step(self, states, true_z, weight_clip):
        states = states.to(self.device)
        true_z = true_z.to(self.device)

        loss, logits = self.compute_loss(states, true_z, weight_clip)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()  
    
    def compute_intrinsic_reward(self, states, z_index):
        states = torch.from_numpy(states).float().to(self.device)

        with torch.no_grad():
            probs = self(states)
            log_q_z = torch.log(probs + 1e-8)   

            intrinsic_rewards = log_q_z[z_index]
            if log_q_z.dim() != 1:
                intrinsic_rewards = log_q_z[:,z_index]
            
            # Return 0 if not training:
            if not self.training:
                return intrinsic_rewards.detach().cpu().numpy() * 0

            # Update running statistics
            curr_mean = intrinsic_rewards.mean().item()
            curr_std = intrinsic_rewards.std(unbiased=False).item()
            if curr_std < 1e-8:
                curr_std = 1.0
            self.ema_mean = self.ema_alpha * self.ema_mean + (1 - self.ema_alpha) * curr_mean
            self.ema_std = self.ema_alpha * self.ema_std + (1 - self.ema_alpha) * curr_std

            # Normalize using EMA stats
            if abs(self.ema_std) < 1e-8:
                normalized = (intrinsic_rewards - self.ema_mean)
            else:
                normalized = (intrinsic_rewards - self.ema_mean) / (self.ema_std + 1e-8)

            normalized = normalized.reshape(-1, 1)
            normalized = torch.clamp(normalized, -2, 2)
            return normalized.detach().cpu().numpy()

    def update(self, states, actions, critics_target, step):
        assert states.shape[0] == actions.shape[0]
        batch_size = states.shape[0] // self.num_latents
        z = self.latents.repeat_interleave(batch_size, dim=0)
        
        with torch.no_grad():
            all_minQs = []
            for critic in critics_target:
                Q1, Q2 = critic(states, actions)
                minQ = torch.min(Q1, Q2)
                all_minQs.append(minQ)
            avg_minQ = torch.mean(torch.stack(all_minQs), dim=0)
            max_minQ = torch.max(avg_minQ)
            weight = torch.exp(avg_minQ - max_minQ) / torch.mean(torch.exp(avg_minQ - max_minQ))
            weight_clip = torch.clamp(weight, 1 - self.clip_ratio, 1 + self.clip_ratio)
        

        loss = self.train_step(states, z, weight_clip)
        
        # Update reward statistics
        with torch.no_grad():
            predictions = self(states).argmax(dim=1)
            accuracy = (predictions == z.argmax(dim=1)).float().mean().item()
            self.accuracy_buffer.append(accuracy)
            probs = self(states)

        if step % 500 == 0:
            wandb.log({
                'discriminator/loss': loss,
                'discriminator/mean_accuracy': np.mean(self.accuracy_buffer),
                'discriminator/reward_mean': self.ema_mean,
                'discriminator/reward_std': self.ema_std,
                'discriminator/z_prediction_accuracy': accuracy,
                'discriminator/entropy': -torch.mean(torch.sum(probs * torch.log(probs + 1e-8), dim=1)).item()
            })

    def reset_network(self):
        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
        # Also reset statistics
        self.ema_mean = None
        self.ema_std = None
        self.accuracy_buffer.clear()

        self.fc.apply(init_weights)
        self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def adaptive_update(self, adapt_rate=0.1):
        """
        adapt_rate: Float between 0 and 1, controls adaptation rate
            0 does nothing
            1 is almost full reset
        """
        # Partially update network weights instead of full reset
        with torch.no_grad():
            for param in self.parameters():
                # Initialize new random weights
                new_weights = torch.randn_like(param)
                nn.init.xavier_uniform_(new_weights) if len(param.shape) > 1 else nn.init.zeros_(new_weights)
                # Blend old and new weights
                param.data = (1 - adapt_rate) * param.data + adapt_rate * new_weights

        if self.ema_mean is not None:
            self.ema_alpha = max(0.9, self.ema_alpha * (1 - adapt_rate))
                    
        self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate * (1 + adapt_rate))