import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim
import torchmetrics
import torch.nn.functional as F


class TuabPretrainTask(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.model = hydra.utils.instantiate(self.hparams.model)
        self.mask_ratio = self.hparams.mask_ratio
    
    def mask_tokens(self, tokens, mask_ratio=0.1):
        B, S, C, D = tokens.size()
        mask = torch.rand(B, S, device=tokens.device) < mask_ratio
        return mask.unsqueeze(2).unsqueeze(3).expand(-1, -1, C, D)

    def reconstruction_loss(self, reconstructed, original, mask):
        masked_loss = F.mse_loss(reconstructed[mask], original[mask])
        unmasked_loss = F.mse_loss(reconstructed[~mask], original[~mask])
        return masked_loss, unmasked_loss

    def training_step(self, batch, batch_idx):
        X, _ = batch
        original_tokens = self.model.patch_embedder(X)
        mask = self.mask_tokens(original_tokens, mask_ratio=self.mask_ratio)
        reconstructed = self.model(X, mask)
        masked_loss, unmasked_loss = self.reconstruction_loss(reconstructed, original_tokens, mask)
        loss = masked_loss + unmasked_loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('masked_loss', masked_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('unmasked_loss', unmasked_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X, _ = batch
        original_tokens = self.model.patch_embedder(X)
        mask = self.mask_tokens(original_tokens, mask_ratio=self.mask_ratio)
        reconstructed = self.model(X, mask)
        masked_loss, unmasked_loss = self.reconstruction_loss(reconstructed, original_tokens, mask)
        loss = masked_loss + unmasked_loss
        self.log('val_loss', loss, prog_bar=True, logger=True)
        self.log('val_masked_loss', masked_loss, prog_bar=True, logger=True)
        self.log('val_unmasked_loss', unmasked_loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        X, _ = batch
        original_tokens = self.model.patch_embedder(X)
        mask = self.mask_tokens(original_tokens, mask_ratio=self.mask_ratio)
        reconstructed = self.model(X, mask)
        masked_loss, unmasked_loss = self.reconstruction_loss(reconstructed, original_tokens, mask)
        loss = masked_loss + unmasked_loss
        self.log('test_loss', loss, prog_bar=True, logger=True)
        self.log('test_masked_loss', masked_loss, prog_bar=True, logger=True)
        self.log('test_unmasked_loss', unmasked_loss, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.optimizer.lr, weight_decay=0.01)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.optimizer.lr)
        elif self.hparams.optimizer.optim == 'AdamW_finetune':
            linear_out_params = self.model.linear_out.parameters() if not self.hparams.multi_gpu else self.model.module.linear_out.parameters()
            ignored_params = list(map(id, linear_out_params))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.model.parameters())

            optimizer = torch.optim.AdamW([
                {'params': base_params},
                {'params': linear_out_params, 'lr': self.hparams.optimizer.lr}
            ], lr=self.hparams.optimizer.lr * 0.1)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(self.model.parameters(), lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optim name")

        scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
        }

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
