import torch
from torch import nn
from torch.nn import functional as F
from icecream import ic

from typing import Any, Dict, List, Tuple

from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder, PositionEmbeddingRandom
from .myTransformer import MyWayTransformer


class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
            self,
            OneWayTransformer: nn.Module,
            image_encoder: ImageEncoderViT,
            prompt_encoder: PromptEncoder,
            mask_decoder: MaskDecoder,
            pixel_mean: List[float] = [123.675, 116.28, 103.53],
            pixel_std: List[float] = [58.395, 57.12, 57.375],
    ) -> None:

        super().__init__()

        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

        #APG-----------------------------------------
        self.MyTransformer = MyWayTransformer(
            depth=2,
            embedding_dim=256,
            mlp_dim=2048,
            num_heads=8,
        )

        self.coarsehead = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256),
                                        nn.ReLU(inplace=True),
                                        nn.Conv2d(256, 3, kernel_size=3, padding=1))
        self.BG_prompt_downsample = nn.MaxPool2d(kernel_size=2, stride=2)
        self.COD_prompt_downsample = nn.MaxPool2d(kernel_size=2, stride=2)
        self.SOD_prompt_downsample = nn.MaxPool2d(kernel_size=2, stride=2)
        self.BG_prompt_fc = nn.Linear(121, 4)
        self.COD_prompt_fc = nn.Linear(121, 4)
        self.SOD_prompt_fc = nn.Linear(121, 4)
        self.static_token_embedding = nn.Embedding(12, 256)
        self.pe_layer = PositionEmbeddingRandom(256 // 2)

    @property
    def device(self) -> Any:
        return self.pixel_mean.device

    def forward(self, batched_input, multimask_output, image_size):
        outputs = self.forward_train(batched_input, multimask_output, image_size)
        return outputs

    def forward_train(self, batched_input, multimask_output, image_size):
        input_images = self.preprocess(batched_input)

        image_embeddings = self.image_encoder(input_images)

        coarse_map_out = self.coarsehead(image_embeddings)
        coarse_map = torch.sigmoid(coarse_map_out)

        coarseBGAttention = coarse_map[:, 0, :, :].unsqueeze(1)
        coarseSODAttention = coarse_map[:, 1, :, :].unsqueeze(1)
        coarseCODAttention = coarse_map[:, 2, :, :].unsqueeze(1)

        BG_prompt = image_embeddings * coarseBGAttention
        SOD_prompt = image_embeddings * coarseSODAttention
        COD_prompt = image_embeddings * coarseCODAttention

        BG_prompt = self.BG_prompt_downsample(BG_prompt)
        SOD_prompt = self.SOD_prompt_downsample(SOD_prompt)
        COD_prompt = self.COD_prompt_downsample(COD_prompt)

        src = image_embeddings
        bs = src.size(0)

        BG_prompt = self.BG_prompt_fc(BG_prompt.reshape(bs, 256, 121)).reshape(bs, 4, 256)
        SOD_prompt = self.SOD_prompt_fc(SOD_prompt.reshape(bs, 256, 121)).reshape(bs, 4, 256)
        COD_prompt = self.COD_prompt_fc(COD_prompt.reshape(bs, 256, 121)).reshape(bs, 4, 256)

        static_tokens = self.static_token_embedding.weight
        static_tokens = static_tokens.unsqueeze(0).expand(bs, -1, -1)
        dynamic_tokens = torch.cat((BG_prompt, SOD_prompt, COD_prompt), dim=1)
        tokens = dynamic_tokens + static_tokens

        image_pe = self.pe_layer([22, 22]).unsqueeze(0)
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)

        queries, keys = self.MyTransformer(src, pos_src, tokens)
        hs = queries
        SmartImage_embeddings = keys

        SmartImage_embeddings = SmartImage_embeddings.permute(0, 2, 1).reshape(bs, 256, 22, 22)

        BG_prompt = hs[:, :4, :]
        SOD_prompt = hs[:, 4:8, :]
        COD_prompt = hs[:, 8:, :]

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            bs=image_embeddings.shape[0],
            points=None, boxes=None, masks=None
        )

        BG_sparse_embeddings = BG_prompt
        BG_low_res_masks = self.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=BG_sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
        )

        BG_masks = self.postprocess_masks(
            BG_low_res_masks,
            input_size=(image_size, image_size),
            original_size=(image_size, image_size)
        )
        BG_outputs = {
            'masks': BG_masks,
            'low_res_logits': BG_low_res_masks
        }

        sod_sparse_embeddings = SOD_prompt
        sod_low_res_masks = self.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sod_sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
        )
        sod_masks = self.postprocess_masks(
            sod_low_res_masks,
            input_size=(image_size, image_size),
            original_size=(image_size, image_size)
        )
        sod_outputs = {
            'masks': sod_masks,
            'low_res_logits': sod_low_res_masks
        }

        cod_sparse_embeddings = COD_prompt
        cod_low_res_masks = self.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=cod_sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
        )
        cod_masks = self.postprocess_masks(
            cod_low_res_masks,
            input_size=(image_size, image_size),
            original_size=(image_size, image_size)
        )
        cod_outputs = {
            'masks': cod_masks,
            'low_res_logits': cod_low_res_masks
        }
        return coarse_map_out, BG_outputs, sod_outputs, cod_outputs

    def postprocess_masks(
            self,
            masks: torch.Tensor,
            input_size: Tuple[int, ...],
            original_size: Tuple[int, ...],
    ) -> torch.Tensor:

        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]

        return masks

    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize pixel values and pad to a square input."""
        x = (x - self.pixel_mean) / self.pixel_std
        h, w = x.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(x, (0, padw, 0, padh))
        return x