import pytorch_lightning as pl
import torch

class BaseProteinModel(pl.LightningModule):
    def __init__(self, args, **kwargs):
        super().__init__()
        for key in kwargs:
            setattr(args, key, kwargs[key])
        self.save_hyperparameters(args)
        self.args=args
    
    @classmethod
    def add_argparse_args(cls, parent_parser):
        group = parent_parser.add_argument_group('BaseProteinModel')
        group.add_argument('--model_name_or_path', type=str, default="")
        group.add_argument('--freeze_before_layer', type=int, default=-1)
        # Training
        group.add_argument('--lr', type=float, default=5e-4)
        group.add_argument('--warmup', type=int, default=50)
        group.add_argument('--early_step', action="store_true")
        group.add_argument('--generation_steps', type=int, default=1)
        group.add_argument('--delta_t', type=float, default=30)
        group.add_argument('--particle_num', type=int, default=1000)
        group.add_argument('--min_particle_num', type=int, default=10)
        group.add_argument('--sampling_method', type=str, default=None)
        return parent_parser
    
    def _calculate_loss(self, batch, mode="train", reduce=True):
        # Fetch data and transform categories to one-hot vectors
        src_tokens, tgt_tokens = batch
        masks = (src_tokens != self.alphabet.pad).view(-1) # [B, L]

        logits = self.forward(src_tokens) # [B, L, V]
        loss = F.cross_entropy(logits.view(-1, logits.size(-1))[masks], tgt_tokens.view(-1)[masks])
        # word_num = masks.sum()

        # Logging
        self.log("%s_loss" % mode, loss)
        return loss

    def forward(self, src_tokens):
        return self.base_model(src_tokens)['logits']
        src_tokens, tgt_tokens = batch
        logits = self.base_model(src_tokens)['logits'] # [B, L, V]
        masks = (src_tokens != self.alphabet.pad)
        
        loss = self.criterion(m(logits), tgt_tokens)
        return loss
        
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        # We don't return the lr scheduler because we need to apply it per iteration, not per epoch
        self.lr_scheduler = CosineWarmupScheduler(
            optimizer, warmup=self.hparams.warmup, max_iters=self.hparams.max_iters
        )
        return optimizer
    
    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.lr_scheduler.step()  # Step per iteration

    def training_step(self, train_batch, batch_idx):
        loss = self._calculate_loss(train_batch, mode="train")
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        loss = self._calculate_loss(val_batch, mode="val")

    def test_step(self, test_batch, batch_idx):        
        raise NotImplementedError()
