import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


class MLP(nn.Module):
    def __init__(self, num_hidden, num_classes=10):
        super().__init__()
        input_dim = 28 * 28
        self.fc1 = nn.Linear(input_dim, num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_classes)
        self.num_vars = self._calculate_num_vars()  
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def _calculate_num_vars(self):
        return self.fc1.weight.numel() + self.fc1.bias.numel() + self.fc2.weight.numel() + self.fc2.bias.numel()

    def initialize(self, M):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        all_flattened_params = []
        for _ in range(M):
            fc1_weight = nn.init.kaiming_uniform_(torch.empty(self.fc1.weight.shape), nonlinearity='relu').to(device)
            fc1_bias = torch.zeros_like(self.fc1.bias).to(device)
            fc2_weight = nn.init.kaiming_uniform_(torch.empty(self.fc2.weight.shape), nonlinearity='relu').to(device)
            fc2_bias = torch.zeros_like(self.fc2.bias).to(device)

            flattened_params = torch.cat([
                fc1_weight.flatten(),
                fc1_bias.flatten(),
                fc2_weight.flatten(),
                fc2_bias.flatten()
            ]).unsqueeze(0)
            all_flattened_params.append(flattened_params)

        flattened_params_tensor = torch.cat(all_flattened_params, dim=0)

        for param in self.parameters():
            param.data.zero_()

        return flattened_params_tensor

    def forward(self, x):
        if isinstance(x, torch.Tensor) and x.dtype == torch.uint8:
            
            x = x.float().to(self.device)
        elif isinstance(x, torch.Tensor) and x.device.type!= self.device.type:
            x = x.to(self.device)
        elif isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float().to(self.device)
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return torch.softmax(x, dim=1)

    def set_net_param(self, particles):
        old_fc1_weight = self.fc1.weight.data.clone()
        old_fc1_bias = self.fc1.bias.data.clone()
        old_fc2_weight = self.fc2.weight.data.clone()
        old_fc2_bias = self.fc2.bias.data.clone()

        fc1_weight_size = self.fc1.weight.numel()
        fc1_bias_size = self.fc1.bias.numel()
        fc2_weight_size = self.fc2.weight.numel()
        fc2_bias_size = self.fc2.bias.numel()

        fc1_weight = particles[:fc1_weight_size].reshape(self.fc1.weight.shape)
        fc1_bias = particles[fc1_weight_size:fc1_weight_size + fc1_bias_size].reshape(self.fc1.bias.shape)
        fc2_weight = particles[fc1_weight_size + fc1_bias_size:fc1_weight_size + fc1_bias_size + fc2_weight_size].reshape(
            self.fc2.weight.shape)
        fc2_bias = particles[fc1_weight_size + fc1_bias_size + fc2_weight_size:].reshape(self.fc2.bias.shape)

        self.fc1.weight.data = torch.tensor(fc1_weight, dtype=torch.float32, requires_grad=True)
        self.fc1.bias.data = torch.tensor(fc1_bias, dtype=torch.float32, requires_grad=True)
        self.fc2.weight.data = torch.tensor(fc2_weight, dtype=torch.float32, requires_grad=True)
        self.fc2.bias.data = torch.tensor(fc2_bias, dtype=torch.float32, requires_grad=True)

        if (self.fc1.weight.data.equal(old_fc1_weight) and
                self.fc1.bias.data.equal(old_fc1_bias) and
                self.fc2.weight.data.equal(old_fc2_weight) and
                self.fc2.bias.data.equal(old_fc2_bias)):
            print("Note that the network parameters have not been modified!")
        else:
            pass

