import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import normalize

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y.expand_as(x)

class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.se = SEBlock(out_channels)
        self.downsample = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride),
            nn.BatchNorm1d(out_channels)
        ) if in_channels != out_channels or stride != 1 else None

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.se(out)  
        
        if self.downsample is not None:
            residual = self.downsample(x)
            
        return F.relu(out + residual)


class UNet1D_Multimodal(nn.Module):
    def __init__(self, modal_num, feature_dim, base_ch=32, depth=4):
        super().__init__()
        self.depth = depth
        self.modal_num = modal_num
        self.feature_dim = feature_dim
        
        
        self.init_conv = nn.Sequential(
            nn.Conv1d(modal_num, base_ch, kernel_size=7, padding=3),
            nn.BatchNorm1d(base_ch),
            nn.ReLU(inplace=True)
        )
        
        
        self.encoders = nn.ModuleList()
        self.pools = nn.ModuleList()
        ch = base_ch
        
        
        self.encoder_channels = [base_ch]
        
        for _ in range(depth):
            out_ch = ch * 2
            self.encoders.append(ResidualConvBlock(ch, out_ch))
            self.pools.append(nn.MaxPool1d(2))
            self.encoder_channels.append(out_ch)  
            ch = out_ch
            
        
        self.bottleneck = ResidualConvBlock(ch, ch)
        
        
        self.decoders = nn.ModuleList()
        self.upconvs = nn.ModuleList()
        
        
        self.decoder_input_channels = []
        
        for i in range(depth):
            skip_ch = self.encoder_channels[-(i+1)]  
            up_ch = ch // 2  
            
            
            decoder_in_ch = up_ch + skip_ch
            self.decoder_input_channels.insert(0, decoder_in_ch)
            
            self.upconvs.append(nn.ConvTranspose1d(ch, up_ch, kernel_size=4, stride=2, padding=1))
            
            
            self.decoders.append(ResidualConvBlock(decoder_in_ch, up_ch))
            ch = up_ch
            
        
        self.final_conv = nn.Sequential(
            nn.Conv1d(base_ch, modal_num, kernel_size=3, padding=1),  
            nn.BatchNorm1d(modal_num),
            nn.ReLU(inplace=True)
        )
        
        
        self.modal_weights = nn.Parameter(torch.ones(modal_num) / modal_num)

    def forward(self, x):
        weighted_x = x * self.modal_weights[None, :, None]
        
        
        x0 = self.init_conv(weighted_x)
        
        
        enc_features = []
        x = x0
        for i in range(self.depth):
            x = self.encoders[i](x)
            enc_features.append(x)
            x = self.pools[i](x)
            
        
        x = self.bottleneck(x)
        
        for i in range(self.depth):
            x = self.upconvs[i](x)
            skip = enc_features[-(i+1)]
            
            if x.shape[-1] != skip.shape[-1]:
                x = F.interpolate(x, size=skip.shape[-1])
            
            
            x = torch.cat([x, skip], dim=1)
            
            
            x = self.decoders[i](x)
        
        x = x + x0[:, :, :x.shape[-1]]  
        
        out = self.final_conv(x)
        
        out = out.reshape(out.size(0), -1)
        return out


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

class Encoder(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 2000),
            nn.ReLU(),
            nn.Linear(2000, feature_dim),
        )

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(feature_dim, 2000),
            nn.ReLU(),
            nn.Linear(2000, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, input_dim)
        )
    def forward(self, x):
        return self.decoder(x)

class THCRL(nn.Module):
    def __init__(self, view, input_size, low_feature_dim, high_feature_dim, device):
        super(THCRL, self).__init__()
        self.encoders = []
        self.decoders = []
        for v in range(view):
            self.encoders.append(Encoder(input_size[v], low_feature_dim).to(device))
            self.decoders.append(Decoder(input_size[v], low_feature_dim).to(device))
        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)
        self.Specific_view = nn.Sequential(
            nn.Linear(low_feature_dim, high_feature_dim),
        )
        self.Common_view = nn.Sequential(
            nn.Linear(low_feature_dim*view, high_feature_dim),
        )
        self.view = view

        self.fusion = UNet1D_Multimodal(modal_num=view, feature_dim=low_feature_dim)

        self.ffn = FeedForward(d_model=low_feature_dim*view, d_ff=256)
        
    def forward(self, xs):
        xrs = []
        zs = []
        hs = []
        for v in range(self.view):
            x = xs[v]
            z = self.encoders[v](x)
            h = normalize(self.Specific_view(z), dim=1)
            xr = self.decoders[v](z)
            hs.append(h)
            zs.append(z)
            xrs.append(xr)
        return xrs, zs, hs

    def DSHF(self, xs):
        zs = []
        Alist = []
        for v in range(self.view):
            x = xs[v]
            A = self.computeA(F.normalize(x), mode='knn')
            Alist.append(A)
            z = self.encoders[v](x)
            zs.append(z)
            
        
        commonz = torch.stack(zs, dim=1)
        
        
        z_refined = self.fusion(commonz)
        
        
        z_refined = self.ffn(z_refined)
        
        z_refined = normalize(self.Common_view(z_refined), dim=1)
        
        return z_refined, torch.mean(torch.stack(Alist), dim=0)

    def computeA(self, x, mode):
        if mode == 'cos':
            a = F.normalize(x, p=2, dim=1)
            b = F.normalize(x.T, p=2, dim=0)
            A = torch.mm(a, b)
            A = (A + 1) / 2
        if mode == 'kernel':
            x = torch.nn.functional.normalize(x, p=1.0, dim=1)
            a = x.unsqueeze(1)
            A = torch.exp(-torch.sum(((a - x.unsqueeze(0)) ** 2) * 1000, dim=2))
        if mode == 'knn':
            dis2 = (-2 * x.mm(x.t())) + torch.sum(torch.square(x), axis=1, keepdim=True) + torch.sum(
                torch.square(x.t()), axis=0, keepdim=True)
            A = torch.zeros(dis2.shape).cuda()
            A[(torch.arange(len(dis2)).unsqueeze(1), torch.topk(dis2, 10, largest=False).indices)] = 1
            A = A.detach()
        if mode == 'sigmod':
            A = 1/(1+torch.exp(-torch.mm(x, x.T)))
        return A