import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from abc import ABC, abstractmethod
from agents.basic_bc import BaseImitationLearning


class ValueDice(BaseImitationLearning):
    """
    ValueDice imitation learning simplified implementation.
    
    ValueDice is a distribution matching approach that uses value functions
    to align the policy's state-action distribution with the expert's.
    It learns a value function that assigns higher values to expert-like
    state-action pairs and then trains the policy to maximize this value.
    """
    
    def __init__(self, 
                 state_dim, 
                 action_dim, 
                 max_action, 
                 device,
                 hidden_dim=256,
                 lr=3e-4,
                 lr_decay=False,
                 lr_maxt=1000,
                 grad_norm=1.0,
                 discount=0.99):
        """
        Initialize the ValueDice agent.
        
        Args:
            state_dim (int): Dimension of the state space
            action_dim (int): Dimension of the action space
            max_action (float): Maximum action value
            device (torch.device): Computation device
            hidden_dim (int): Hidden layer dimension for networks
            lr (float): Learning rate
            lr_decay (bool): Whether to use learning rate decay
            lr_maxt (int): Maximum steps for cosine annealing
            grad_norm (float): Gradient clipping threshold
            discount (float): Discount factor for value function
        """
        super().__init__(state_dim, action_dim, max_action, device, lr)
        
        self.grad_norm = grad_norm
        self.lr_decay = lr_decay
        self.discount = discount
        
        # Policy network that maps states to actions
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()  # Tanh output for bounded actions
        ).to(device)
        
        # Value function network that estimates state-action values
        self.value_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Single output for value estimation
        ).to(device)
        
        # Separate optimizers for policy and value networks
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=lr)
        
        # Optional learning rate scheduling
        if lr_decay:
            self.actor_lr_scheduler = CosineAnnealingLR(self.actor_optimizer, T_max=lr_maxt, eta_min=0.)
            self.value_lr_scheduler = CosineAnnealingLR(self.value_optimizer, T_max=lr_maxt, eta_min=0.)
    
    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
        """
        Train the ValueDice model.
        
        The training process involves:
        1. Training value function using expert transitions
        2. Training policy to maximize the learned value function
        
        Args:
            replay_buffer: Buffer containing expert demonstrations
            iterations (int): Number of training iterations
            batch_size (int): Batch size for training
            log_writer: TensorBoard writer for logging
            
        Returns:
            dict: Training metrics for both actor and value networks
        """
        metric = {'actor_loss': [], 'value_loss': []}
        
        for _ in range(iterations):
            # Sample expert transitions (state, action, next_state)
            state, action, next_state, _, _ = replay_buffer.sample(batch_size)
            
            # Compute value for current state-action pairs
            current_value = self.value_net(torch.cat([state, action], dim=-1))
            
            # Generate actions for next states using current policy
            with torch.no_grad():
                next_action = self.actor(next_state) * self.max_action
            
            # Compute value for next state-action pairs
            next_value = self.value_net(torch.cat([next_state, next_action], dim=-1))
            
            # ValueDice value function loss: align with discounted future value
            target_value = self.discount * next_value
            value_loss = F.mse_loss(current_value, target_value.detach())
            
            # Update value function
            self.value_optimizer.zero_grad()
            value_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.value_net.parameters(), 
                                       max_norm=self.grad_norm, norm_type=2)
            self.value_optimizer.step()
            
            # Policy loss: maximize expected value under current policy
            generated_action = self.actor(state) * self.max_action
            policy_value = self.value_net(torch.cat([state, generated_action], dim=-1))
            actor_loss = -policy_value.mean()  # Negative because we want to maximize
            
            # Update policy
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.actor.parameters(), 
                                       max_norm=self.grad_norm, norm_type=2)
            self.actor_optimizer.step()
            
            self.step += 1
            
            # Log training metrics
            if log_writer is not None:
                log_writer.add_scalar('Actor Loss', actor_loss.item(), self.step)
                log_writer.add_scalar('Value Loss', value_loss.item(), self.step)
            
            metric['actor_loss'].append(actor_loss.item())
            metric['value_loss'].append(value_loss.item())
        
        # Update learning rates if decay is enabled
        if self.lr_decay:
            self.actor_lr_scheduler.step()
            self.value_lr_scheduler.step()
            
        return metric
    
    def sample_action(self, state):
        """
        Sample an action using the trained policy.
        
        Args:
            state (np.ndarray or torch.Tensor): Current environment state
            
        Returns:
            np.ndarray: Sampled action
        """
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        
        with torch.no_grad():
            action = self.actor(state) * self.max_action
        
        return action.cpu().data.numpy().flatten()
    
    def save_model(self, dir, id=None):
        """
        Save both actor and value networks.
        
        Args:
            dir (str): Directory to save models
            id (str, optional): Model identifier for versioning
        """
        if id is not None:
            torch.save(self.actor.state_dict(), f'{dir}/valuedice_actor_{id}.pth')
            torch.save(self.value_net.state_dict(), f'{dir}/valuedice_value_{id}.pth')
        else:
            torch.save(self.actor.state_dict(), f'{dir}/valuedice_actor.pth')
            torch.save(self.value_net.state_dict(), f'{dir}/valuedice_value.pth')
    
    def load_model(self, dir, id=None, map_location=None):
        """
        Load both actor and value networks.
        
        Args:
            dir (str): Directory containing saved models
            id (str, optional): Model identifier for versioning
            map_location: Device mapping for loading
        """
        if id is not None:
            self.actor.load_state_dict(torch.load(f'{dir}/valuedice_actor_{id}.pth', map_location=map_location))
            self.value_net.load_state_dict(torch.load(f'{dir}/valuedice_value_{id}.pth', map_location=map_location))
        else:
            self.actor.load_state_dict(torch.load(f'{dir}/valuedice_actor.pth', map_location=map_location))
            self.value_net.load_state_dict(torch.load(f'{dir}/valuedice_value.pth', map_location=map_location))