import sys
from typing import List, Tuple, Optional, Union

import torch
from pytorch_lightning import LightningModule
import pytorch_lightning as pl

from timm.models.vision_transformer import vit_base_patch16_224
from torch import Tensor
from torch.nn import MSELoss, Parameter
from torch.optim import AdamW

from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler

from model_zoo.base import BaseModel


class MAEModel(BaseModel):
    def __init__(
        self,
        batch_size_per_device: int = 256,
        lr: Union[float, str] = "auto",
        momentum: float = 0.9,
        weight_decay: float = 1e-6,
        temperature: float = 0.1,
        network: str = "resnet50",
        low_res: bool = False,
        weight_decay_trick: bool = True,
        datamodule: Optional[pl.LightningDataModule] = None,
    ) -> None:
        super().__init__(
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            network=network,
            low_res=low_res,
            weight_decay_trick=weight_decay_trick,
            datamodule=datamodule,
        )
        self.save_hyperparameters()
        self.batch_size_per_device = batch_size_per_device

        decoder_dim = 512
        vit = vit_base_patch16_224()

        self.mask_ratio = 0.75
        self.patch_size = vit.patch_embed.patch_size[0]
        self.sequence_length = vit.patch_embed.num_patches + vit.num_prefix_tokens
        mask_token = Parameter(torch.zeros(1, 1, decoder_dim))
        torch.nn.init.normal_(mask_token, std=0.02)
        self.backbone = MaskedVisionTransformerTIMM(vit=vit)
        self.decoder = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=vit.embed_dim,
            decoder_embed_dim=decoder_dim,
            decoder_depth=8,
            decoder_num_heads=16,
            mlp_ratio=4.0,
            proj_drop_rate=0.0,
            attn_drop_rate=0.0,
            mask_token=mask_token,
        )
        self.criterion = MSELoss()
        self.online_classifier = OnlineLinearClassifier(
            feature_dim=768, num_classes=self.num_classes
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.backbone(images=x)

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images=images, idx_keep=idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(
            self.decoder.mask_token, (batch_size, self.sequence_length)
        )
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred

    def training_step(
        self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        images, targets = batch[0], batch[1]
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        features = self.forward_encoder(images, idx_keep)
        predictions = self.forward_decoder(features, idx_keep, idx_mask)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)

        loss = self.criterion(predictions, target)
        self.log(
            "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
        )

        cls_features = features[:, 0]
        cls_loss, cls_log = self.online_classifier.training_step(
            (cls_features.detach(), targets), batch_idx
        )
        self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
        return loss + cls_loss

    def configure_optimizers(self):
        # Don't use weight decay for batch norm, bias parameters, and classification
        # head to improve performance.
        params, params_no_weight_decay = utils.get_weight_decay_parameters(
            [self.backbone, self.decoder]
        )
        optimizer = AdamW(
            [
                {"name": "mae", "params": params},
                {
                    "name": "mae_no_weight_decay",
                    "params": params_no_weight_decay,
                    "weight_decay": 0.0,
                },
                {
                    "name": "online_classifier",
                    "params": self.online_classifier.parameters(),
                    "weight_decay": 0.0,
                },
            ],
            lr=1.5e-4 * self.batch_size_per_device * self.trainer.world_size / 256,
            weight_decay=0.05,
            betas=(0.9, 0.95),
        )
        scheduler = {
            "scheduler": CosineWarmupScheduler(
                optimizer=optimizer,
                warmup_epochs=(
                    self.trainer.estimated_stepping_batches
                    / self.trainer.max_epochs
                    * 40
                ),
                max_epochs=self.trainer.estimated_stepping_batches,
            ),
            "interval": "step",
        }
        return [optimizer], [scheduler]

