import torch
import torch.nn as nn
from einops import rearrange
from transformers import (
    Dinov2Model,
    Dinov2Config,
    SegformerForSemanticSegmentation,
    SegformerConfig,
    ResNetConfig,
    ResNetModel,
)
from sfno import SphericalFourierNeuralOperatorNet as SFNO
from crossattn import CrossAttn, CrossAttnImage
import pytorch_lightning as pl
from ..losses import AsymmetricLoss, Ralloss, MEfullloss, ANfullloss


class RangeModel(nn.Module):
    def __init__(
        self,
        filter_type="non-linear",
        spectral_transform="sht",
        operator_type="vector",
        img_size=(900, 1800),
        scale_factor=4,
        in_chans=20,
        out_chans=1,
        embed_dim=128,
        num_layers=2,
        encoder_layers=1,
        spectral_layers=2,
        env_cov=False,
        attn_heads=16,
        attn_dim_head=64,
        text_dim=4096,
    ):
        super(RangeModel, self).__init__()
        self.img_size = img_size
        self.scale_factor = scale_factor
        assert (
            img_size[0] % scale_factor == 0
        ), "scale_factor must be a factor of img_size"
        assert (
            img_size[1] % scale_factor == 0
        ), "scale_factor must be a factor of img_size"

        self.encoder = SFNO(
            filter_type=filter_type,
            spectral_transform=spectral_transform,
            operator_type=operator_type,
            img_size=img_size,
            scale_factor=scale_factor,
            in_chans=in_chans,
            out_chans=out_chans,
            embed_dim=embed_dim,
            num_layers=num_layers,
            encoder_layers=encoder_layers,
            spectral_layers=spectral_layers,
            env_cov=env_cov,
        )

        self.cross_attn = CrossAttn(
            dim=embed_dim,
            dim_text=text_dim,
            scale_factor=scale_factor,
            heads=attn_heads,
            dim_head=attn_dim_head,
        )

        self.upsample = nn.Upsample(
            scale_factor=scale_factor, mode="bilinear", align_corners=False
        )

        self.out = nn.Covn2d(embed_dim, 1, 1, bias=False)

    def forward(self, text, image=None):
        x = self.encoder(image)
        x = self.cross_attn(x, text)
        x = rearrange(x, "b (h w) d -> b d h w", h=self.img_size[0] / self.scale_factor)
        x = self.upsample(x)
        x = self.out(x)
        return x


class DinoV2Model(nn.Module):
    def __init__(
        self,
        patch_size=18,
        img_size=(900, 1800),
        num_channels=20,
        num_hidden_layers=12,
        num_attention_heads=6,
        hidden_size=384,
    ):
        super(DinoV2Model, self).__init__()
        self.config = Dinov2Config(
            patch_size=patch_size,
            img_size=img_size,
            num_channels=num_channels,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            hidden_size=hidden_size,
        )
        self.model = Dinov2Model(self.config).train()
        self.cross_attn = CrossAttnImage(hidden_size, 8192)
        self.upsample = nn.Upsample(scale_factor=patch_size, mode="bilinear")
        self.out = nn.Conv2d(hidden_size, 1, 1, bias=False)

    def forward(self, image, text):
        x = self.model(image).last_hidden_state[:, 1:, :]
        x = rearrange(x, "b (h w) d -> b d h w", h=50)
        x = self.cross_attn(x, text)
        x = rearrange(x, "b (h w) d -> b d h w", h=50)
        x = self.upsample(x)
        x = self.out(x)

        return x


class SegformerModel(nn.Module):
    def __init__(self, num_channels=20, num_labels=128):
        super(SegformerModel, self).__init__()
        self.config = SegformerConfig(num_channels=num_channels, num_labels=128)
        self.model = SegformerForSemanticSegmentation(self.config).train()
        self.cross_attn = CrossAttnImage(num_labels, 8192)
        self.upsample = nn.Upsample(scale_factor=4, mode="bilinear")
        self.out = nn.Conv2d(num_labels, 1, 1, bias=False)

    def forward(self, image, text):
        x = self.model(image).logits
        x = self.cross_attn(x, text)
        x = rearrange(x, "b (h w) d -> b d h w", h=225)
        x = self.upsample(x)
        x = self.out(x)

        return x


# class ResNet50Model(nn.Module):
#     def __init__(self, img_size=(900, 1800), num_channels=20, embedding_size=128):
#         super(ResNet50Model, self).__init__()
#         self.config = ResNetConfig(num_channels=20, embedding_size=128, hidden_sizes=[128, 128, 128, 128])
#         self.model = ResNetModel(self.config).train()
#         self.cross_attn = CrossAttnImage(128, 8192)
#         self.upsample = nn.Upsample(img_size, mode='bilinear')
#         self.out = nn.Conv2d(128, 1, 1, bias=False)

#     def forward(self, image, text):
#         x = self.model(image).last_hidden_state
#         x = self.cross_attn(x, text)
#         x = rearrange(x, 'b (h w) d -> b d h w', h=225)
#         x = self.upsample(x)
#         x = self.out(x)

#         return x


class LightningRangeModel(pl.LightningModule):
    def __init__(
        self,
        filter_type="non-linear",
        spectral_transform="sht",
        operator_type="vector",
        img_size=(900, 1800),
        scale_factor=4,
        in_chans=12,
        out_chans=1,
        embed_dim=128,
        num_layers=2,
        encoder_layers=1,
        spectral_layers=2,
        env_cov=False,
        attn_heads=16,
        attn_dim_head=64,
        text_dim=4096,
        loss_type="RAL",
        gamma_neg=4,
        gamma_pos=2,
        alpha=10,
        mask=None,
        lr=1e-5,
    ):
        super(LightningRangeModel, self).__init__()
        self.model = RangeModel(
            filter_type=filter_type,
            spectral_transform=spectral_transform,
            operator_type=operator_type,
            img_size=img_size,
            scale_factor=scale_factor,
            in_chans=in_chans,
            out_chans=out_chans,
            embed_dim=embed_dim,
            num_layers=num_layers,
            encoder_layers=encoder_layers,
            spectral_layers=spectral_layers,
            env_cov=env_cov,
            attn_heads=attn_heads,
            attn_dim_head=attn_dim_head,
            text_dim=text_dim,
        )
        self.env_cov = env_cov
        self.mask = mask
        self.lr = lr
        if loss_type == "RAL":
            self.loss = Ralloss(gamma_neg=gamma_neg, gamma_pos=gamma_pos)
        elif loss_type == "ASL":
            self.loss = AsymmetricLoss(gamma_neg=gamma_neg, gamma_pos=gamma_pos)
        elif loss_type == "ME-full":
            self.loss = MEfullloss(alpha=alpha)
        elif loss_type == "AN-full":
            self.loss = ANfullloss(alpha=alpha)

    def forward(self, text, image=None):
        return self.model(text, image)

    def shared_step(self, batch, batch_idx):
        if self.env_cov:
            image, text, target = batch
        else:
            text, target = batch
            image = None
        output = self.forward(text, image)
        loss = self.loss(output, target, mask=self.mask)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
