import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from agents.basic_il import BaseImitationLearning


class VectorFieldNetwork(nn.Module):
    """
    Vector field network for learning velocity fields in the Flow process.
    
    This network learns continuous transformations from noise x_0 to expert actions x_1,
    by predicting the velocity field v(x_t, t, s) at time t to guide the flow process.
    """
    
    def __init__(self, state_dim, action_dim, time_dim=32, hidden_dim=256):
        """
        Initialize the vector field network.
        
        Args:
            state_dim (int): Dimension of state space
            action_dim (int): Dimension of action space
            time_dim (int): Dimension of time embedding
            hidden_dim (int): Dimension of hidden layers
        """
        super().__init__()
        
        # Time embedding network - encodes continuous time t into vector representation
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_dim),
            nn.Mish(),
            nn.Linear(time_dim, time_dim),
            nn.Mish(),
        )
        
        # Main network: input is [action, state, time_embedding], output is velocity field
        input_dim = action_dim + state_dim + time_dim
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, action_dim),
        )
    
    def forward(self, x_t, t, state):
        """
        Forward pass: predict velocity field at time t given state.
        
        Args:
            x_t (torch.Tensor): Current flow state [batch_size, action_dim]
            t (torch.Tensor): Time parameter [batch_size, 1]
            state (torch.Tensor): Environment state [batch_size, state_dim]
            
        Returns:
            torch.Tensor: Predicted velocity field [batch_size, action_dim]
        """
        # Encode time information
        t_emb = self.time_mlp(t)
        
        # Concatenate all inputs
        x_input = torch.cat([x_t, state, t_emb], dim=-1)
        
        # Predict velocity field
        velocity = self.mlp(x_input)
        
        return velocity


class OneStepFlowNetwork(nn.Module):
    """
    One-step Flow network for fast sampling.
    
    This network directly predicts actions from noise and state, avoiding the multi-step 
    integration process of BC Flow. It learns from BC Flow through knowledge distillation,
    providing faster inference speed.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        """
        Initialize the one-step Flow network.
        
        Args:
            state_dim (int): Dimension of state space
            action_dim (int): Dimension of action space
            hidden_dim (int): Dimension of hidden layers
        """
        super().__init__()
        
        # Input is [noise, state], output is action
        input_dim = action_dim + state_dim
        import ipdb;ipdb.set_trace()
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()  # Tanh output ensures actions are in reasonable range
        )
    
    def forward(self, noise, state):
        """
        Forward pass: directly predict action from noise and state.
        
        Args:
            noise (torch.Tensor): Random noise [batch_size, action_dim]
            state (torch.Tensor): Environment state [batch_size, state_dim]
            
        Returns:
            torch.Tensor: Predicted action [batch_size, action_dim]
        """
        x_input = torch.cat([noise, state], dim=-1)
        import ipdb; ipdb.set_trace()
        action = self.mlp(x_input)
        return action


class CriticNetwork(nn.Module):
    """
    Critic network for value function estimation.
    
    Implements double Q-network architecture to reduce overestimation bias,
    takes state and action as input, outputs Q-value estimates.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        """
        Initialize the Critic network.
        
        Args:
            state_dim (int): Dimension of state space
            action_dim (int): Dimension of action space
            hidden_dim (int): Dimension of hidden layers
        """
        super().__init__()
        
        input_dim = state_dim + action_dim
        
        # Q network 1
        self.q1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )
        
        # Q network 2
        self.q2 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )
    
    def forward(self, state, action):
        """
        Forward pass: estimate Q-values for state-action pairs.
        
        Args:
            state (torch.Tensor): Environment state [batch_size, state_dim]
            action (torch.Tensor): Action [batch_size, action_dim]
            
        Returns:
            tuple: (q1, q2) - outputs from both Q networks
        """
        x = torch.cat([state, action], dim=-1)
        q1 = self.q1(x)
        q2 = self.q2(x)
        return q1, q2


class Flow_QL(BaseImitationLearning):
    """
    Flow Q-Learning (FQL) algorithm implementation.
    
    FQL combines Flow-based behavioral cloning, knowledge distillation, and Q-learning, including:
    1. BC Flow: Learn continuous transformations from noise to expert actions
    2. One-step Flow: Learn fast sampling through knowledge distillation
    3. Critic: Q-learning provides value guidance
    4. Joint optimization of three loss functions
    
    Key advantages:
    - Multi-modal behavior modeling (BC Flow)
    - Fast inference sampling (One-step Flow)
    - Value function guidance (Critic)
    - Robust training process
    """
    
    def __init__(self,
                 state_dim,
                 action_dim,
                 max_action,
                 device,
                 flow_steps=10,
                 lr=3e-4,
                 lr_decay=False,
                 lr_maxt=1000,
                 grad_norm=1.0,
                 hidden_dim=256,
                 time_dim=32,
                 discount=0.99,
                 tau=0.005,
                 alpha=10.0,
                 q_agg='mean',
                 normalize_q_loss=False,
                 **kwargs
                ):
        """
        Initialize Flow Q-Learning agent.
        
        Args:
            state_dim (int): Dimension of state space
            action_dim (int): Dimension of action space
            max_action (float): Maximum action value
            device (torch.device): Compute device
            flow_steps (int): Number of integration steps for BC Flow
            lr (float): Learning rate
            lr_decay (bool): Whether to use learning rate decay
            lr_maxt (int): Maximum steps for cosine annealing
            grad_norm (float): Gradient clipping threshold
            hidden_dim (int): Dimension of hidden layers
            time_dim (int): Dimension of time embedding
            discount (float): Discount factor
            tau (float): Soft update coefficient for target network
            alpha (float): BC coefficient, controls behavioral cloning loss weight
            q_agg (str): Q-value aggregation method ('mean' or 'min')
            normalize_q_loss (bool): Whether to normalize Q loss
        """
        super().__init__(state_dim, action_dim, max_action, device, lr)
        
        self.flow_steps = flow_steps
        self.grad_norm = grad_norm
        self.lr_decay = lr_decay
        self.discount = discount
        self.tau = tau
        self.alpha = alpha
        self.q_agg = q_agg
        self.normalize_q_loss = normalize_q_loss
        
        # === Network Architecture ===
        
        # BC Flow network - learns complete flow process
        self.bc_flow = VectorFieldNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            time_dim=time_dim,
            hidden_dim=hidden_dim
        ).to(device)
        
        # One-step Flow network - fast sampling
        self.one_step_flow = OneStepFlowNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim
        ).to(device)
        
        # Critic network
        self.critic = CriticNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim
        ).to(device)
        
        # Target Critic network
        self.target_critic = copy.deepcopy(self.critic).to(device)
        
        # Freeze target network gradients
        for param in self.target_critic.parameters():
            param.requires_grad = False
        
        # === Optimizers ===
        
        # Create separate optimizers for different components
        self.bc_flow_optimizer = torch.optim.Adam(self.bc_flow.parameters(), lr=lr)
        self.one_step_flow_optimizer = torch.optim.Adam(self.one_step_flow.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        
        # Optional learning rate scheduling
        if lr_decay:
            self.bc_flow_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.bc_flow_optimizer, T_max=lr_maxt, eta_min=0.)
            self.one_step_flow_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.one_step_flow_optimizer, T_max=lr_maxt, eta_min=0.)
            self.critic_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.critic_optimizer, T_max=lr_maxt, eta_min=0.)
    
    def critic_loss(self, batch):
        """
        Compute Critic loss (Q-learning component).
        
        Args:
            batch: Batch data (state, action, reward, next_state, done)
            
        Returns:
            tuple: (critic_loss, critic_info)
        """
        state, action, reward, next_state, done = batch
        
        # Sample next actions using One-step Flow
        with torch.no_grad():
            next_noise = torch.randn_like(action)
            next_actions = self.one_step_flow(next_noise, next_state)
            next_actions = torch.clamp(next_actions, -self.max_action, self.max_action)
            
            # Compute target Q-values
            next_q1, next_q2 = self.target_critic(next_state, next_actions)
            
            if self.q_agg == 'min':
                next_q = torch.min(next_q1, next_q2)
            else:
                next_q = (next_q1 + next_q2) / 2.0
            
            target_q = reward + self.discount * (1 - done) * next_q
        
        # Current Q-values
        current_q1, current_q2 = self.critic(state, action)
        
        # Critic loss
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        
        critic_info = {
            'critic_loss': critic_loss.item(),
            'q1_mean': current_q1.mean().item(),
            'q2_mean': current_q2.mean().item(),
            'target_q_mean': target_q.mean().item(),
        }
        
        return critic_loss, critic_info
    
    def actor_loss(self, batch):
        """
        Compute Actor loss, including BC Flow, knowledge distillation, and Q loss.
        
        Args:
            batch: Batch data (state, action, reward, next_state, done)
            
        Returns:
            tuple: (actor_loss, actor_info)
        """
        state, action, _, _, _ = batch
        batch_size = state.shape[0]
        
        # 1. BC Flow loss
        x_0 = torch.randn_like(action)
        x_1 = action  # Expert actions
        t = torch.rand(batch_size, 1, device=self.device)
        x_t = (1 - t) * x_0 + t * x_1
        target_velocity = x_1 - x_0
        
        predicted_velocity = self.bc_flow(x_t, t, state)
        bc_flow_loss = F.mse_loss(predicted_velocity, target_velocity)
        
        # 2. Knowledge distillation loss - One-step Flow learns BC Flow outputs
        noise = torch.randn_like(action)
        
        # Target actions from BC Flow (via Euler integration)
        with torch.no_grad():
            target_flow_actions = self.compute_flow_actions(state, noise)
        
        # One-step Flow predictions
        one_step_actions = self.one_step_flow(noise, state)
        distill_loss = F.mse_loss(one_step_actions, target_flow_actions)
        
        # 3. Q loss - use Critic to guide One-step Flow
        actor_actions = torch.clamp(one_step_actions, -self.max_action, self.max_action)
        q1, q2 = self.critic(state, actor_actions)
        q = (q1 + q2) / 2.0
        
        q_loss = -q.mean()
        
        # Optional Q loss normalization
        if self.normalize_q_loss:
            with torch.no_grad():
                scale = 1.0 / (torch.abs(q).mean() + 1e-8)
            q_loss = scale * q_loss
        
        # Total Actor loss
        actor_loss = bc_flow_loss + self.alpha * distill_loss + q_loss
        
        # Additional metric: MSE with expert actions
        with torch.no_grad():
            sampled_actions = self.one_step_flow(torch.randn_like(action), state)
            mse_with_expert = F.mse_loss(sampled_actions, action)
        
        actor_info = {
            'actor_loss': actor_loss.item(),
            'bc_flow_loss': bc_flow_loss.item(),
            'distill_loss': distill_loss.item(),
            'q_loss': q_loss.item(),
            'q_mean': q.mean().item(),
            'mse_with_expert': mse_with_expert.item(),
        }
        
        return actor_loss, actor_info
    
    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
        """
        Train the complete Flow Q-Learning model.
        
        Training includes three main components:
        1. Critic training - Q-learning loss
        2. BC Flow training - behavioral cloning flow loss
        3. One-step Flow training - knowledge distillation + Q loss
        
        Args:
            replay_buffer: Buffer containing expert demonstrations and experience
            iterations (int): Number of training iterations
            batch_size (int): Batch size
            log_writer: TensorBoard writer
            
        Returns:
            dict: Dictionary of training metrics
        """
        metrics = {
            'critic_loss': [], 'actor_loss': [], 'bc_flow_loss': [],
            'distill_loss': [], 'q_loss': [], 'mse_with_expert': []
        }
        
        for iteration in range(iterations):
            # Sample batch data
            batch = replay_buffer.sample(batch_size)
            
            # === 1. Train Critic ===
            critic_loss, critic_info = self.critic_loss(batch)
            
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_norm)
            
            self.critic_optimizer.step()
            
            # === 2. Train Actor (BC Flow + One-step Flow) ===
            actor_loss, actor_info = self.actor_loss(batch)
            
            # Update BC Flow and One-step Flow separately
            self.bc_flow_optimizer.zero_grad()
            self.one_step_flow_optimizer.zero_grad()
            
            actor_loss.backward()
            
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.bc_flow.parameters(), self.grad_norm)
                nn.utils.clip_grad_norm_(self.one_step_flow.parameters(), self.grad_norm)
            
            self.bc_flow_optimizer.step()
            self.one_step_flow_optimizer.step()
            
            # === 3. Soft update Target Critic ===
            self.soft_update_target()
            
            self.step += 1
            
            # === Record metrics ===
            metrics['critic_loss'].append(critic_info['critic_loss'])
            metrics['actor_loss'].append(actor_info['actor_loss'])
            metrics['bc_flow_loss'].append(actor_info['bc_flow_loss'])
            metrics['distill_loss'].append(actor_info['distill_loss'])
            metrics['q_loss'].append(actor_info['q_loss'])
            metrics['mse_with_expert'].append(actor_info['mse_with_expert'])
            
            if log_writer is not None:
                for key, value in critic_info.items():
                    log_writer.add_scalar(f'Critic/{key}', value, self.step)
                for key, value in actor_info.items():
                    log_writer.add_scalar(f'Actor/{key}', value, self.step)
        
        # Update learning rates
        if self.lr_decay:
            self.bc_flow_scheduler.step()
            self.one_step_flow_scheduler.step()
            self.critic_scheduler.step()
        
        return metrics
    
    def soft_update_target(self):
        """Soft update target Critic network."""
        for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    def compute_flow_actions(self, state, noise):
        """
        Compute actions using BC Flow (Euler integration method).
        
        This method implements the same flow computation process as the original JAX code,
        used for generating target actions in knowledge distillation training.
        
        Args:
            state (torch.Tensor): Environment state [batch_size, state_dim]
            noise (torch.Tensor): Initial noise [batch_size, action_dim]
            
        Returns:
            torch.Tensor: Actions computed through flow [batch_size, action_dim]
        """
        batch_size = state.shape[0]
        actions = noise.clone()
        
        # Euler method integration
        dt = 1.0 / self.flow_steps
        
        for step in range(self.flow_steps):
            t = torch.full((batch_size, 1), step * dt, device=self.device)
            velocity = self.bc_flow(actions, t, state)
            actions = actions + velocity * dt
        
        # Clip to valid range
        actions = torch.clamp(actions, -self.max_action, self.max_action)
        return actions
    
    def sample_action(self, state):
        """
        Sample actions using One-step Flow for fast inference.
        
        Unlike the original BC Flow integration, this uses the trained One-step Flow network
        to directly generate actions from noise, greatly improving inference speed.
        
        Args:
            state (np.ndarray or torch.Tensor): Current environment state
            
        Returns:
            np.ndarray: Sampled action
        """
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        
        with torch.no_grad():
            batch_size = state.shape[0]
            
            # Generate random noise
            noise = torch.randn(batch_size, self.action_dim, device=self.device)
            
            # One-step Flow directly predicts action
            action = self.one_step_flow(noise, state)
            
            # Clip to valid action range
            action = torch.clamp(action, -self.max_action, self.max_action)
        
        return action.cpu().numpy().flatten()
    
    def save_model(self, dir, id=None):
        """
        Save all FQL model components to disk.
        
        Args:
            dir (str): Save directory
            id (str, optional): Model identifier
        """
        suffix = f'_{id}' if id is not None else ''
        
        torch.save(self.bc_flow.state_dict(), f'{dir}/fql_bc_flow{suffix}.pth')
        torch.save(self.one_step_flow.state_dict(), f'{dir}/fql_one_step_flow{suffix}.pth')
        torch.save(self.critic.state_dict(), f'{dir}/fql_critic{suffix}.pth')
        torch.save(self.target_critic.state_dict(), f'{dir}/fql_target_critic{suffix}.pth')
    
    def load_model(self, dir, id=None, map_location=None):
        """
        Load pre-trained FQL model from disk.
        
        Args:
            dir (str): Model file directory
            id (str, optional): Model identifier
            map_location: Device mapping for cross-device loading
        """
        suffix = f'_{id}' if id is not None else ''
        
        self.bc_flow.load_state_dict(
            torch.load(f'{dir}/fql_bc_flow{suffix}.pth', map_location=map_location)
        )
        self.one_step_flow.load_state_dict(
            torch.load(f'{dir}/fql_one_step_flow{suffix}.pth', map_location=map_location)
        )
        self.critic.load_state_dict(
            torch.load(f'{dir}/fql_critic{suffix}.pth', map_location=map_location)
        )
        self.target_critic.load_state_dict(
            torch.load(f'{dir}/fql_target_critic{suffix}.pth', map_location=map_location)
        )
