import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


class CNN_MNIST(nn.Module):
    def __init__(self, num_classes=60):
        super(CNN_MNIST, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(in_features=64*3*3, out_features=num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.flatten(1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        return x

class CNN_CIFAR(nn.Module):
    def __init__(self, emb_dim=60, dropout=0.2):  # Changed back to 60 to match server's default
        super(CNN_CIFAR, self).__init__()
        
        # First conv block
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Second conv block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Third conv block
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        
        # Global feature pooling
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten(start_dim=1, end_dim=-1)
        
        # Embedding layers
        self.fc1 = nn.Linear(256, emb_dim)  # Output dimension matches server's expected input
        self.bn_fc = nn.BatchNorm1d(emb_dim)
        self.dropout = nn.Dropout(dropout)

    def calc_representation(self, x):
        # First block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.maxpool1(x)
        
        # Second block
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.maxpool2(x)
        
        # Third block
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        
        # Global pooling and embedding
        x = self.global_pool(x)
        x = self.flat(x)
        x = self.fc1(x)
        x = self.bn_fc(x)
        x = self.dropout(x)
        x = F.relu(x)
        return x

    def forward(self, x, vis=False):
        x = F.interpolate(x, (32, 32))
        x = self.calc_representation(x)
        return x
# class CNN_CIFAR(nn.Module):
#     def __init__(self, emb_dim=60, dropout=0.0):
#         super(CNN_CIFAR, self).__init__()
#         self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
#         self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
#         self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
#         self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
#         self.adapt_pool = nn.AdaptiveAvgPool2d((1, 1))
#         self.flat =nn.Flatten(start_dim=1, end_dim=-1)  # size 128 
        

#         self.fc1 = nn.Linear(128, emb_dim, bias=True)
 

#         self.bn1 = nn.BatchNorm1d(emb_dim)

    
#     def calc_representation(self, x):
        
#         x = self.avg_pool(F.relu(self.conv1(x)))
#         x = self.avg_pool(F.relu(self.conv2(x)))
#         x = self.avg_pool(F.relu(self.conv3(x)))
#         x = self.adapt_pool(F.relu(self.conv4(x)))
#         x = F.relu(self.bn1(self.fc1(self.flat(x))))
#         # x = F.relu( self.fc1(self.flat(x)))
#         return x


#     def forward(self, x, vis=False ):
#         x=  F.interpolate(x, (32, 32))
#         x = self.calc_representation(x)

#         return x

def clientModel(args):
    if args.client_arch == 'CNN':
        if args.data == 'MNIST':
                return CNN_MNIST()
        elif args.data == 'CIFAR10':
                return CNN_CIFAR()
    else:
        return torchvision.models.__dict__[args.client_arch](num_classes=128)
    
#     # return mlpModel()