import torch
import pytorch_lightning as pl
from rsl_rl.rsl_rl.modules.actor_critic import P4RLHamburgerActor

class ActorZeroOutputPretrainLightning(pl.LightningModule):
    def __init__(self, actor: P4RLHamburgerActor, lr=1e-3):
        super().__init__()
        self.actor = actor
        self.lr = lr
        self.train_losses = []

        # 1. Freeze pretrained_module parameters
        self.requires_grad_state = {name: param.requires_grad for name, param in self.actor.pretrained_module.named_parameters()}
        for param in self.actor.pretrained_module.parameters():
            param.requires_grad = False

    def training_step(self, batch, batch_idx):
        x = batch[0]  # because TensorDataset returns tuples
        actions = self.actor(x)
        loss = torch.nn.functional.l1_loss(
            actions,
            torch.zeros_like(actions),
            reduction='mean'
        )
        self.train_losses.append(loss.item())
        return loss
    
    def on_train_end(self):
        for name, param in self.actor.pretrained_module.named_parameters():
            param.requires_grad = self.requires_grad_state[name]
        print("Restored the trainable state of the pretrained module in P4RLHamburgerCritic.")
        return super().on_train_end()

    def configure_optimizers(self):
        # 2. Only pass unfrozen parameters to optimizer
        trainable_params = filter(lambda p: p.requires_grad, self.actor.parameters())
        return torch.optim.Adam(trainable_params, lr=self.lr)