import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from utils.diffusion import IMMDiffusion
from torch.optim.lr_scheduler import CosineAnnealingLR
from agents.basic_il import BaseImitationLearning

# Import SiLU compatibility fix
from utils.helpers import get_silu

# ============================================================================
# Q-Learning Integration: MMD Diffusion Q-Learning
# ============================================================================

class MMDCritic(nn.Module):
    """
    Dual Q-network Critic for estimating state-action value functions
    """
    def __init__(self, state_dim, action_dim, hidden_dim=512):
        super(MMDCritic, self).__init__()
        self.q1_model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, 1)
        )

        self.q2_model = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x), self.q2_model(x)

    def q1(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x)

    def q_min(self, state, action):
        q1, q2 = self.forward(state, action)
        return torch.min(q1, q2)


class MoMa_QL(BaseImitationLearning):
    """
    Q-learning algorithm based on MMD loss and diffusion models
    
    This algorithm combines:
    1. IMMPrecond: Conditional diffusion model as policy network
    2. SimpleMMDLoss: Distribution matching loss based on Maximum Mean Discrepancy
    3. Standard Q-learning framework: Dual Q-networks + Bellman updates
    
    Core idea:
    - Train diffusion policy network with MMD loss to match expert action distribution
    - Simultaneously optimize cumulative reward using Q-learning
    """
    
    def __init__(self,
                state_dim,
                action_dim,
                max_action,
                device="cuda" if torch.cuda.is_available() else "cpu",
                model="MLP",
                discount=0.99,
                tau=0.005,
                
                # Q-learning related parameters
                max_q_backup=False,
                eta=2.0,  # Q-learning weight
                # Candidate sampling settings
                backup_candidate_num=10,     # for max-q backup
                eval_candidate_num=10,       # for action selection during evaluation/inference
                action_select="softmax",      # "greedy" | "softmax"
                softmax_temperature=1.0,     # temperature for softmax selection
                softmax_topk=None,           # optionally apply top-k filter before softmax
                
                # Diffusion model related parameters
                noise_schedule="fm",
                sigma_data=0.5,
                f_type="euler_fm",
                T=0.994,
                eps=0.001,
                temb_type='identity',
                time_scale=1000.,
                
                # MMD loss related parameters
                mmd_sigma=1,
                sample_t_mode="lognormal",
                P_mean=-1.1,
                P_std=2.0, 
                matrix_size=512, 
                sample_repeat=1,
                k=12,
                a=2,
                b=4, 
                min_tr_gap=None,
                
                # Training related parameters
                ema_decay=0.995,
                step_start_ema=1000,
                update_ema_every=5,
                lr=3e-4,
                lr_decay=False,
                lr_maxt=1000,
                grad_norm=1.0,
                q_norm=False,
                
                
                # Adam optimizer parameters
                adam_beta1=0.9,              # Adam beta1: exponential decay rate for 1st moment
                adam_beta2=0.999,            # Adam beta2: exponential decay rate for 2nd moment  
                adam_eps=1e-8,               # Adam epsilon: numerical stability term
                adam_weight_decay=0.0,       # Adam weight decay (L2 regularization)
                adam_amsgrad=False,          # Use AMSGrad variant of Adam
                
                # Learning rate decay parameters
                lr_decay_steps=None,
                lr_min_factor=0.1,
                
                # CFG parameters
                cfg_scale=None,
                cfg_dropout_prob=0.1,
                
                **kwargs
                ):
        
        # Save basic parameters
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.device = device
        self.discount = discount
        self.tau = tau
        self.eta = eta
        self.max_q_backup = max_q_backup
        self.lr_decay = lr_decay
        self.grad_norm = grad_norm
        self.q_norm = q_norm
        # Candidate sampling settings
        self.backup_candidate_num = int(backup_candidate_num) if backup_candidate_num is not None else 10
        self.eval_candidate_num = int(eval_candidate_num) if eval_candidate_num is not None else 10
        self.action_select = action_select
        self.softmax_temperature = float(softmax_temperature) if softmax_temperature is not None else 1.0
        self.softmax_topk = int(softmax_topk) if softmax_topk is not None else None
        
        
        # Optimizer parameters
        self.adam_beta1 = adam_beta1
        self.adam_beta2 = adam_beta2
        self.adam_eps = adam_eps
        self.adam_weight_decay = adam_weight_decay
        self.adam_amsgrad = adam_amsgrad
        
        # Learning rate parameters
        self.lr_decay_steps = lr_decay_steps if lr_decay_steps else lr_maxt
        self.lr_min_factor = lr_min_factor
        
        # CFG parameters
        self.cfg_scale = cfg_scale
        self.cfg_dropout_prob = cfg_dropout_prob
        
        # Initialize diffusion policy network (IMMPrecond)
        self.actor = IMMDiffusion(
            model,
            state_dim=state_dim,
            action_dim=action_dim,
            max_action=max_action,
            device=device,
            # IMMPrecond parameters
            noise_schedule=noise_schedule,
            sigma_data=sigma_data, 
            f_type=f_type,
            T=T,
            eps=eps,  
            temb_type=temb_type, 
            time_scale=time_scale,
            # IMMLoss parameters
            mmd_sigma=mmd_sigma,
            sample_t_mode=sample_t_mode,
            P_mean=P_mean,
            P_std=P_std, 
            matrix_size=matrix_size, 
            sample_repeat=sample_repeat,
            k=k,
            a=a,
            b=b, 
            min_tr_gap=min_tr_gap,
            # CFG parameters
            cfg_scale=cfg_scale,
        ).to(device)
        
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), 
            lr=lr,
            betas=(self.adam_beta1, self.adam_beta2),
            eps=self.adam_eps,
            weight_decay=self.adam_weight_decay,
            amsgrad=self.adam_amsgrad
        )
        
        # # Initialize IMM loss function with neural network features
        # self.mmd_loss = IMMLoss(sigma=mmd_sigma, matrix_size=4)
        
        # EMA model
        self.step = 0
        self.step_start_ema = step_start_ema
        from utils.helpers import EMA
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.actor)
        self.update_ema_every = update_ema_every
        
        # Initialize Critic networks
        self.critic = MMDCritic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), 
            lr=lr,  # Use same lr as actor or separate critic_lr if provided
            betas=(self.adam_beta1, self.adam_beta2),
            eps=self.adam_eps,
            weight_decay=self.adam_weight_decay,
            amsgrad=self.adam_amsgrad
        )
        
        # Learning rate scheduler
        if lr_decay:
            from torch.optim.lr_scheduler import CosineAnnealingLR
            eta_min_actor = lr * self.lr_min_factor
            eta_min_critic = lr * self.lr_min_factor
            self.actor_lr_scheduler = CosineAnnealingLR(
                self.actor_optimizer, 
                T_max=self.lr_decay_steps, 
                eta_min=eta_min_actor
            )
            self.critic_lr_scheduler = CosineAnnealingLR(
                self.critic_optimizer, 
                T_max=self.lr_decay_steps, 
                eta_min=eta_min_critic
            )

    def step_ema(self):
        """Update EMA model"""
        if self.step < self.step_start_ema:
            return
        self.ema.update_model_average(self.ema_model, self.actor)
    
    
    def train_with_balanced_sampling(self, offline_sampler, online_sampler, iterations, batch_size=256, 
                                     balanced_ratio=0.5, log_writer=None, use_grad=False):
        """
        Train with balanced sampling from offline and online buffers
        
        Args:
            offline_sampler: Data_Sampler for offline dataset
            online_sampler: Data_Sampler for online replay buffer
            iterations: number of training iterations
            batch_size: total batch size
            balanced_ratio: ratio of online data (0.0=pure offline, 1.0=pure online)
            log_writer: tensorboard writer
            use_grad: whether to use gradient through diffusion sampling
        """
        metric = {'bc_loss': [], 'ql_loss': [], 'actor_loss': [], 'critic_loss': []}
        
        # Calculate batch sizes for each sampler
        online_batch_size = int(batch_size * balanced_ratio)
        offline_batch_size = batch_size - online_batch_size
        
        for iteration in tqdm(range(iterations), desc="Online Training (Balanced Sampling)", leave=False):
            # Sample from both buffers
            if offline_batch_size > 0:
                offline_state, offline_action, offline_next_state, offline_reward, offline_not_done = \
                    offline_sampler.sample(offline_batch_size)
            
            if online_batch_size > 0:
                online_state, online_action, online_next_state, online_reward, online_not_done = \
                    online_sampler.sample(online_batch_size)
            
            # Merge batches
            if offline_batch_size > 0 and online_batch_size > 0:
                state = torch.cat([offline_state, online_state], dim=0)
                action = torch.cat([offline_action, online_action], dim=0)
                next_state = torch.cat([offline_next_state, online_next_state], dim=0)
                reward = torch.cat([offline_reward, online_reward], dim=0)
                not_done = torch.cat([offline_not_done, online_not_done], dim=0)
            elif offline_batch_size > 0:
                state, action, next_state, reward, not_done = \
                    offline_state, offline_action, offline_next_state, offline_reward, offline_not_done
            else:
                state, action, next_state, reward, not_done = \
                    online_state, online_action, online_next_state, online_reward, online_not_done
            
            # 以下训练逻辑与原train方法相同
            """ Q-network training """
            current_q1, current_q2 = self.critic(state, action)

            if self.max_q_backup:
                repeats = self.backup_candidate_num
                next_state_rpt = torch.repeat_interleave(next_state, repeats=repeats, dim=0)
                next_action_rpt = self.ema_model.sample(next_state_rpt)
                target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
                target_q1 = target_q1.view(batch_size, repeats).max(dim=1, keepdim=True)[0]
                target_q2 = target_q2.view(batch_size, repeats).max(dim=1, keepdim=True)[0]
                target_q = torch.min(target_q1, target_q2)
            else:
                next_action = self.ema_model.sample(next_state)
                target_q1, target_q2 = self.critic_target(next_state, next_action)
                target_q = torch.min(target_q1, target_q2)

            target_q = (reward + not_done * self.discount * target_q).detach()
            
            if self.q_norm:
                target_q = (target_q - target_q.mean()) / (target_q.std() + 1e-6)

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self.grad_norm > 0:
                critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.critic_optimizer.step()
            
            """ Policy network training """
            if use_grad is False:
                with torch.no_grad():
                    generated_actions = self.actor.sample(state)
            else:
                generated_actions = self.actor.sample(state)
            
            mmd_loss_value, mmd_logs = self.actor.loss(
                action, state, cfg_dropout_prob=self.cfg_dropout_prob
            )
            
            generated_actions = self.actor.sample(state)
            q1_new_action, q2_new_action = self.critic(state, generated_actions)
            
            if np.random.uniform() > 0.5:
                q_loss = - q1_new_action.mean() / q2_new_action.abs().mean().detach()
            else:
                q_loss = - q2_new_action.mean() / q1_new_action.abs().mean().detach()
            
            actor_loss = mmd_loss_value + self.eta * q_loss
            
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            
            if self.grad_norm > 0:
                actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.actor_optimizer.step()
            
            """ Update target networks and EMA """
            if self.step % self.update_ema_every == 0:
                self.step_ema()
            
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            self.step += 1
            
            """ Record metrics """
            if log_writer is not None:
                if self.grad_norm > 0:
                    log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
                    log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step)
                log_writer.add_scalar('Actor Loss', actor_loss.item(), self.step)
                log_writer.add_scalar('MMD Loss', mmd_loss_value.item(), self.step)
                log_writer.add_scalar('QL Loss', q_loss.item(), self.step)
                log_writer.add_scalar('Critic Loss', critic_loss.item(), self.step)
                log_writer.add_scalar('Target_Q Mean', target_q.mean().item(), self.step)
            
            metric['actor_loss'].append(actor_loss.item())
            metric['bc_loss'].append(mmd_loss_value.item())
            metric['ql_loss'].append(q_loss.item())
            metric['critic_loss'].append(critic_loss.item())
        
        if self.lr_decay:
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()
        
        return metric
    
    def train(self, env, iterations, batch_size=1, log_writer=None, use_grad=False):
        """
        Train MMD diffusion Q-learning algorithm with direct environment interaction
        No replay buffer - each training step uses fresh transitions from the environment
        """
        metric = {'bc_loss': [], 'ql_loss': [], 'actor_loss': [], 'critic_loss': []}
        
        # Initialize environment state
        current_state = env.reset()
        
        for iteration in tqdm(range(iterations), desc="Online Training", leave=False):
            # Collect batch of transitions by interacting with environment
            states = []
            actions = []
            next_states = []
            rewards = []
            dones = []
            
            for _ in range(batch_size):
                # Get action from current policy
                action = self.sample_action(current_state)
                
                # Execute in environment
                next_state, reward, done, _ = env.step(action)
                
                # Store transition
                states.append(current_state)
                actions.append(action)
                next_states.append(next_state)
                rewards.append(reward)
                dones.append(done)
                
                # Update current state
                current_state = next_state
                if done:
                    current_state = env.reset()
            
            # Convert to tensors
            state = torch.FloatTensor(np.array(states)).to(self.device)
            action = torch.FloatTensor(np.array(actions)).to(self.device)
            next_state = torch.FloatTensor(np.array(next_states)).to(self.device)
            reward = torch.FloatTensor(np.array(rewards)).reshape(-1, 1).to(self.device)
            not_done = torch.FloatTensor(1.0 - np.array(dones, dtype=np.float32)).reshape(-1, 1).to(self.device)
            
            """ Q-network training """
            current_q1, current_q2 = self.critic(state, action)

            if self.max_q_backup:
                repeats = self.backup_candidate_num
                next_state_rpt = torch.repeat_interleave(next_state, repeats=repeats, dim=0)
                next_action_rpt = self.ema_model.sample(next_state_rpt)
                target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
                target_q1 = target_q1.view(batch_size, repeats).max(dim=1, keepdim=True)[0]
                target_q2 = target_q2.view(batch_size, repeats).max(dim=1, keepdim=True)[0]
                target_q = torch.min(target_q1, target_q2)
            else:
                next_action = self.ema_model.sample(next_state)
                target_q1, target_q2 = self.critic_target(next_state, next_action)
                target_q = torch.min(target_q1, target_q2)

            target_q = (reward + not_done * self.discount * target_q).detach()
            
            # normalize the target_q if q_norm is enabled
            if self.q_norm:
                target_q = (target_q - target_q.mean()) / (target_q.std() + 1e-6)

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self.grad_norm > 0:
                critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.critic_optimizer.step()
            
            """ Policy network training """
            # Generate actions from current policy
            if use_grad is False:
                with torch.no_grad():
                    generated_actions = self.actor.sample(state)
            else:
                generated_actions = self.actor.sample(state)
            '''
            generated_actions = self.actor.sample(state)
            '''
            # Compute IMM loss using neural network features
            # Pass cfg_dropout_prob for CFG training
            mmd_loss_value, mmd_logs = self.actor.loss(
                action, state, cfg_dropout_prob=self.cfg_dropout_prob
            )
            
            # Generate actions from current policy
            # Note: gradients are disabled by default (use_grad=False)
            # For Q-learning with gradient, set use_grad=True and reduce num_steps
            generated_actions = self.actor.sample(state)
            
            # Q-learning loss: encourage generating high Q-value actions            
            q1_new_action, q2_new_action = self.critic(state, generated_actions)
            
            if np.random.uniform() > 0.5:
                q_loss = - q1_new_action.mean() / q2_new_action.abs().mean().detach()
            else:
                q_loss = - q2_new_action.mean() / q1_new_action.abs().mean().detach()
            # q_loss = -torch.min(q1_new_action, q2_new_action).mean()
            
            # Total policy loss
            #print("mmd_loss_value", mmd_loss_value)
            #print("q_loss", q_loss)
            actor_loss = mmd_loss_value + self.eta * q_loss
            
            # Update policy network
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            
            if self.grad_norm > 0:
                actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.actor_optimizer.step()
            
            """ Update target networks and EMA """
            if self.step % self.update_ema_every == 0:
                self.step_ema()
            
            # Soft update Critic target network
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            self.step += 1
            
            """ Record metrics """
            if log_writer is not None:
                if self.grad_norm > 0:
                    log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
                    log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step)
                log_writer.add_scalar('Actor Loss', actor_loss.item(), self.step)
                log_writer.add_scalar('MMD Loss', mmd_loss_value.item(), self.step)
                log_writer.add_scalar('QL Loss', q_loss.item(), self.step)
                log_writer.add_scalar('Critic Loss', critic_loss.item(), self.step)
                log_writer.add_scalar('Target_Q Mean', target_q.mean().item(), self.step)
            
            metric['actor_loss'].append(actor_loss.item())
            metric['bc_loss'].append(mmd_loss_value.item())
            metric['ql_loss'].append(q_loss.item())
            metric['critic_loss'].append(critic_loss.item())
        
        # Learning rate scheduling
        if self.lr_decay:
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()
        
        return metric
    
    def sample_action(self, state, cfg_scale=None):
        """
        Sample action for environment interaction with optional CFG
        Use Q-value weighted sampling to select best action
        
        Args:
            state: input state
            cfg_scale: CFG guidance scale (overrides self.cfg_scale if provided)
        
        Returns:
            sampled action
        """
        state = torch.as_tensor(state, dtype=torch.float32, device=self.device).view(1, -1)
        repeats = max(1, int(self.eval_candidate_num))

        # Sample multiple candidate actions
        state_rpt = torch.repeat_interleave(state, repeats=repeats, dim=0)

        # Determine guidance scale
        guidance_scale = cfg_scale if cfg_scale is not None else None

        with torch.no_grad():
            # Sample with CFG if enabled via guidance scale
            if (guidance_scale is not None) and (guidance_scale > 0):
                action_candidates = self.actor.sample(state_rpt, cfg_scale=guidance_scale)
            else:
                action_candidates = self.actor.sample(state_rpt)

            # Clamp actions to valid range
            action_candidates = torch.clamp(action_candidates, -self.max_action, self.max_action)

            # Compute Q-value for each candidate action
            q_values = self.critic_target.q_min(state_rpt, action_candidates).flatten()

            # Handle non-finite values
            if not torch.isfinite(q_values).all():
                finite_mask = torch.isfinite(q_values)
                if finite_mask.any():
                    q_values = q_values.where(finite_mask, q_values[finite_mask].min())
                else:
                    # fallback to uniform if all invalid
                    probs = torch.ones(repeats, device=self.device) / repeats
                    idx = torch.multinomial(probs, 1)
                    return action_candidates[idx].cpu().data.numpy().flatten()

            if self.action_select == "softmax":
                # optional top-k before softmax
                if self.softmax_topk is not None and self.softmax_topk < repeats:
                    k = max(1, int(self.softmax_topk))
                    topk_vals, topk_idx = torch.topk(q_values, k=k, largest=True, sorted=False)
                    # safe softmax with temperature
                    temp = max(1e-6, float(self.softmax_temperature))
                    logits = (topk_vals - topk_vals.max()) / temp
                    probs = F.softmax(logits, dim=0)
                    pick_rel = torch.multinomial(probs, 1)
                    idx = topk_idx[pick_rel]
                else:
                    temp = max(1e-6, float(self.softmax_temperature))
                    logits = (q_values - q_values.max()) / temp
                    probs = F.softmax(logits, dim=0)
                    idx = torch.multinomial(probs, 1)
            else:
                # greedy selection by default
                idx = torch.argmax(q_values, dim=0).view(1)

        return action_candidates[idx].cpu().data.numpy().flatten()
    
    def save_model(self, dir, id=None):
        """Save model"""
        import os
        os.makedirs(dir, exist_ok=True)
        if id is not None:
            torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
            torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
        else:
            torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
            torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
    
    def load_model(self, dir, id=None):
        """Load model"""
        if id is not None:
            self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth', weights_only=True))
            self.ema_model.load_state_dict(torch.load(f'{dir}/actor_{id}.pth', weights_only=True))
            self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth', weights_only=True))
        else:
            self.actor.load_state_dict(torch.load(f'{dir}/actor.pth', weights_only=True))
            self.ema_model.load_state_dict(torch.load(f'{dir}/actor.pth', weights_only=True))
            self.critic.load_state_dict(torch.load(f'{dir}/critic.pth', weights_only=True))
