import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule

from melp.backbone.resnet1d import ResNet18, ResNet34, ResNet50, ResNet101
from melp.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base
from melp.backbone.pooling import AttentionPool2d


class PSGModalityEncoder(nn.Module):
    """Encoder for PSG signals: backbone -> optional pooling -> projection -> L2-norm"""

    def __init__(self, *,
                 encoder_name: str,
                 proj_out: int = 256,
                 proj_hidden: int = 512,
                 freq: int = 64,
                 win_sec: int = 30,
                 channel: int = 11,
                 lead_wise=0,
                 patch_size=40,
                 patch_size_ch=4,
                 use_lead_embedding: bool = True,
                 is_proj_head=1):
        super().__init__()
        token_len = freq * win_sec  # [T]
        spacial_dim = token_len // 16

        self.token_len = token_len
        self.patch_size = patch_size

        if "resnet" in encoder_name:
            if encoder_name == "resnet18":
                self.backbone = ResNet18(in_ch=channel)
                in_ch = 512
            elif encoder_name == "resnet34":
                self.backbone = ResNet34(in_ch=channel)
                in_ch = 512
            elif encoder_name == "resnet50":
                self.backbone = ResNet50(in_ch=channel)
                in_ch = 2048
            elif encoder_name == "resnet101":
                self.backbone = ResNet101(in_ch=channel)
                in_ch = 2048

            self.downproj = nn.Conv1d(in_ch, proj_out, kernel_size=1)
            self.att_pool = AttentionPool2d(spacial_dim=spacial_dim, embed_dim=proj_out,
                                            num_heads=4, output_dim=proj_out)
            self.proj_head = None

        elif "vit" in encoder_name and encoder_name != "pretrained-vit":
            if encoder_name == "vit_nano":
                self.backbone = vit_nano(num_leads=channel, seq_len=token_len, patch_size=patch_size,
                                         lead_wise=lead_wise, patch_size_ch=patch_size_ch,
                                         use_lead_embedding=use_lead_embedding)
            elif encoder_name == "vit_tiny":
                self.backbone = vit_tiny(num_leads=channel, seq_len=token_len, patch_size=patch_size,
                                         lead_wise=lead_wise, patch_size_ch=patch_size_ch,
                                         use_lead_embedding=use_lead_embedding)
            elif encoder_name == "vit_small":
                self.backbone = vit_small(num_leads=channel, seq_len=token_len, patch_size=patch_size,
                                          lead_wise=lead_wise, patch_size_ch=patch_size_ch,
                                          use_lead_embedding=use_lead_embedding)
            elif encoder_name == "vit_middle":
                self.backbone = vit_middle(num_leads=channel, seq_len=token_len, patch_size=patch_size,
                                           lead_wise=lead_wise, patch_size_ch=patch_size_ch,
                                           use_lead_embedding=use_lead_embedding)
            elif encoder_name == "vit_base":
                self.backbone = vit_base(num_leads=channel, seq_len=token_len, patch_size=patch_size,
                                         lead_wise=lead_wise, patch_size_ch=patch_size_ch,
                                         use_lead_embedding=use_lead_embedding)

            d_model = self.backbone.width
            self.downproj = None
            self.att_pool = None
            if is_proj_head == 1:
                self.proj_head = nn.Sequential(
                    nn.Linear(d_model, proj_hidden),
                    nn.LayerNorm(proj_hidden),
                    nn.ReLU(inplace=True),
                    nn.Linear(proj_hidden, proj_out),
                    nn.LayerNorm(proj_out),
                )
            else:
                self.proj_head = None
        else:
            raise ValueError(f"Unknown encoder_name: {encoder_name}")

        self.avgpool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x, normalize=True, is_patch=False):
        # x: [B, C, T]
        h = self.backbone(x, is_patch=is_patch)  # [B, D, T'] or [B, N, D]

        if not is_patch:
            if self.downproj is not None:  # ResNet
                h = self.downproj(h)  # [B, proj_out, T']
                if self.att_pool is not None:
                    h, _ = self.att_pool(h)  # [B, 1, proj_out]
                    h = h.squeeze(1)  # [B, proj_out]
                else:
                    h = self.avgpool(h).squeeze(-1)
            else:  # ViT
                if self.proj_head is not None:
                    h = self.proj_head(h)  # [B, proj_out]

        if normalize:
            return F.normalize(h, dim=-1)
        return h


class BasePretrainModel(LightningModule):
    def __init__(self,
                 psg_encoder_name: str = "resnet18",
                 text_encoder_name: str = "google/flan-t5-base",
                 fusion_decoder_name: str = 'cross-attn',
                 shared_emb_dim: int = 256,
                 lr: float = 2e-4,
                 weight_decay: float = 0.2,
                 training_steps_per_epoch: int = 7000,
                 max_epochs: int = 100,
                 *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.psg_encoder_name = psg_encoder_name
        self.text_encoder_name = text_encoder_name
        self.fusion_decoder_name = fusion_decoder_name
        self.shared_emb_dim = shared_emb_dim
        self.lr = lr
        self.weight_decay = weight_decay
        self.training_steps_per_epoch = training_steps_per_epoch
        self.max_epochs = max_epochs
        self.warmup_epochs = 0.1 * self.max_epochs
        self.proj_out = shared_emb_dim
        self.proj_hidden = 256

        assert self.training_steps_per_epoch > 1

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
            betas=(0.9, 0.95),
        )

        total_steps = int(self.training_steps_per_epoch * self.max_epochs)
        warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs))
        warmup_steps = max(0, warmup_steps)
        decay_steps = max(1, total_steps - warmup_steps)

        if warmup_steps > 0:
            warmup = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
            cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=decay_steps, eta_min=1e-8)
            sched = torch.optim.lr_scheduler.SequentialLR(
                optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
        else:
            sched = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=decay_steps, eta_min=1e-8)

        return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}]

    def training_step(self, batch, batch_idx):
        loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
        for k, v in loss_dict.items():
            self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        for k, v in metrics_dict.items():
            self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss_dict['loss']

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
        for k, v in loss_dict.items():
            self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        for k, v in metrics_dict.items():
            self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss_dict

    def test_step(self, batch, batch_idx):
        loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
        for k, v in loss_dict.items():
            self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        for k, v in metrics_dict.items():
            self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss_dict
