import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack

from .Trans_En import TransformerModel
from .vit import ViT
from .vit_1d import ViT_1d

## no-share the parameters for the view1 and view3
class Refine_Fea_DoubleView_noshare(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.view1_refine =  Refine_Fea_DoubleView(cfg)
        self.view3_refine = Refine_Fea_DoubleView(cfg)

    def forward(self, view1_x, view3_x):

        out_view1 = self.view1_refine(view1_x)
        out_view3 = self.view3_refine(view3_x)
        out_dict = {
            'view1_feature': out_view1['feature'],
            'view3_feature': out_view3['feature'],
            'logit_scale':out_view1['logit_scale']
        }
        return out_dict

##======================== share parameter for the view1 and view3
class Refine_Fea_DoubleView(nn.Module): # for pretrain & down, just single_view1 and single_view3
    def __init__(self, cfg):
        super().__init__()

        type_frame = cfg.refine_fea.type_frame
        in_channel = cfg.refine_fea.in_channel
        d_model = cfg.refine_fea.d_model
        init_logit_scale = cfg.refine_fea.init_logit_scale

        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        self.refine = Refine_Fea(type_frame, in_channel, d_model) # all refine parameters are included here

    def forward(self, x):
        b_n_view, win, c = x.shape
        x = self.refine(x)  # [b*nh, win, c] -> [b*nh, c]
        x = F.normalize(x, dim=-1)

        out_dict = {
            "feature": x,
            "logit_scale": self.logit_scale.exp()
        }
        return out_dict


class Refine_Fea_MultiView(nn.Module): # just for pretrain, multi_view1 and multi_view3
    def __init__(self, cfg):
        super().__init__()

        type_frame = cfg.refine_fea.type_frame
        in_channel = cfg.refine_fea.in_channel
        d_model = cfg.refine_fea.d_model
        init_logit_scale = cfg.refine_fea.init_logit_scale
        type_aggre = cfg.aggre.type_aggre

        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)

        self.refine = Refine_Fea(type_frame, in_channel, d_model) # all refine parameters are includede here
        self.aggre = Aggregate(type_aggre, d_model, d_model)

    def forward(self, x, n_view):
        b_n_view, win, c = x.shape
        x = self.refine(x)  # [b*nh, win, c] -> [b*nh, c]
        # x = F.normalize(x, dim=-1)
        bgn, c = x.shape  # bgn is real batchsize/num_gpus
        assert bgn % n_view == 0, 'The bn is set wrong.'
        x = rearrange(x, '(bg n) c -> bg n c', n=n_view)

        x = self.aggre(x)
        out_dict = {
            "feature": x,
            "logit_scale": self.logit_scale.exp()
        }
        return out_dict

#####=============== basic model for feature refine ================
## can be used once or multiply times
class Refine_Fea(nn.Module):  # input win, output 1
    def __init__(self, type_frame, in_channel, d_model):
        super().__init__()
        self.type_frame = type_frame

        if in_channel != d_model:
            ##===norm3
            self.norm0 = nn.LayerNorm(in_channel) # norm3
            self.proj_in = nn.Linear(in_features=in_channel, out_features= d_model, bias=True)
        else:
            self.proj_in = None

        # ##=======norm2
        self.norm = nn.LayerNorm(d_model) ## norm2

        ##=====ours
        self.refine_fea = TransformerModel(d_model, d_model, nhead=8, num_encoder_layers=6,
                                           dim_feedforward=2048,
                                           dropout=0.1, activation="relu", normalize_before=False, max_len=200)

        if self.type_frame == 'ClsToken':
            self.cls_token = nn.Parameter(torch.randn(d_model))

    def forward(self, x, **kwargs):
        # print(x.shape)
        if self.proj_in:
            ###========norm3
            x = self.norm0(x)
            x = self.proj_in(x) # cls cat before or after the projection

        ###========norm2
        x = self.norm(x)

        if self.type_frame == 'AvgPool':
            x = self.refine_fea(x, **kwargs)
            x = torch.mean(x, dim=-2, keepdim=False)
            return x
        elif self.type_frame == 'MaxPool':
            x = self.refine_fea(x, **kwargs)
            x = torch.max(x, dim=-2, keepdim=False)
            return x
        elif self.type_frame == 'ClsToken':
            b, w, c = x.shape

            cls_tokens = repeat(self.cls_token, 'd -> b d', b=b)
            x, ps = pack([cls_tokens, x], 'b * d')
            x = self.refine_fea(x, **kwargs)
            cls_tokens, x = unpack(x, ps, 'b * d')  # [b d] [b * d]
            return cls_tokens
        elif self.type_frame == 'AttentionPool':
            b, w, c = x.shape
            cls_tokens =  x.mean(1, keepdim=True)
            # cls_tokens = repeat(self.cls_token, 'd -> b d', b=b)
            x, ps = pack([cls_tokens, x], 'b * d')
            x = self.refine_fea(x, **kwargs)
            cls_tokens, x = unpack(x, ps, 'b * d')
            # cls_tokens =  cls_tokens.squeeze(dim=-2)  # [b d] [b * d]
            return cls_tokens.squeeze(dim=-2)

class Aggregate(nn.Module):
    def __init__(self, type_aggre, in_channel, d_model):
        super().__init__()
        self.type_aggre = type_aggre
        self.embed = TransformerModel( in_channel, d_model, nhead=2, num_encoder_layers=1, dim_feedforward=512,
                                      dropout=0.1, activation="relu", normalize_before=True, max_len=200)
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        # several transformer encoder layer!
        b, n, _ = x.shape
        if self.type_aggre == 'AvgPool':
            x = self.embed(x)
            x = torch.mean(x, dim=-2, keepdim=False)
            return x
        elif self.type_aggre == 'MaxPool':
            x = self.embed(x)
            x = torch.max(x, dim=-2, keepdim=False)
            return x

        elif self.type_aggre == 'AttentionPool':
            cls_tokens = x.mean(1, keepdim=True)
            x, ps = pack([cls_tokens, x], 'b * d')  # [b d] [b * d]
            embed = self.embed(x)
            cls_tokens, x = unpack(embed, ps, 'b * d')  # wrong before unpack(x, ps, 'b * d')
            cls_tokens = self.linear(cls_tokens.squeeze(dim=-2))
            return cls_tokens

######=================== vit

class Refine_Fea_DoubleView_vit(nn.Module): # for pretrain & down, just single_view1 and single_view3
    def __init__(self, cfg):
        super().__init__()

        type_frame = cfg.refine_fea.type_frame
        in_channel = cfg.refine_fea.in_channel
        d_model = cfg.refine_fea.d_model
        init_logit_scale = cfg.refine_fea.init_logit_scale
        num_patch = cfg.refine_fea.num_patch

        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        self.refine = Refine_Fea_vit(type_frame, in_channel, d_model, num_patch ) # all refine parameters are included here

    def forward(self, x):
        b_n_view, win, c = x.shape
        x = self.refine(x)  # [b*nh, win, c] -> [b*nh, c]

        out_dict = {
            "feature": x,
            "logit_scale": self.logit_scale.exp()
        }
        return out_dict

class Refine_Fea_vit(nn.Module):  # input win, output 1
    def __init__(self, type_frame, in_channel, d_model, num_patch):
        super().__init__()

        self.refine_fea = ViT( in_dim = in_channel, out_dim = d_model , dim=1024, num_patches =num_patch, depth=6, heads=16,
                                mlp_dim=2048, pool = type_frame, dropout=0.1, emb_dropout=0.1 )

    def forward(self, x):
        # print(x.shape)
        x = self.refine_fea(x)
        return x

######=================== vit_1d
## all most are same as the vit, just differ in mlp_head
## the vit_1d

class Refine_Fea_DoubleView_vit1d(nn.Module): # for pretrain & down, just single_view1 and single_view3
    def __init__(self, cfg):
        super().__init__()

        in_channel = cfg.refine_fea.in_channel
        d_model = cfg.refine_fea.d_model
        init_logit_scale = cfg.refine_fea.init_logit_scale
        num_patch = cfg.refine_fea.num_patch

        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        self.refine = Refine_Fea_vit1d( in_channel, d_model, num_patch ) # all refine parameters are included here

    def forward(self, x):
        b_n_view, win, c = x.shape
        x = self.refine(x)  # [b*nh, win, c] -> [b*nh, c]

        out_dict = {
            "feature": x,
            "logit_scale": self.logit_scale.exp()
        }
        return out_dict

class Refine_Fea_vit1d(nn.Module):  # input win, output 1
    def __init__(self, in_channel, d_model, num_patch):
        super().__init__()

        ##=====large
        self.refine_fea = ViT_1d( num_patches = num_patch, in_dim = in_channel, out_dim = d_model , dim=1024, depth =24, heads=16, mlp_dim=4096,
                                dim_head = 64, dropout = 0.1, emb_dropout = 0.1)
        # ##=====base
        # self.refine_fea = ViT_1d( num_patches = num_patch, in_dim = in_channel, out_dim = d_model , dim=768, depth =12, heads=12, mlp_dim=3072,
        #                         dim_head = 64, dropout = 0.1, emb_dropout = 0.1)
        # ##====ours
        # self.refine_fea = ViT_1d( num_patches = num_patch, in_dim = in_channel, out_dim = d_model , dim=1024, depth =6, heads=16, mlp_dim=2048,
        #                             channels = 3, dim_head = 64, dropout = 0.1, emb_dropout = 0.1)

    def forward(self, x):
        # print(x.shape)
        x = self.refine_fea(x)
        return x
