from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import trunc_normal_
from timm.models.resnet import Bottleneck, ResNet

from .frequency_loss import FrequencyLoss
from .swin_transformer import SwinTransformer
from .utils import get_2d_sincos_pos_embed
from .vision_transformer import VisionTransformer


class SwinTransformerForMFM(SwinTransformer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = x_fft
        x = self.patch_embed(x)
        B, L, _ = x.shape

        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)

        x = x.transpose(1, 2)
        B, C, L = x.shape
        H = W = int(L ** 0.5)
        x = x.reshape(B, C, H, W)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return super().no_weight_decay() | {'mask_token'}


class VisionTransformerForMFM(VisionTransformer):
    def __init__(self, config, use_fixed_pos_emb=False, **kwargs):
        super().__init__(**kwargs)
        self.decoder_depth = config.MODEL.VIT.DECODER.DEPTH
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

        if use_fixed_pos_emb:
            assert self.pos_embed is None
            self.pos_embed = nn.Parameter(torch.zeros(
                1, self.patch_embed.num_patches + 1, self.embed_dim), requires_grad=False)  # fixed sin-cos embedding
            pos_embed = get_2d_sincos_pos_embed(
                self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
            self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        assert self.pos_embed is not None

    def _trunc_normal_(self, tensor, mean=0., std=1.):
        trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = x_fft
        x = self.patch_embed(x)
        B, L, _ = x.shape

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)

        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            x = blk(x, rel_pos_bias=rel_pos_bias)
        x = self.norm(x)

        if self.decoder_depth == 0:
            # remove cls token
            x = x[:, 1:]
            B, L, C = x.shape
            H = W = int(L ** 0.5)
            x = x.permute(0, 2, 1).reshape(B, C, H, W)
        return x


class ResNetForMFM(ResNet):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = x_fft
        x = self.forward_features(x)
        return x


class ResNetVa(ResNet):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = self.forward_features(x)
        return x

class MFMStudent(nn.Module):
    def __init__(self, encoder, config, is_student=True):
        super().__init__()
        self.encoder = encoder
        self.is_student = is_student
        self.student_strategy = config.TRAIN.STUDENT_STRATEGY
        self.filter_type = config.DATA.FILTER_TYPE
        self.mask_radius1 = config.DATA.MASK_RADIUS1
        self.mask_radius2 = config.DATA.MASK_RADIUS2
        self.recover_target_type = config.MODEL.RECOVER_TARGET_TYPE
        self.criterion = FrequencyLoss(
            loss_gamma=config.MODEL.FREQ_LOSS.LOSS_GAMMA,
            matrix_gamma=config.MODEL.FREQ_LOSS.MATRIX_GAMMA,
            patch_factor=config.MODEL.FREQ_LOSS.PATCH_FACTOR,
            ave_spectrum=config.MODEL.FREQ_LOSS.AVE_SPECTRUM,
            with_matrix=config.MODEL.FREQ_LOSS.WITH_MATRIX,
            log_matrix=config.MODEL.FREQ_LOSS.LOG_MATRIX,
            batch_matrix=config.MODEL.FREQ_LOSS.BATCH_MATRIX).cuda()

        self.normalize_img = T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)

        if config.MODEL.TYPE == 'resnet':
            self.in_chans = config.MODEL.RESNET.IN_CHANS
            self.patch_size = 1
        else:
            self.in_chans = self.encoder.in_chans
            self.patch_size = self.encoder.patch_size

    def frequency_transform(self, x, mask):
        x_freq = torch.fft.fft2(x)
        x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
        x_freq_masked = x_freq * mask
        x_freq_masked = torch.fft.ifftshift(x_freq_masked, dim=(-2, -1))
        x_corrupted = torch.fft.ifft2(x_freq_masked).real
        x_corrupted = torch.clamp(x_corrupted, min=0., max=1.)
        return x_corrupted

    def forward(self, x, mask=None):
        assert mask is not None
        mask = mask.unsqueeze(1)
        x_corrupted = self.frequency_transform(x, mask)
        x_corrupted = self.normalize_img(x_corrupted)
        x = self.normalize_img(x)
        if self.is_student:
            z = self.encoder(x, x_corrupted)
        else:
            z = self.encoder(x, x)
        return x, z

    @torch.jit.ignore
    def no_weight_decay(self):
        if hasattr(self.encoder, 'no_weight_decay'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay()}
        return {}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if hasattr(self.encoder, 'no_weight_decay_keywords'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
        return {}


class MFMHeader(nn.Module):
    def __init__(self, num_features, encoder_stride):
        super().__init__()
        self.num_features = num_features
        self.encoder_stride = encoder_stride
        self.header = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_features,
                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(self.encoder_stride),
        )

    def forward(self, x):
        x = self.header(x)
        return x
    

class DistillationHeader(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.num_features = num_features
        self.header = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_features,
                out_channels=self.num_features, kernel_size=1),
        )

    def forward(self, x):
        x = self.header(x)
        return x


def build_mfm_student(config, is_student=True):
    model_type = config.MODEL.TYPE
    if model_type == 'swin':
        encoder = SwinTransformerForMFM(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=0,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            config=config)
        encoder_stride = 32
    elif model_type == 'vit':
        encoder = VisionTransformerForMFM(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.VIT.PATCH_SIZE,
            in_chans=config.MODEL.VIT.IN_CHANS,
            num_classes=0,
            embed_dim=config.MODEL.VIT.EMBED_DIM,
            depth=config.MODEL.VIT.DEPTH,
            num_heads=config.MODEL.VIT.NUM_HEADS,
            mlp_ratio=config.MODEL.VIT.MLP_RATIO,
            qkv_bias=config.MODEL.VIT.QKV_BIAS,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            init_values=config.MODEL.VIT.INIT_VALUES,
            use_abs_pos_emb=config.MODEL.VIT.USE_APE,
            use_fixed_pos_emb=config.MODEL.VIT.USE_FPE,
            use_rel_pos_bias=config.MODEL.VIT.USE_RPB,
            use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB,
            use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING,
            config=config)
        encoder_stride = 16
    elif model_type == 'resnet':
        encoder = ResNetForMFM(
            block=Bottleneck,
            layers=config.MODEL.RESNET.LAYERS,
            in_chans=config.MODEL.RESNET.IN_CHANS,
            num_classes=0,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            config=config)
        encoder_stride = 32
    else:
        raise NotImplementedError(f"Unknown pre-train model: {model_type}")

    model = MFMStudent(encoder=encoder, config=config, is_student=is_student)
    mfm_header = MFMHeader(num_features=encoder.num_features, encoder_stride=encoder_stride)
    distillation_header = DistillationHeader(num_features=encoder.num_features)

    return model, mfm_header, distillation_header


def build_resnet(config):
    model = ResNetVa(
        block=Bottleneck,
        layers=config.MODEL.RESNET.LAYERS,
        in_chans=config.MODEL.RESNET.IN_CHANS,
        num_classes=config.MODEL.NUM_CLASSES,
        drop_rate=config.MODEL.DROP_RATE,
        drop_path_rate=config.MODEL.DROP_PATH_RATE,
        config=config)

    return model


