# Implementation of Flow Q-Learning (FQL)
# https://arxiv.org/abs/2502.02538

import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.logger import logger

from agents.flow import FlowMatching
from agents.model import MLP
from agents.helpers import EMA


class OneStepPolicy(nn.Module):
    """
    One-step policy π_ω(s, z) that maps noise to action in a single step.
    This is distilled from the flow policy μ_θ.
    """
    def __init__(self, state_dim, action_dim, hidden_dim=512):
        super(OneStepPolicy, self).__init__()
        
        # Input: state + noise, Output: action (4 layers as in official FQL)
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, state, z):
        """
        Args:
            state: (batch_size, state_dim)
            z: (batch_size, action_dim) - noise from N(0, I)
        Returns:
            action: (batch_size, action_dim)
        """
        x = torch.cat([state, z], dim=-1)
        return self.net(x)


class Critic(nn.Module):
    """Twin Q-network for critic with LayerNorm (as in official FQL)."""
    def __init__(self, state_dim, action_dim, hidden_dim=512, layer_norm=True):
        super(Critic, self).__init__()
        
        if layer_norm:
            self.q1_model = nn.Sequential(
                nn.Linear(state_dim + action_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, 1)
            )
            
            self.q2_model = nn.Sequential(
                nn.Linear(state_dim + action_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, 1)
            )
        else:
            self.q1_model = nn.Sequential(
                nn.Linear(state_dim + action_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, 1)
            )
            
            self.q2_model = nn.Sequential(
                nn.Linear(state_dim + action_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, 1)
            )
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x), self.q2_model(x)
    
    def q1(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x)
    
    def q_min(self, state, action):
        q1, q2 = self.forward(state, action)
        return torch.min(q1, q2)


class Flow_QL(object):
    
    def __init__(self,
                 state_dim,
                 action_dim,
                 max_action,
                 device,
                 discount,
                 tau,
                 max_q_backup=False,
                 eta=1.0,              # Weight for Q-learning loss in one-step policy
                 alpha=1.0,            # Weight for distillation loss (tuned per env in main_fql.py)
                 normalize_q_loss=False,  # Whether to normalize the Q loss
                 layer_norm=True,      # Whether to use LayerNorm in critic
                 n_timesteps=10,       # Number of Euler steps for flow sampling
                 ema_decay=0.995,
                 step_start_ema=1000,
                 update_ema_every=5,
                 lr=3e-4,
                 lr_decay=False,
                 lr_maxt=1000,
                 grad_norm=1.0,
                 ):
        
        # Flow model (BC flow policy μ_θ)
        self.flow_model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)
        self.flow_policy = FlowMatching(
            state_dim=state_dim, 
            action_dim=action_dim, 
            model=self.flow_model, 
            max_action=max_action,
            n_timesteps=n_timesteps
        ).to(device)
        self.flow_optimizer = torch.optim.Adam(self.flow_policy.parameters(), lr=lr)
        
        # One-step policy π_ω (distilled from flow)
        self.one_step_policy = OneStepPolicy(state_dim, action_dim).to(device)
        self.one_step_optimizer = torch.optim.Adam(self.one_step_policy.parameters(), lr=lr)
        
        # Critic Q_φ (with LayerNorm as in official FQL)
        self.critic = Critic(state_dim, action_dim, layer_norm=layer_norm).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        
        # EMA for flow policy
        self.step = 0
        self.step_start_ema = step_start_ema
        self.ema = EMA(ema_decay)
        self.ema_flow_policy = copy.deepcopy(self.flow_policy)
        self.update_ema_every = update_ema_every
        
        # Learning rate schedulers
        self.lr_decay = lr_decay
        if lr_decay:
            self.flow_lr_scheduler = CosineAnnealingLR(self.flow_optimizer, T_max=lr_maxt, eta_min=0.)
            self.one_step_lr_scheduler = CosineAnnealingLR(self.one_step_optimizer, T_max=lr_maxt, eta_min=0.)
            self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.)
        
        self.grad_norm = grad_norm
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.eta = eta
        self.alpha = alpha
        self.normalize_q_loss = normalize_q_loss
        self.layer_norm = layer_norm
        self.device = device
        self.max_q_backup = max_q_backup
        self.n_timesteps = n_timesteps
    
    def step_ema(self):
        if self.step < self.step_start_ema:
            return
        self.ema.update_model_average(self.ema_flow_policy, self.flow_policy)
    
    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
        """
        Training loop for FQL.
        """
        metric = {
            'flow_loss': [], 
            'ql_loss': [], 
            'distill_loss': [],
            'one_step_loss': [], 
            'critic_loss': []
        }
        
        for _ in range(iterations):

            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
            
            # ==================== Train Critic Q_φ ====================
            current_q1, current_q2 = self.critic(state, action)
            
            with torch.no_grad():
                if self.max_q_backup:
                    next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
                    z_rpt = torch.randn(next_state_rpt.shape[0], self.action_dim, device=self.device)
                    next_action_rpt = self.one_step_policy(next_state_rpt, z_rpt)
                    next_action_rpt = next_action_rpt.clamp(-self.max_action, self.max_action)
                    target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
                    target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                    target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                    target_q = torch.min(target_q1, target_q2)
                else:
                    z = torch.randn(batch_size, self.action_dim, device=self.device)
                    next_action = self.one_step_policy(next_state, z)
                    next_action = next_action.clamp(-self.max_action, self.max_action)
                    target_q1, target_q2 = self.critic_target(next_state, next_action)
                    target_q = torch.min(target_q1, target_q2)
                
                target_q = reward + not_done * self.discount * target_q
            
            # Average the two Q losses (official FQL computes mean over ensemble)
            critic_loss = (F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)) / 2
            
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_norm)
            self.critic_optimizer.step()
            
            # ==================== Train Flow Policy (BC) ====================
            flow_loss = self.flow_policy.loss(action, state)
            
            self.flow_optimizer.zero_grad()
            flow_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.flow_policy.parameters(), max_norm=self.grad_norm)
            self.flow_optimizer.step()
            
            # ==================== Train One-Step Policy ====================
    
            z = torch.randn(batch_size, self.action_dim, device=self.device)
            
            # a^π ← μ_ω(s, z)
            a_pi = self.one_step_policy(state, z)
            a_pi_clipped = a_pi.clamp(-self.max_action, self.max_action)
            
            # Get target from flow policy using the SAME noise z (important for distillation)
            with torch.no_grad():
                a_flow = self.flow_policy.sample_from_noise(state, z)
            
            # Q-loss: -Q_φ(s, a^π) using mean of both Q values (as in official FQL)
            q1_pi, q2_pi = self.critic(state, a_pi_clipped)
            q_pi = (q1_pi + q2_pi) / 2
            q_loss = -q_pi.mean()
            if self.normalize_q_loss:
                lam = 1.0 / q_pi.abs().mean().detach()
                q_loss = lam * q_loss
            
            # Distillation loss: ||a^π - μ_θ(s, z)||_2^2
            distill_loss = F.mse_loss(a_pi, a_flow)
            
            # Total one-step policy loss (matches official FQL: bc_flow_loss + alpha * distill_loss + q_loss)
            # Note: flow_loss is trained separately, so here we have: alpha * distill_loss + eta * q_loss
            one_step_loss = self.alpha * distill_loss + self.eta * q_loss
            
            self.one_step_optimizer.zero_grad()
            one_step_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.one_step_policy.parameters(), max_norm=self.grad_norm)
            self.one_step_optimizer.step()
            
            # ==================== Update Target Networks ====================
            if self.step % self.update_ema_every == 0:
                self.step_ema()
            
            # Soft update critic target
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            self.step += 1
            
            # ==================== Logging ====================
            if log_writer is not None:
                log_writer.add_scalar('Flow Loss', flow_loss.item(), self.step)
                log_writer.add_scalar('QL Loss', q_loss.item(), self.step)
                log_writer.add_scalar('Distill Loss', distill_loss.item(), self.step)
                log_writer.add_scalar('One-Step Loss', one_step_loss.item(), self.step)
                log_writer.add_scalar('Critic Loss', critic_loss.item(), self.step)
            
            metric['flow_loss'].append(flow_loss.item())
            metric['ql_loss'].append(q_loss.item())
            metric['distill_loss'].append(distill_loss.item())
            metric['one_step_loss'].append(one_step_loss.item())
            metric['critic_loss'].append(critic_loss.item())
        
        if self.lr_decay:
            self.flow_lr_scheduler.step()
            self.one_step_lr_scheduler.step()
            self.critic_lr_scheduler.step()
        
        return metric
    
    def sample_action(self, state):
        """
        Sample action using the one-step policy.
        At test time, we only use the one-step policy π_ω.
        """
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
        
        with torch.no_grad():
            z = torch.randn(state_rpt.shape[0], self.action_dim, device=self.device)
            action = self.one_step_policy(state_rpt, z)
            action = action.clamp(-self.max_action, self.max_action)
            q_value = self.critic_target.q_min(state_rpt, action).flatten()
            idx = torch.multinomial(F.softmax(q_value, dim=0), 1)
        
        return action[idx].cpu().data.numpy().flatten()
    
    def sample_action_deterministic(self, state):
        """
        Sample a single action deterministically (using zero noise).
        """
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        
        with torch.no_grad():
            z = torch.zeros(1, self.action_dim, device=self.device)
            action = self.one_step_policy(state, z)
            action = action.clamp(-self.max_action, self.max_action)
        
        return action.cpu().data.numpy().flatten()
    
    def save_model(self, dir, id=None):
        if id is not None:
            torch.save(self.flow_policy.state_dict(), f'{dir}/flow_{id}.pth')
            torch.save(self.one_step_policy.state_dict(), f'{dir}/one_step_{id}.pth')
            torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
        else:
            torch.save(self.flow_policy.state_dict(), f'{dir}/flow.pth')
            torch.save(self.one_step_policy.state_dict(), f'{dir}/one_step.pth')
            torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
    
    def load_model(self, dir, id=None, map_location=None):
        if id is not None:
            self.flow_policy.load_state_dict(torch.load(f'{dir}/flow_{id}.pth', map_location=map_location))
            self.one_step_policy.load_state_dict(torch.load(f'{dir}/one_step_{id}.pth', map_location=map_location))
            self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth', map_location=map_location))
        else:
            self.flow_policy.load_state_dict(torch.load(f'{dir}/flow.pth', map_location=map_location))
            self.one_step_policy.load_state_dict(torch.load(f'{dir}/one_step.pth', map_location=map_location))
            self.critic.load_state_dict(torch.load(f'{dir}/critic.pth', map_location=map_location))
