import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=20):
        super(SAE, self).__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.activations = nn.Sigmoid()
        self.decoder = nn.Linear(hidden_dim, input_dim)

        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - self.topk), dim=1)[0]

        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        
    def forward(self, x):
        sparse_embeddings = self.encoder(x)
        sparse_embeddings = self.activations(sparse_embeddings)
        sparse_embeddings = self.sparsify(sparse_embeddings)
        x = self.decoder(sparse_embeddings)
        return x


class VL_SAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0.1):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim)
        )
        self.activations = nn.Sigmoid()
        
        self.vision_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(input_dim, input_dim),
            )
        self.text_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(input_dim, input_dim),
            )


        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - self.topk), dim=1)[0]

        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        

    def encode(self, embeddings):
        return self.sparsify(self.activations(self.encoder(embeddings)))

    def forward(self, vision_embeddings=None, text_embeddings=None):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings)
            recon_vision_embeddings = self.vision_decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings)
            recon_text_embeddings = self.text_decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t

class VL_SAE_COS(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0):
        super().__init__()
        self.encoder = nn.Parameter(torch.randn(hidden_dim, input_dim))
        nn.init.kaiming_uniform_(self.encoder, a=math.sqrt(5))
        # self.bias = nn.Parameter(torch.randn(input_dim))
        # nn.init.zeros_(self.bias)
        # nn.init.normal_(self.encoder, std=0.01)

        # self.activations = nn.Sigmoid()
        
        self.vision_decoder = nn.Linear(hidden_dim, input_dim)
        self.text_decoder = nn.Linear(hidden_dim, input_dim)
        # self.vision_decoder.weight.data.copy_(self.encoder.T)
        # self.text_decoder.weight.data.copy_(self.encoder.T)

        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings, topk):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - topk), dim=1)[0]
        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        

    def encode(self, embeddings, mode='eval'):
        weights = F.normalize(self.encoder, p=2, dim=1)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        embeddings = torch.cdist(embeddings, weights, p=2)
        embeddings = 2 - embeddings
        if mode=='train':
            # embeddings = embeddings + torch.randn_like(embeddings) * 0.1
            # topk = max(1, int(torch.normal(mean=torch.tensor(float(self.topk)), std=torch.tensor(self.topk)).item()))
            topk = self.topk
        else:
            topk = self.topk
        return self.sparsify(embeddings, topk=topk)

    def forward(self, vision_embeddings=None, text_embeddings=None, mode='eval'):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings, mode=mode)
            recon_vision_embeddings = self.vision_decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings, mode=mode)
            recon_text_embeddings = self.text_decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t

class VL_SAE_COS_ABL(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0):
        super().__init__()
        self.encoder = nn.Parameter(torch.randn(hidden_dim, input_dim))
        nn.init.kaiming_uniform_(self.encoder, a=math.sqrt(5))
        # self.bias = nn.Parameter(torch.randn(input_dim))
        # nn.init.zeros_(self.bias)
        # nn.init.normal_(self.encoder, std=0.01)

        # self.activations = nn.Sigmoid()
        
        # self.vision_decoder = nn.Linear(hidden_dim, input_dim)
        # self.text_decoder = nn.Linear(hidden_dim, input_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
        # self.vision_decoder.weight.data.copy_(self.encoder.T)
        # self.text_decoder.weight.data.copy_(self.encoder.T)

        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings, topk):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - topk), dim=1)[0]
        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        

    def encode(self, embeddings, mode='eval'):
        weights = F.normalize(self.encoder, p=2, dim=1)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        embeddings = torch.cdist(embeddings, weights, p=2)
        embeddings = 2 - embeddings
        # embeddings = embeddings @ self.encoder.T
        if mode=='train':
            # embeddings = embeddings + torch.randn_like(embeddings) * 0.1
            # topk = max(1, int(torch.normal(mean=torch.tensor(float(self.topk)), std=torch.tensor(self.topk)).item()))
            topk = self.topk
        else:
            topk = self.topk
        return self.sparsify(embeddings, topk=topk)

    def forward(self, vision_embeddings=None, text_embeddings=None, mode='eval'):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings, mode=mode)
            recon_vision_embeddings = self.decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings, mode=mode)
            recon_text_embeddings = self.decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t


class VL_SAE_DIS(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.activations = nn.Sigmoid()
        
        self.vision_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            # nn.ReLU(),
            # # nn.Dropout(dropout),  
            # nn.Linear(input_dim, input_dim),
            )
        self.text_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            # nn.ReLU(),
            # # nn.Dropout(dropout),  
            # nn.Linear(input_dim, input_dim),
            )


        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - self.topk), dim=1)[0]
        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        

    def encode(self, embeddings):
        embeddings = torch.cdist(embeddings, self.encoder.weight, p=2)
        return self.sparsify(self.activations(-embeddings))

    def forward(self, vision_embeddings=None, text_embeddings=None):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings)
            recon_vision_embeddings = self.vision_decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings)
            recon_text_embeddings = self.text_decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t

class VL_SAE_COS_LLaVA(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0):
        super().__init__()
        self.encoder = nn.Parameter(torch.randn(hidden_dim, input_dim))
        self.alpha = nn.Parameter(torch.ones(hidden_dim))
        self.beta = nn.Parameter(torch.zeros(hidden_dim))
        self.activations = nn.GELU()
        
        self.vision_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, input_dim),
            )
        self.text_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, input_dim),
            )


        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.training = False

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - self.topk), dim=1)[0]
        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
        return embeddings
        
    def encode(self, embeddings):
        weights = F.normalize(self.encoder, p=2, dim=-1)
        embeddings = F.normalize(embeddings, p=2, dim=-1)
        embeddings = torch.cdist(embeddings, weights, p=2)
        embeddings = 2 - embeddings
        return self.sparsify(self.activations(embeddings*self.alpha+self.beta))

    def forward(self, vision_embeddings=None, text_embeddings=None):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings)
            recon_vision_embeddings = self.vision_decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings)
            recon_text_embeddings = self.text_decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings

class VL_SAE_CON(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0):
        super().__init__()
        # self.res_block = nn.Sequential(
        #     nn.Linear(input_dim, input_dim),
        #     nn.GELU(),
        #     nn.Linear(input_dim, input_dim),
        # )
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
        )
        self.norm = nn.LayerNorm(hidden_dim)
        # self.activations = nn.ReLU()
        
        self.vision_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            )
        self.text_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            )


        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.training = False

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat, k=(self.hidden_dim - self.topk), dim=1)[0]
        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings


    def encode(self, embeddings):
        # embeddings = embeddings + self.res_block(embeddings)
        embeddings = self.encoder(embeddings)
        # embeddings = self.norm(embeddings)
        return self.sparsify(embeddings)

    def forward(self, vision_embeddings=None, text_embeddings=None):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings)
            recon_vision_embeddings = self.vision_decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings)
            recon_text_embeddings = self.text_decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t

class SAE_D(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0.1):
        super().__init__()
        # self.v_encoder = nn.Parameter(torch.randn(hidden_dim, input_dim))
        # self.v_bias = nn.Parameter(torch.randn(hidden_dim))
        self.v_encoder = nn.Linear(input_dim, hidden_dim)
        self.activations = nn.ReLU()
        
        # self.t_encoder = nn.Parameter(torch.randn(hidden_dim, input_dim))
        # self.t_bias = nn.Parameter(torch.randn(hidden_dim))
        self.t_encoder = nn.Linear(input_dim, hidden_dim)
        self.vision_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            # nn.ReLU(),
            # nn.Dropout(dropout),  
            # nn.Linear(input_dim, input_dim),
            )
        self.text_decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            # nn.ReLU(),
            # nn.Dropout(dropout),  
            # nn.Linear(input_dim, input_dim),
            )


        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat.float(), k=(self.hidden_dim - self.topk), dim=1)[0]

        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        

    def encode_v(self, embeddings):
        return self.sparsify(self.activations(self.v_encoder(embeddings)))
        # weights = F.normalize(self.v_encoder, p=2, dim=1)
        # embeddings = F.normalize(embeddings, p=2, dim=1)
        # embeddings = torch.cdist(embeddings, weights, p=2)
        # embeddings = 2 - embeddings
        # return self.sparsify(self.activations(embeddings))
    
    def encode_t(self, embeddings):
        return self.sparsify(self.activations(self.t_encoder(embeddings)))

        # weights = F.normalize(self.t_encoder, p=2, dim=1)
        # embeddings = F.normalize(embeddings, p=2, dim=1)
        # embeddings = torch.cdist(embeddings, weights, p=2)
        # embeddings = 2 - embeddings
        # return self.sparsify(self.activations(embeddings))

    def forward(self, vision_embeddings=None, text_embeddings=None):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode_v(vision_embeddings)
            recon_vision_embeddings = self.vision_decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode_t(text_embeddings)
            recon_text_embeddings = self.text_decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t

class SAE_V(nn.Module):
    def __init__(self, input_dim, hidden_dim, topk=32, dropout=0.1):
        super().__init__()
        # self.encoder = nn.Parameter(torch.randn(hidden_dim, input_dim))
        # self.bias = nn.Parameter(torch.randn(hidden_dim))
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.activations = nn.ReLU()
        
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            #nn.ReLU(),
            #nn.Dropout(dropout),  
            #nn.Linear(input_dim, input_dim),
            )

        self.topk = topk
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

    def sparsify(self, embeddings):
        abs_feat = torch.abs(embeddings)
        thres = torch.kthvalue(abs_feat.float(), k=(self.hidden_dim - self.topk), dim=1)[0]

        sub = abs_feat - thres.unsqueeze(-1)
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        one_sub = torch.ones_like(n_sub)
        n_sub = torch.where(n_sub != 0, one_sub, n_sub)
        embeddings = embeddings * n_sub
   
        return embeddings
        

    def encode(self, embeddings):
        # weights = F.normalize(self.encoder, p=2, dim=1)
        # embeddings = F.normalize(embeddings, p=2, dim=1)
        # embeddings = torch.cdist(embeddings, weights, p=2)
        # embeddings = 2 - embeddings
        return self.sparsify(self.activations(self.encoder(embeddings)))

    def forward(self, vision_embeddings=None, text_embeddings=None, mode='eval'):
        recon_vision_embeddings = None
        recon_text_embeddings = None
        latent_v = None
        latent_t = None
        if vision_embeddings is not None:
            latent_v = self.encode(vision_embeddings)
            recon_vision_embeddings = self.decoder(latent_v)
        if text_embeddings is not None:
            latent_t = self.encode(text_embeddings)
            recon_text_embeddings = self.decoder(latent_t)
        return recon_vision_embeddings, recon_text_embeddings, latent_v, latent_t