import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.distributions import Normal
from math import sqrt, pi
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader

def gaussian_density(x):
    return 1.0 / sqrt(2*pi) * torch.exp(-0.5*torch.square(x))
normal_dist = Normal(loc=0.0, scale=1.0)

def normalize_individual_image(img):
    # Flatten the image to compute mean and std
    img_flat = img.view(-1)
    mean = img_flat.mean()
    std = img_flat.std()
    
    # Normalize the image
    img_normalized = (img_flat - mean) / (std + 1e-5)  # Adding a small constant to avoid division by zero
    return img_normalized.view(img.size())  # Reshape it back to the original shape


# Transformation for MNIST data (normalize with mean and std of MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_individual_image)
    #transforms.Normalize((0.1307,), (0.3081,))  # Normalize using MNIST mean and std
])

# Load MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)  # Evaluate on entire test set at once

class FullyConnectedNet(nn.Module):
    def __init__(self,width,Beta,W):
        super(FullyConnectedNet, self).__init__()
        # Only one fully-connected layer
        self.fc1 = nn.Linear(28*28, width)  # Input is 784 features (28x28), output features = 1000 
        self.fc2 = nn.Linear(width, 9)  # Input is 784 features (28x28), output features = 1000
        self.fc1.weight = nn.Parameter(Beta)
        self.fc2.weight = nn.Parameter(W)
        self.width = width
        
    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image to a vector
        x = self.fc2(F.relu(self.fc1(x)))/self.width  # Fully connected layer
        zeros = torch.zeros(x.shape[0], 1, device=x.device)  # Ensure zeros are on the same device as x
        x = torch.cat((x, zeros), dim=1)
        return x
    
class GMMNet(nn.Module):
    def __init__(self,K=20,gamma=1):
        super(GMMNet, self).__init__()
        # Only one fully-connected layer
        #self.fc1 = nn.Linear(28*28, 1000)  # Input is 784 features (28x28), output features = 1000 
        #self.fc2 = nn.Linear(1000, 10)  # Input is 784 features (28x28), output features = 1000
        self.K = 20
        self.L = 10
        self.d = 28*28
        self.mu_all = nn.Parameter(gamma*torch.randn(self.d,self.K))
        self.sq_cov_all = nn.Parameter(gamma*torch.ones(self.d,self.K))
        self.U_all = nn.Parameter(gamma*torch.randn(self.L-1,self.d,self.K))
        self.V_all = nn.Parameter(gamma*torch.randn(self.L-1,self.K))

    def forward(self, x):
        X = x.view(-1, 28*28)  # Flatten the image to a vector
        d = self.d
        K = self.K
        L = self.L
        for k in range(K):
            mu = self.mu_all[:,k].reshape([d,1])
            sq_cov = self.sq_cov_all[:,k].reshape([d,1])
            U = self.U_all[:,:,k]
            V = self.V_all[:,k].reshape([L-1,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(X,mu) # n * K
            sig_A = torch.sqrt(torch.sum(X * X * cov.T, axis = 1))      
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1))
            sig_B = sig_B.reshape([L-1,1]) # (L-1) * 1
            rho = torch.matmul(X * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN = fcn / K
            else:
                FCN += fcn / K
        # Append a zero (constant) column to make it compatible with 10 classes
        zeros = torch.zeros(FCN.shape[0], 1, device=FCN.device)  # Ensure zeros are on the same device as x
        FCN = torch.cat((FCN, zeros), dim=1)
            
        return FCN#F.log_softmax(FCN, dim=1)

gamma_cand = np.array([1/2])
num_gamma = gamma_cand.shape[0]
width_cand = np.arange(200,10001,200)
num_width = width_cand.shape[0]
K = 20
num_epochs = 20
nrep = 5
test_err_GMM = np.zeros([10*num_epochs])
test_err_NN = np.zeros([num_width,nrep])

for i in range(num_gamma):
    gamma = gamma_cand[i]
    
    # GMM step
    model = GMMNet(K=K,gamma=gamma)
    criterion = nn.CrossEntropyLoss()
    
    # Training hyperparameters
    learning_rate = 0.1
        
    # Training loop
    counter = 0
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            # Zero gradients manually
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.zero_()
        
            # Forward pass
            output = model(data)
        
            # Calculate loss
            loss = criterion(output, target)
        
            # Backward pass
            loss.backward()
        
            # Manual SGD: update weights
            with torch.no_grad():
                for param in model.parameters():
                    param -= learning_rate * param.grad
            
            # Print training status
            if batch_idx % 100 == 0:
                with torch.no_grad():
                    test_loss = 0.0
                    correct = 0
                    total = 0
                    for test_data, test_target in test_loader:
                        test_output = model(test_data)
                        test_loss += criterion(test_output, test_target).item()
                        _, predicted = test_output.max(1)
                        total += test_target.size(0)
                        correct += predicted.eq(test_target).sum().item()

                    test_error = 1 - correct / total
                    print(f'Gamma: {i}, GMM, Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item()}, Test Error: {test_error}')
                    test_err_GMM[counter] = test_error
                    counter += 1
    
    # record the distribution
    mu_all = model.mu_all.clone().detach()
    sq_cov_all = model.sq_cov_all.clone().detach()
    U_all = model.U_all.clone().detach()
    V_all = model.V_all.clone().detach()
    
    for i_width in range(num_width):
        for rep in range(nrep):            
            # Initialize parameters for NN
            width = width_cand[i_width]
            latent_class = np.random.randint(0, K, size=width)
            latent_class = latent_class.tolist()
            Beta = torch.randn(28*28,width)
            Beta = mu_all[:,latent_class] + sq_cov_all[:,latent_class] * Beta
            W = torch.zeros(width,9)
            for j in range(width):
                W[j,:] = torch.matmul(U_all[:,:,latent_class[j]],Beta[:,j])+V_all[:,latent_class[j]]
    
            # NN step
            model = FullyConnectedNet(width,Beta.T,W.T)
            criterion = nn.CrossEntropyLoss()
    
            with torch.no_grad():
                test_loss = 0.0
                correct = 0
                total = 0
                for test_data, test_target in test_loader:
                    test_output = model(test_data)
                    test_loss += criterion(test_output, test_target).item()
                    _, predicted = test_output.max(1)
                    total += test_target.size(0)
                    correct += predicted.eq(test_target).sum().item()

                test_error = 1 - correct / total
                test_err_NN[i_width,rep] = test_error
        print(f'Width: {width}, Test Error: {np.mean(test_err_NN[i_width,:])}')
                
np.save('MNIST_subsample/MNIST_GMM_subsample.npy',test_err_GMM)
np.save('MNIST_subsample/MNIST_NN_subsample.npy',test_err_NN)
