import torch
import torch.nn as nn
import torch.nn.functional as F



class LinearModel(nn.Module):

    def __init__(self, num_classes=10, num_clients=2, dim=128):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(num_clients, num_classes,bias=False)
        #self.classifier = nn.Sequential(
        #    nn.Dropout(),
        #    nn.Linear(256 * 6 * 6 * 4, 4096),
        #    nn.ReLU(inplace=True),
        #    nn.Dropout(),
        #    nn.Linear(4096, 4096),
        #    nn.ReLU(inplace=True),
        #    nn.Linear(4096, num_classes),
        #)

    def forward(self, x):
        pooled_view = self.fc(x)
        return pooled_view

class mlpModel_MNIST(nn.Module):

    def __init__(self, num_classes=10, num_clients=2, emb_dim=60):
        super(mlpModel_MNIST, self).__init__()
        # self.fc = nn.Linear(dim * num_clients, num_classes)
        # self.agg_weights = nn.Parameter(torch.ones(num_clients)/num_clients)
        self.fc1 = nn.Linear(num_clients*emb_dim , emb_dim)
        # self.fc1 = nn.Linear(emb_dim , emb_dim)
        self.fc2 = nn.Linear(emb_dim, num_classes)
        self.dropout1 = nn.Dropout(0.25)
    def forward(self, x):
        # with torch.no_grad():
        #     x = torch.stack(x).mean(dim=0)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class mlpModel_CIFAR(nn.Module):
    def __init__(self, num_classes=10, num_clients=2, emb_dim=60):  # Changed to match client's output
        super(mlpModel_CIFAR, self).__init__()
        input_dim = num_clients * emb_dim
        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.3)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        x = self.dropout1(x)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        
        x = self.dropout2(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        
        x = self.fc3(x)
        return x

class WeightedAgg(nn.Module):
    def __init__(self,  num_clients):
        super(WeightedAgg, self).__init__()
        self.agg_weights = nn.Parameter(torch.ones(num_clients)/num_clients)
    
    def forward(self, x_list ):
        x = x_list[0] * self.agg_weights[0]
        for i in range(1,len(x_list)):
            x+= x_list[i] * self.agg_weights[i]
        return x

def serverModel(args, num_classes=10, num_clients=2, dim=60):
    if args.server_arch == 'linear':
        return LinearModel(num_classes, num_clients, dim)
    elif args.server_arch == 'mlp':
        if args.data == 'CIFAR10':
            return mlpModel_CIFAR(num_classes, num_clients, dim)
        else:
            return mlpModel_MNIST(num_classes, num_clients, dim)
        
    # return mlpModel()