import torch.nn as nn
import torch
# import faiss
from sklearn.cluster import KMeans, MiniBatchKMeans
import time

class KMeansTorch:
    def __init__(self, n_clusters, max_iter=100, tol=1e-4):
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.tol = tol

    def fit(self, X):
        B, N, D = X.shape
        centroids = X[:, torch.randint(0, N, (self.n_clusters,))]  # ����������
        self.closest_points = []  # ����������������������

        for _ in range(self.max_iter):
            distances = torch.cdist(X, centroids, p=2)  # [B, N, n_clusters]
            labels = torch.argmin(distances, dim=2)  # [B, N]
            new_centroids = torch.stack([
                X[batch_idx].index_select(0, labels[batch_idx]).mean(dim=0)
                for batch_idx in range(B)
            ])

            # ��������������
            closest_points = []
            for batch_idx in range(B):
                batch_distances = distances[batch_idx]  # [N, n_clusters]
                closest_point_indices = torch.argmin(batch_distances, dim=0)  # [n_clusters]
                closest_points.append(X[batch_idx][closest_point_indices])  # [n_clusters, D]
            self.closest_points.append(torch.stack(closest_points, dim=0))  # [B, n_clusters, D]

            # ��������
            if torch.norm(new_centroids - centroids) < self.tol:
                break
            centroids = new_centroids

        self.centroids = centroids

    def predict(self, X):
        distances = torch.cdist(X, self.centroids, p=2)
        return torch.argmin(distances, dim=2)

def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=True)


import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, num_features, reduction=4):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(num_features, num_features // reduction, 1, padding=0, bias=True),
            nn.PReLU(),
            nn.Conv2d(num_features // reduction, num_features, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y


class DenseLayer(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DenseLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
        self.relu = nn.PReLU()  # ���� PReLU

    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        return out


class RDB_CA(nn.Module):
    """
    ������ �������� (RDB) + ���������� (CA) + PReLU ������
    """

    def __init__(self, in_channels, growth_rate=64, num_layers=4, res_scale=0.2):
        super(RDB_CA, self).__init__()
        self.res_scale = res_scale
        self.num_layers = num_layers

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(DenseLayer(in_channels + i * growth_rate, growth_rate))

        self.lff = nn.Conv2d(in_channels + num_layers * growth_rate, in_channels, kernel_size=1, bias=True)

        # self.ca = ChannelAttention(in_channels)

    def forward(self, x):
        inputs = [x]
        for layer in self.layers:
            out = layer(torch.cat(inputs, 1))
            inputs.append(out)


        fused = self.lff(torch.cat(inputs, 1))


        att_out =fused
        return att_out * self.res_scale + x
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, res_scale=1):
        super(ResBlock, self).__init__()
        self.res_scale = res_scale
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.relu = nn.LeakyReLU(0.1, inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)

    def forward(self, x):
        x1 = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = out * self.res_scale + x1
        return out


def compute_channel_similarity(tensor):

    B, C, H, W = tensor.shape
    similarity_matrices = []

    for b in range(B):
        # [C, H*W]
        flat_tensor = tensor[b].view(C, -1)  # [C, H*W]
        #
        flat_tensor = flat_tensor / (flat_tensor.norm(dim=1, keepdim=True) + 1e-8)  # [C, H*W]
        # [C, C]
        similarity_matrix = torch.matmul(flat_tensor, flat_tensor.T)  # [C, C]
        similarity_matrices.append(similarity_matrix)

    return torch.stack(similarity_matrices, dim=0)  # [B, C, C]

class channel_clustering(nn.Module):
    def __init__(self, n_clusters,speKmeans):
        super(channel_clustering, self).__init__()
        self.n_clusters = n_clusters
        self.kmeans = speKmeans
    def forward(self, tensor):
        similarity_matrices = compute_channel_similarity(tensor)
        # B, C, _ = similarity_matrices.shape
        B, C, H, W = tensor.shape
        # cluster_labels = []
        cluster_centers = []

        for b in range(B):
            #  [C, C]
            # similarity_matrix = similarity_matrices[b].detach().cpu().numpy()
            similarity_matrix = similarity_matrices[b].detach().cpu().numpy()

            # ���� K-means ����
            # kmeans = KMeans(n_clusters=self.n_clusters, random_state=0)
            # kmeans.fit_predict(similarity_matrix)  # [C]
            kmeans = MiniBatchKMeans(n_clusters=self.n_clusters, batch_size=self.n_clusters * 2, random_state=0)
            kmeans.fit(similarity_matrix)
            centroids = torch.tensor(kmeans.cluster_centers_, device=tensor.device)

            centers = torch.matmul(centroids, tensor[b].view(C, -1))  # [n_clusters, H*W]
            centers = centers.view(self.n_clusters, H, W)  # [n_clusters, H, W]
            cluster_centers.append(centers)
        # test = torch.stack(cluster_centers, dim=0)

        return torch.stack(cluster_centers, dim=0)  # [B, n_clusters, H, W]


class spatial_clustering(nn.Module):
    def __init__(self, hw_clusters,spaKmeans):
        super(spatial_clustering, self).__init__()

        self.hw_clusters = hw_clusters
        self.kmeans=spaKmeans

    def forward(self, tensor):
        B, K_c, H, W = tensor.shape
        core_tensor = []
        for b in range(B):
            # ������������ [K_c, H*W]
            flat_tensor = tensor[b].reshape(K_c, -1).T  # [H*W, K_c]

            # kmeans = KMeans(n_clusters=self.hw_clusters, random_state=0)
            # kmeans = MiniBatchKMeans(n_clusters=self.hw_clusters, batch_size=self.hw_clusters*2, random_state=0)
            self.kmeans.fit(flat_tensor.detach().cpu().numpy())
            # kmeans = KMeansTorch(n_clusters=self.hw_clusters, max_iter=100, tol=1e-4)
            # kmeans.fit(flat_tensor)
            # spatial_clusters = kmeans.centroids

            # �������� [n_clusters, K_c]
            spatial_clusters = torch.tensor(self.kmeans.cluster_centers_, device=tensor.device)  # [n_clusters, K_c]
            core_tensor.append(spatial_clusters.T)  # [K_c, n_clusters]

        return torch.stack(core_tensor, dim=0)  # [B, K_c, n_clusters]




class tensor_decomposition(nn.Module):
    def __init__(self, n_clusters, hw_clusters,speKmeans,spaKmeans):
        super(tensor_decomposition, self).__init__()

        self.channel_clustering = channel_clustering(n_clusters,speKmeans)
        self.spatial_clustering = spatial_clustering(hw_clusters,spaKmeans)

        self.channel_adjust = nn.Sequential(
            nn.Linear(n_clusters, n_clusters, bias=True),
            nn.LeakyReLU(0.1),
            # nn.Sigmoid(),
            nn.Linear(n_clusters, n_clusters, bias=True),
            nn.Sigmoid()
        )

        self.spatial_adjust = nn.Sequential(
            nn.Linear(hw_clusters, hw_clusters, bias=True),
            nn.LeakyReLU(0.1),
            # nn.Sigmoid(),
            nn.Linear(hw_clusters, hw_clusters, bias=True),
            nn.Sigmoid()
        )

    def forward(self, tensor):
        B, C, H, W = tensor.shape
        # 1. ��������
        channel_clustered = self.channel_clustering(tensor)  # [B, K_c, H, W]
        _, k_c, _, _ = channel_clustered.shape

        flat_LR = tensor.view(B, C, -1)
        attn_channel = torch.matmul(flat_LR, channel_clustered.view(B, k_c, -1).transpose(1, 2))
        attn_channel = attn_channel / torch.sqrt(torch.tensor(k_c, dtype=torch.float32))
        attn_channel = self.channel_adjust(attn_channel)
        attn_channel = attn_channel.softmax(dim=1)

        # 2. ��������
        core_tensor = self.spatial_clustering(channel_clustered)  # [B, K_c, K_s]
        _, _, k_s = core_tensor.shape
        attn_spatial = torch.matmul(channel_clustered.view(B, k_c, -1).permute(0, 2, 1), core_tensor.view(B, k_c, -1).permute(0, 2, 1).transpose(1, 2))
        attn_spatial = attn_spatial / torch.sqrt(torch.tensor(k_s, dtype=torch.float32))
        attn_spatial = self.spatial_adjust(attn_spatial)
        attn_spatial = attn_spatial.softmax(dim=1)

        return core_tensor, attn_channel, attn_spatial

class attn_compute(nn.Module):
    def __init__(self, n_clusters, hw_clusters):
        super(attn_compute, self).__init__()

        self.spatial_adjust = nn.Sequential(
            nn.Linear(hw_clusters, hw_clusters, bias=True),
            nn.LeakyReLU(0.1),
            # nn.Sigmoid(),
            nn.Linear(hw_clusters, hw_clusters, bias=True),
            nn.Sigmoid()
        )

    def forward(self, tensor, core, attn_channel):
        B, C, H, W = tensor.shape
        # test1 = tensor.view(B, C, -1).permute(0, 2, 1)
        # test2 = attn_channel.transpose(1, 2)
        LR_C = torch.matmul(tensor.view(B, C, -1).permute(0, 2, 1), attn_channel)

        # LR_C = LR_C.view(B, C, -1)
        _, _, k_s = core.shape
        attn_spatial = torch.matmul(LR_C, core)
        attn_spatial = attn_spatial / torch.sqrt(torch.tensor(k_s, dtype=torch.float32))
        attn_spatial = self.spatial_adjust(attn_spatial)
        attn_spatial = attn_spatial.softmax(dim=1)

        output_hw = torch.matmul(core, attn_spatial.transpose(-2, -1))

        output_c = torch.matmul(output_hw.permute(0, 2, 1), attn_channel.transpose(-2, -1)).permute(0, 2, 1)

        output_c = output_c.view(B, C, H, W) + tensor

        return output_c



class tensor_reconstruction(nn.Module):
    def __init__(self, n_clusters, hw_clusters,speKmeans,spaKmeans):
        super(tensor_reconstruction, self).__init__()
        self.decomposition = tensor_decomposition(n_clusters, hw_clusters,speKmeans,spaKmeans)
        # self.speKmeans = speKmeans
        # self.spaKmeans = spaKmeans

    def forward(self, LR):
        B, C, H, W = LR.shape

        core, attn_channel, attn_spatial = self.decomposition(LR)
        # b, c, s = core.shape

        output_hw = torch.matmul(core, attn_spatial.transpose(-2, -1))

        output_c = torch.matmul(output_hw.permute(0,2,1), attn_channel.transpose(-2, -1)).permute(0,2,1)

        output_c = output_c.view(B, C, H, W) + LR

        return output_c, core, attn_channel
import math
class Downsample(nn.Sequential):
    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, num_feat // 4, 3, 1, 1))
                m.append(nn.PixelUnshuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, num_feat // 9, 3, 1, 1))
            m.append(nn.PixelUnshuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
        super(Downsample, self).__init__(*m)


class Upsample(nn.Sequential):
    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)
class main(nn.Module):
    def __init__(self, in_channels,n_clusters=3,hw_clusters=16):
        super(main, self).__init__()

        self.speKmeans = MiniBatchKMeans(n_clusters=n_clusters, batch_size=in_channels, random_state=42)
        self.spaKmeans = MiniBatchKMeans(n_clusters=hw_clusters, batch_size=256, random_state=42)

        self.reconstruction1 = tensor_reconstruction(n_clusters, hw_clusters,self.speKmeans,self.spaKmeans)
        self.reconstruction2 = attn_compute(n_clusters, hw_clusters)
        self.reconstruction3 = attn_compute(n_clusters, hw_clusters)



        self.RBS = nn.ModuleList()
        for _ in range(3):
            self.RBS.append(nn.Sequential(
                # RDB_CA(in_channels, in_channels),
                # # RDB_CA(in_channels, in_channels),
                ResBlock(in_channels, in_channels, 1),
                ResBlock(in_channels, in_channels, 1),
            ))

        self.upsampleX2 = nn.Upsample(scale_factor=2, mode='bicubic')
        self.upsampleX4 = nn.Upsample(scale_factor=4, mode='bicubic')

        self.final_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(3, 3), padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(3, 3), padding=1))

    def forward(self, LR):
        LRX2 = self.upsampleX2(LR)
        LRX4 = self.upsampleX4(LR)

        LR_re, core, attn_channel = self.reconstruction1(LR)
        LR_re = self.RBS[0](LR_re)

        LR_reX2 = self.upsampleX2(LR_re) + LRX2

        LR_re2 = self.reconstruction2(LR_reX2, core, attn_channel)
        LR_re2 = self.RBS[1](LR_re2)

        LR_reX4 = self.upsampleX2(LR_re2) + LRX4

        LR_re3 = self.reconstruction3(LR_reX4, core, attn_channel)
        LR_re3 = self.RBS[2](LR_re3)

        final_out = self.final_conv(LR_re3) + LRX4

        return final_out, LR_re2, LR_re




if __name__ == "__main__":
    # net = Global_Tucker(rank=[1,5,8,8]).cuda()  # ????????????????
    net = main(in_channels=102).cuda()  # ????????????????
    # net = tensor_reconstruction(5, 64).cuda()  # ????????????????
    from thop import profile
    input1 = torch.randn(1, 102, 16, 16).cuda()  # ????????????????????????????????????????????????????tensor
    input2 = torch.randn(1, 102, 64, 64).cuda()  # ????????????????????????????????????????????????????tensor

    startTime = time.time()
    flops, params = profile(net, inputs=(input1,))
    endTime = time.time()
    print(f"Took {round((endTime - startTime), 5)} seconds to calculate.")
    total = sum([param.nelement() for param in net.parameters()])
    print('   Number of params: %.4fM' % (total / 1e6))
    print('   Number of FLOPs: %.4fGFLOPs' % (flops / 1e9))