from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import trunc_normal_
from timm.models.resnet import Bottleneck, ResNet

from .frequency_loss import FrequencyLoss
from .swin_transformer import SwinTransformer
from .utils import get_2d_sincos_pos_embed
from .vision_transformer import VisionTransformer


class SwinTransformerForMFM(SwinTransformer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = x_fft
        x = self.patch_embed(x)
        B, L, _ = x.shape

        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)

        x = x.transpose(1, 2)
        B, C, L = x.shape
        H = W = int(L ** 0.5)
        x = x.reshape(B, C, H, W)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return super().no_weight_decay() | {'mask_token'}


class VisionTransformerForMFM(VisionTransformer):
    def __init__(self, config, use_fixed_pos_emb=False, **kwargs):
        super().__init__(**kwargs)
        self.decoder_depth = config.MODEL.VIT.DECODER.DEPTH
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

        if use_fixed_pos_emb:
            assert self.pos_embed is None
            self.pos_embed = nn.Parameter(torch.zeros(
                1, self.patch_embed.num_patches + 1, self.embed_dim), requires_grad=False)  # fixed sin-cos embedding
            pos_embed = get_2d_sincos_pos_embed(
                self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
            self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        assert self.pos_embed is not None

    def _trunc_normal_(self, tensor, mean=0., std=1.):
        trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = x_fft
        x = self.patch_embed(x)
        B, L, _ = x.shape

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)

        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            x = blk(x, rel_pos_bias=rel_pos_bias)
        x = self.norm(x)
        x_cls = x[:, 0]
        if self.decoder_depth == 0:
            
            # remove cls token
            x = x[:, 1:]
            B, L, C = x.shape
            H = W = int(L ** 0.5)
            x = x.permute(0, 2, 1).reshape(B, C, H, W)
        return x, x_cls


class ResNetForMFM(ResNet):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = x_fft
        x = self.forward_features(x)
        b, f, _, _ = x.shape
        pooled_tensor = F.avg_pool2d(x, kernel_size=x.shape[-2:])

# Reshape the pooled tensor to the desired shape [256, 2048]
        x_pool = pooled_tensor.view(b, f)
        return x, x_pool
    

class ResNetVa(ResNet):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.filter_type = config.DATA.FILTER_TYPE
        assert self.num_classes == 0

    def forward(self, x, x_fft):
        # todo remove comment
        # if self.filter_type == 'mfm':
        x = self.forward_features(x)
        return x


class MFMStudentMultiHead(nn.Module):
    def __init__(self, encoder, config, is_student=True, encoder_stride=16, logger=None):
        super().__init__()
        self.distillation_header_version = config.MODEL.DISTILLATION_HEADER_VERSION
        self.logger = logger
        self.encoder = encoder
        self.encoder_stride = encoder_stride
        self.is_student = is_student
        self.student_strategy = config.TRAIN.STUDENT_STRATEGY
        assert config.DATA.FILTER_TYPE in ['mfm', 'sr', 'deblur', 'denoise', 'fftComp', 
                                           'gabor', 'r_mask', 'cicle']
        assert config.MODEL.RECOVER_TARGET_TYPE in ['masked', 'normal']
        self.filter_type = config.DATA.FILTER_TYPE
        self.mask_radius1 = config.DATA.MASK_RADIUS1
        self.mask_radius2 = config.DATA.MASK_RADIUS2
        self.recover_target_type = config.MODEL.RECOVER_TARGET_TYPE
        self.criterion = FrequencyLoss(
            loss_gamma=config.MODEL.FREQ_LOSS.LOSS_GAMMA,
            matrix_gamma=config.MODEL.FREQ_LOSS.MATRIX_GAMMA,
            patch_factor=config.MODEL.FREQ_LOSS.PATCH_FACTOR,
            ave_spectrum=config.MODEL.FREQ_LOSS.AVE_SPECTRUM,
            with_matrix=config.MODEL.FREQ_LOSS.WITH_MATRIX,
            log_matrix=config.MODEL.FREQ_LOSS.LOG_MATRIX,
            batch_matrix=config.MODEL.FREQ_LOSS.BATCH_MATRIX).cuda()
        if self.filter_type == 'sr':
            self.sr_factor = config.DATA.SR_FACTOR
            self.sr_mode = config.DATA.INTERPOLATION
        self.normalize_img = T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
        self.mfm_header = MFMHeader(num_features=encoder.num_features, encoder_stride=encoder_stride)
        # self.num_patches = (config.DATA.IMG_SIZE // config.MODEL.VIT.PATCH_SIZE) ** 2
        if config.MODEL.DISTILLATION_HEADER_VERSION == "DINOHead":
            self.distillation_header = DINOHead(in_dim=encoder.num_features, 
                                                out_dim=config.TRAIN.OUT_DIM, 
                                                use_bn=config.TRAIN.USE_BN_IN_HEAD,
                                                norm_last_layer=config.TRAIN.NORM_LAST_LAYER,
                                                nlayers=config.N_LAYERS)
            
        if config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV3":
            self.distillation_header = DistillationHeaderV3(in_dim=encoder.num_features, 
                                                            out_dim=encoder.num_features, 
                                                            act='gelu', nlayers=3, 
                                                            num_features=encoder.num_features)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV31":
            self.distillation_header = DistillationHeaderV3(in_dim=encoder.num_features, 
                                                            out_dim=256, 
                                                            act='gelu', nlayers=3, 
                                                            num_features=2024)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV3CLS":
            self.distillation_header = DistillationHeaderV33(in_dim=encoder.num_features, 
                                                             act='gelu', nlayers=3, 
                                                             out_dim=encoder.num_features)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV31CLS":
            self.distillation_header = DistillationHeaderV33(in_dim=encoder.num_features, 
                                                             act='gelu', nlayers=3, 
                                                             out_dim=256, num_features=2024)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV2":
            self.distillation_header = DistillationHeaderV3(in_dim=encoder.num_features, 
                                                            out_dim=encoder.num_features, 
                                                            act='gelu', nlayers=2, 
                                                            num_features=encoder.num_features)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV21":
            self.distillation_header = DistillationHeaderV3(in_dim=encoder.num_features, 
                                                            out_dim=256, 
                                                            act='gelu', nlayers=2, 
                                                            num_features=2024)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV2CLS":
            self.distillation_header = DistillationHeaderV33(in_dim=encoder.num_features,
                                                             act='gelu', nlayers=2, 
                                                             out_dim=encoder.num_features)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV21CLS":
            self.distillation_header = DistillationHeaderV33(in_dim=encoder.num_features,
                                                             act='gelu', nlayers=2, 
                                                             out_dim=256, num_features=2024)

        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV1":
            self.distillation_header = DistillationHeaderV1(num_features=encoder.num_features)

        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV11":
            self.distillation_header = DistillationHeaderV1(num_features=256)

        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV1CLS":
            self.distillation_header = DistillationHeaderV13(in_dim=encoder.num_features, 
                                                             out_dim=encoder.num_features)
            
        elif config.MODEL.DISTILLATION_HEADER_VERSION == "DistillationHeaderV11CLS":
            self.distillation_header = DistillationHeaderV13(in_dim=encoder.num_features, 
                                                             out_dim=256)

        if config.MODEL.TYPE == 'resnet':
            self.in_chans = config.MODEL.RESNET.IN_CHANS
            self.patch_size = 1
        else:
            self.in_chans = self.encoder.in_chans
            self.patch_size = self.encoder.patch_size

    def frequency_transform(self, x, mask):
        x_freq = torch.fft.fft2(x)
        x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
        x_freq_masked = x_freq * mask
        x_freq_masked = torch.fft.ifftshift(x_freq_masked, dim=(-2, -1))
        x_corrupted = torch.fft.ifft2(x_freq_masked).real
        x_corrupted = torch.clamp(x_corrupted, min=0., max=1.)
        return x_corrupted

    def forward(self, x, mask=None):
        assert mask is not None
        mask = mask.unsqueeze(1)
        x_corrupted = self.frequency_transform(x, mask)
        x_corrupted = self.normalize_img(x_corrupted)
        x = self.normalize_img(x)
        if self.is_student:
            z, z_cls = self.encoder(x, x_corrupted)
            x = x_corrupted
        else:
            z, z_cls = self.encoder(x, x)

        # self.logger.info(f"shape of out for res {z.shape}, {z_cls.shape}")
        mfm_res = self.mfm_header(z)
        if 'CLS' in self.distillation_header_version or self.distillation_header_version == "DINOHead":
            distillation_res = self.distillation_header(z_cls)
        else:
            distillation_res = self.distillation_header(z)
        return x, z_cls, mfm_res, distillation_res

    @torch.jit.ignore
    def no_weight_decay(self):
        if hasattr(self.encoder, 'no_weight_decay'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay()}
        return {}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if hasattr(self.encoder, 'no_weight_decay_keywords'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
        return {}


class MFMHeader(nn.Module):
    def __init__(self, num_features, encoder_stride):
        super().__init__()
        self.num_features = num_features
        self.encoder_stride = encoder_stride
        self.header = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_features,
                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(self.encoder_stride),
        )

    def forward(self, x):
        x = self.header(x)
        return x
    

class DistillationHeaderV1(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.num_features = num_features
        self.header = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_features,
                out_channels=self.num_features, kernel_size=1),
        )

    def forward(self, x):
        x = self.header(x)
        return x
    
class DistillationHeaderV13(nn.Module):
    def __init__(self, in_dim, out_dim=2048):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.header = nn.Sequential(
            nn.Linear(self.in_dim, self.out_dim),
        )

    def forward(self, x):
        x = self.header(x)
        return x
    

class DistillationHeaderV3(nn.Module):
    def __init__(self, in_dim, out_dim=2048, act='gelu', nlayers=3, num_features=2048):
        super().__init__()
        # self.norm = partial(nn.LayerNorm, eps=1e-6)
        self.norm = nn.SyncBatchNorm(num_features, eps=1e-6)
        self.act = self._build_act(act)

        nlayers = max(nlayers, 1)
        layers = [nn.Conv2d(in_dim, num_features, kernel_size=3, padding=1)]
        if self.norm is not None:
            layers.append(self.norm)
        layers.append(self.act)
        if nlayers > 2:
            for _ in range(nlayers - 2):
                layers.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
                if self.norm is not None:
                    layers.append(self.norm)
                layers.append(self.act)
        layers.append(nn.Conv2d(num_features, out_dim, kernel_size=3, padding=1))
        self.mlp = nn.Sequential(*layers)
        self.apply(self._init_weights)


    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Conv2d) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        return x

    def _build_act(self, act):
        if act == 'relu':
            act = nn.ReLU()
        elif act == 'gelu':
            act = nn.GELU()
        else:
            assert False, "unknown act type {}".format(act)
        return act
    

class DistillationHeaderV33(nn.Module):
    def __init__(self, in_dim, out_dim=512, act='gelu', nlayers=3, num_features=512):
        super().__init__()
        # self.norm = partial(nn.LayerNorm, eps=1e-6)
        self.norm = nn.SyncBatchNorm(num_features, eps=1e-6)
        self.norm_zero = nn.SyncBatchNorm(in_dim, eps=1e-6)
        self.act = self._build_act(act)

        nlayers = max(nlayers, 1)
        layers = [nn.Linear(in_dim, num_features)]
        if self.norm is not None:
            layers.append(self.norm)
        layers.append(self.act)
        if nlayers > 2:
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(num_features, num_features))
                if self.norm is not None:
                    layers.append(self.norm)
                layers.append(self.act)
        layers.append(nn.Linear(num_features, out_dim))
        self.mlp = nn.Sequential(*layers)
        self.apply(self._init_weights)


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        return x

    def _build_act(self, act):
        if act == 'relu':
            act = nn.ReLU()
        elif act == 'gelu':
            act = nn.GELU()
        else:
            assert False, "unknown act type {}".format(act)
        return act


class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)
        self.apply(self._init_weights)
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x  
 

def build_mfm_multi_head_student(config, is_student=True, logger=None):
    model_type = config.MODEL.TYPE
    if model_type == 'swin':
        encoder = SwinTransformerForMFM(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=0,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            config=config)
        encoder_stride = 32
    elif model_type == 'vit':
        encoder = VisionTransformerForMFM(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.VIT.PATCH_SIZE,
            in_chans=config.MODEL.VIT.IN_CHANS,
            num_classes=0,
            embed_dim=config.MODEL.VIT.EMBED_DIM,
            depth=config.MODEL.VIT.DEPTH,
            num_heads=config.MODEL.VIT.NUM_HEADS,
            mlp_ratio=config.MODEL.VIT.MLP_RATIO,
            qkv_bias=config.MODEL.VIT.QKV_BIAS,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            init_values=config.MODEL.VIT.INIT_VALUES,
            use_abs_pos_emb=config.MODEL.VIT.USE_APE,
            use_fixed_pos_emb=config.MODEL.VIT.USE_FPE,
            use_rel_pos_bias=config.MODEL.VIT.USE_RPB,
            use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB,
            use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING,
            config=config)
        encoder_stride = 16
    elif model_type == 'resnet':
        encoder = ResNetForMFM(
            block=Bottleneck,
            layers=config.MODEL.RESNET.LAYERS,
            in_chans=config.MODEL.RESNET.IN_CHANS,
            num_classes=0,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            config=config)
        encoder_stride = 32
    else:
        raise NotImplementedError(f"Unknown pre-train model: {model_type}")

    model = MFMStudentMultiHead(encoder=encoder, config=config, is_student=is_student,
                                encoder_stride=encoder_stride, logger=logger)

    return model


def build_resnet(config):
    model = ResNetVa(
        block=Bottleneck,
        layers=config.MODEL.RESNET.LAYERS,
        in_chans=config.MODEL.RESNET.IN_CHANS,
        num_classes=config.MODEL.NUM_CLASSES,
        drop_rate=config.MODEL.DROP_RATE,
        drop_path_rate=config.MODEL.DROP_PATH_RATE,
        config=config)

    return model