from functools import partial
import numpy as np
import torch
import torch.nn as nn
import math
import random
from collections import defaultdict, Counter
from einops import rearrange, repeat
from torch import tensor

from vit import Block
from util.pos_embed import get_2d_sincos_pos_embed
from utils import trunc_normal_
import random
from aim import adaptive_inter_channel_masking
from torch.nn import functional as F


class PatchEmbedPerChannel(nn.Module):
    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 8,
        mapper: dict = None,
        embed_dim: int = 768,
        enable_sample: bool = False,
        use_channelvit_channels: bool = True,
    ):
        super().__init__()

        num_patches = (img_size // patch_size) * (img_size // patch_size)

        self.img_size = img_size
        self.mapper = mapper
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.in_chans = in_chans
        self.channel_scale = np.sqrt(1.0 / 1000)
        self.channel_emb_proxies = torch.nn.Parameter((torch.randn(in_chans, embed_dim) / 8))
        nn.init.orthogonal_(self.channel_emb_proxies)

        self.proj = nn.Conv2d(
            1,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )
        
        self.channel_embed = nn.Embedding(in_chans, embed_dim)
        nn.init.orthogonal_(self.channel_embed.weight)

        self.use_channelvit_channels = use_channelvit_channels
        self.enable_sample = enable_sample

    def forward(
        self,
        x,
        channel_indices,
        chunk_name: str = "train",
        training_chunks=None,
        new_channel_init=None,
        extra_tokens={},
        **kwargs,
    ):
        batch_size, _, h, w = x.shape
        
        x_proj = self.proj(x)
        
        if self.use_channelvit_channels:
            channel_embed = self.channel_embed(channel_indices)
            channel_embed = channel_embed.unsqueeze(-1).unsqueeze(-1)
            x_proj = x_proj + channel_embed

        x_proj = x_proj.flatten(2)
        x_proj = x_proj.transpose(1, 2)

        return x_proj, 2, 0


class MaskedAutoencoderViT(nn.Module):
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=8,
        mapper: dict = None,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        norm_layer=nn.LayerNorm,
        enable_sample=False,
        use_channelvit_channels=True,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        norm_pix_loss=False,
        **kwargs,
    ):
        super().__init__()
        drop_path_rate = 0.0
        self.num_features = self.embed_dim = self.out_dim = embed_dim
        self.in_chans = in_chans

        self.patch_embed = PatchEmbedPerChannel(
            img_size=img_size[0],
            patch_size=patch_size,
            mapper=mapper,
            in_chans=in_chans,
            embed_dim=embed_dim,
            enable_sample=enable_sample,
            use_channelvit_channels=use_channelvit_channels,
        )
        num_patches = self.patch_embed.num_patches
        self.patch_size = patch_size
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.pos_embed = nn.Parameter(
            torch.zeros(1, 2 * num_patches + 1, embed_dim)
        )

        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        
        BlockClass = Block
            
        self.blocks = nn.ModuleList(
            [
                BlockClass(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    **kwargs,
                )
                for i in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)

        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        self.decoder_channel_embed = nn.Embedding(in_chans, decoder_embed_dim)
        nn.init.orthogonal_(self.decoder_channel_embed.weight)

        patches_per_channel = (img_size[0] // patch_size) ** 2
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, 2 * patches_per_channel, decoder_embed_dim),
            requires_grad=False
        )
        
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * 1, bias=True)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)

        self.norm_pix_loss = norm_pix_loss

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
        
        self.initialize_mae_weights()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def initialize_mae_weights(self):
        patches_per_channel = (self.patch_embed.img_size // self.patch_size) ** 2
        
        single_channel_pos_embed = get_2d_sincos_pos_embed(
            self.decoder_pos_embed.shape[-1], 
            int(patches_per_channel**.5), 
            cls_token=False
        )
        
        decoder_pos_embed = np.concatenate([single_channel_pos_embed, single_channel_pos_embed], axis=0)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        torch.nn.init.normal_(self.mask_token, std=.02)

    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1

        if npatch == N and w == h:
            return self.pos_embed

        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]

        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        w0, h0 = w0 + 0.1, h0 + 0.1
        
        patches_per_channel = N // 2
        sqrt_N = int(math.sqrt(patches_per_channel))
        
        patch_pos_embed_ch1 = patch_pos_embed[:, :patches_per_channel]
        patch_pos_embed_ch2 = patch_pos_embed[:, patches_per_channel:]
        
        interpolated_ch1 = nn.functional.interpolate(
            patch_pos_embed_ch1.reshape(1, sqrt_N, sqrt_N, dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / sqrt_N, h0 / sqrt_N),
            mode="bicubic",
        ).permute(0, 2, 3, 1).view(1, -1, dim)
        
        interpolated_ch2 = nn.functional.interpolate(
            patch_pos_embed_ch2.reshape(1, sqrt_N, sqrt_N, dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / sqrt_N, h0 / sqrt_N),
            mode="bicubic",
        ).permute(0, 2, 3, 1).view(1, -1, dim)
        
        patch_pos_embed = torch.cat([interpolated_ch1, interpolated_ch2], dim=1)

        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

    def prepare_tokens(self, x, chunk: str, training_chunks_str, new_channel_init, extra_tokens):
        B, C, H, W = x.shape
        
        random_offsets = torch.randint(0, C-1, (B * C,), device=x.device)
        orig_channels = torch.arange(C, device=x.device).repeat(B)
        
        additional_channels = torch.where(random_offsets < orig_channels, random_offsets, random_offsets + 1)
        
        x_original = x.reshape(B * C, 1, H, W)
        
        additional_idx_tensor = additional_channels.view(B, C)
        x_additional = torch.gather(x, 1, additional_idx_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W))
        x_additional = x_additional.reshape(B * C, 1, H, W)
        
        x_combined = torch.cat([x_original, x_additional], dim=1)
        
        x_reshaped = x_combined.view(B * C * 2, 1, H, W)
        
        all_channel_indices = torch.stack([orig_channels, additional_channels], dim=1).flatten()
        channel_indices = all_channel_indices
        
        x_embedded, nc, ortho_proxy_loss = self.patch_embed(
            x_reshaped, channel_indices, chunk, training_chunks_str, new_channel_init, extra_tokens
        )
        
        batch_samples = B * C
        num_patches = x_embedded.shape[1]
        x_embedded = x_embedded.reshape(batch_samples, 2 * num_patches, self.embed_dim)
        
        cls_tokens = self.cls_token.expand(batch_samples, -1, -1)
        x_embedded = torch.cat((cls_tokens, x_embedded), dim=1)

        pos_embed = self.interpolate_pos_encoding(x_embedded, W, H)
        x_embedded = x_embedded + pos_embed

        return self.pos_drop(x_embedded), ortho_proxy_loss, all_channel_indices

    def patchify(self, imgs):
        p = self.patch_size
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        C = imgs.shape[1]
        
        x = imgs.reshape(imgs.shape[0], C, h, p, w, p)
        x = x.permute(0, 2, 4, 1, 3, 5)
        x = x.reshape(imgs.shape[0], h * w, C * p * p)
        
        return x

    def unpatchify(self, x):
        p = self.patch_size
        C = 1
        
        img_size = self.patch_embed.img_size
        h = w = img_size // p
        
        assert h * w == x.shape[1], f"Mismatch: h*w={h*w}, x.shape[1]={x.shape[1]}"
        
        x = x.reshape(x.shape[0], h, w, C, p, p)
        x = x.permute(0, 3, 1, 4, 2, 5)
        imgs = x.reshape(x.shape[0], C, h * p, w * p)
        return imgs
    
    def saliency_guided_masking(self, x, base_mask_ratio=0.5):
        N, L_plus_1, D = x.shape
        base_mask_ratio = 0.75
        
        cls_token = x[:, :1, :]
        x_patches = x[:, 1:, :]
        L = x_patches.shape[1]

        aff = torch.matmul(x_patches, x_patches.permute(0, 2, 1))
        aff = F.softmax(aff, dim=2)
        aff_sum = torch.sum(aff, dim=1)
        aff_sum_normalized = (aff_sum - aff_sum.min(dim=1, keepdim=True)[0]) / \
                            (aff_sum.max(dim=1, keepdim=True)[0] - aff_sum.min(dim=1, keepdim=True)[0] + 1e-8)

        y = (aff_sum_normalized > 0.1).sum(dim=1)
        y_max = L
        y_normalized = y.float().mean() / y_max
        dynamic_mask_ratio = base_mask_ratio - 0.15 + 2 * 0.15 * y_normalized
        dynamic_mask_ratio = torch.clamp(dynamic_mask_ratio, 0.0, 1.0)

        len_keep = int(L * (1 - dynamic_mask_ratio))

        noise = torch.rand(N, L, device=x.device) / 2
        saliency_guided_noise = aff_sum_normalized + noise
        ids_shuffle = torch.argsort(saliency_guided_noise, dim=1, descending=True)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_kept_patches = torch.gather(x_patches, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        x_masked = torch.cat([cls_token, x_kept_patches], dim=1)

        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def random_masking(self, x, mask_ratio, mask_type='channel', num_channels=2):
        N, L, D = x.shape
        cls_tokens = x[:, :1]
        x_patches = x[:, 1:]
        L = L - 1
        tokens_per_channel = L // num_channels

        if mask_type == 'random':
            len_keep = int(L * (1 - mask_ratio))
            noise = torch.rand(N, L, device=x.device)
            ids_shuffle = torch.argsort(noise, dim=1)
            ids_restore = torch.argsort(ids_shuffle, dim=1)
            ids_keep = ids_shuffle[:, :len_keep]

        elif mask_type == 'channel':
            keep_channels = torch.randint(0, num_channels, (N,), device=x.device)
            
            base_idx = torch.arange(tokens_per_channel, device=x.device).unsqueeze(0)
            keep_idx = keep_channels.unsqueeze(1) * tokens_per_channel + base_idx
            
            ids_keep = keep_idx
            ids_restore = torch.arange(L, device=x.device).unsqueeze(0).repeat(N, 1)

        else:
            raise ValueError("mask_type should be 'random' or 'channel'")

        ids_keep_exp = ids_keep.unsqueeze(-1).repeat(1, 1, D)
        x_masked = torch.gather(x_patches, dim=1, index=ids_keep_exp)
        x_masked = torch.cat([cls_tokens, x_masked], dim=1)

        mask = torch.ones([N, L], device=x.device)
        mask.scatter_(1, ids_keep, 0)
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio, chunk_name="train"):
        B, C, H, W = x.shape
        
        x, ortho_proxy_loss, channel_indices = self.prepare_tokens(
            x, chunk_name, None, None, {}
        )

        x, mask, ids_restore = adaptive_inter_channel_masking(x, 0.5)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)

        return x, mask, ids_restore, channel_indices

    def forward_decoder(self, x, ids_restore, channel_indices):
        batch_samples = x.shape[0]
        
        cls_tokens = x[:, :1]
        x = x[:, 1:]
        
        x = self.decoder_embed(x)
        
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
        x_ = torch.cat([x, mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))

        channel_indices_reshaped = channel_indices.view(batch_samples, 2)
        ch1_indices = channel_indices_reshaped[:, 0]
        ch2_indices = channel_indices_reshaped[:, 1]
        
        patches_per_channel = ids_restore.shape[1] // 2
        
        ch1_embeds = self.decoder_channel_embed(ch1_indices)
        ch2_embeds = self.decoder_channel_embed(ch2_indices)
        
        ch1_embeds = ch1_embeds.unsqueeze(1).repeat(1, patches_per_channel, 1)
        ch2_embeds = ch2_embeds.unsqueeze(1).repeat(1, patches_per_channel, 1)
        
        channel_embeds = torch.cat([ch1_embeds, ch2_embeds], dim=1)
        x_ = x_ + channel_embeds

        x_ = x_ + self.decoder_pos_embed

        for blk in self.decoder_blocks:
            x_ = blk(x_)
        x_ = self.decoder_norm(x_)

        x_ = self.decoder_pred(x_)

        return x_

    def forward_loss(self, imgs, pred, mask, channel_indices):
        B, C, H, W = imgs.shape
        
        channel_indices_reshaped = channel_indices.view(B*C, 2)
        ch1_indices = channel_indices_reshaped[:, 0]
        ch2_indices = channel_indices_reshaped[:, 1]
        
        imgs_expanded = imgs.unsqueeze(1).repeat(1, C, 1, 1, 1).view(B*C, C, H, W)
        
        ch1_imgs = torch.gather(imgs_expanded, 1, ch1_indices.unsqueeze(1).unsqueeze(-1).unsqueeze(-1).expand(-1, 1, H, W))
        ch2_imgs = torch.gather(imgs_expanded, 1, ch2_indices.unsqueeze(1).unsqueeze(-1).unsqueeze(-1).expand(-1, 1, H, W))
        
        ch1_patches = self.patchify(ch1_imgs)
        ch2_patches = self.patchify(ch2_imgs)
        
        target = torch.cat([ch1_patches, ch2_patches], dim=1)
        
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)

        loss = (loss * mask).sum() / mask.sum()
        return loss

    def forward(self, x, mask_ratio=0.75, chunk_name="train"):
        latent, mask, ids_restore, channel_indices = self.forward_encoder(x, mask_ratio, chunk_name)
        
        pred = self.forward_decoder(latent, ids_restore, channel_indices)
        
        loss = self.forward_loss(x, pred, mask, channel_indices)
        
        return loss, pred, mask, channel_indices

    def forward_features(self, x, chunk_name="train"):
        B, C, H, W = x.shape
        
        x, ortho_proxy_loss, channel_indices = self.prepare_tokens(
            x, chunk_name, None, None, {}
        )

        for blk in self.blocks:
            if hasattr(blk, '__class__') and 'BlockV2' in str(blk.__class__):
                x, counter = blk(x, pruning_method=None, nc=0)
            else:
                x = blk(x)

        x = self.norm(x)
        cls_tokens = x[:, 0]
        
        cls_tokens = cls_tokens.view(B, C, -1)
        
        return cls_tokens

    def forward_classification(self, x, chunk_name="train"):
        x = self.forward_features(x, chunk_name)
        if hasattr(self, 'head') and self.head is not None:
            x = self.head(x)
        return x


def mae_vit_small_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=384, depth=12, num_heads=6,
        decoder_embed_dim=384, decoder_depth=2, decoder_num_heads=6,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b