import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.distributions import Normal
from math import sqrt, pi
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

def gaussian_density(x):
    return 1.0 / sqrt(2*pi) * torch.exp(-0.5*torch.square(x))
normal_dist = Normal(loc=0.0, scale=1.0)

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def normalize_individual_image(img):
    # Flatten the image to compute mean and std
    img_flat = img.view(-1)
    mean = img_flat.mean()
    std = img_flat.std()
    
    # Normalize the image
    img_normalized = (img_flat - mean) / (std + 1e-5)  # Adding a small constant to avoid division by zero
    return img_normalized.view(img.size())  # Reshape it back to the original shape


# Transformation for MNIST data (normalize with mean and std of MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_individual_image)
    #transforms.Normalize((0.1307,), (0.3081,))  # Normalize using MNIST mean and std
])

# Load MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)  # Evaluate on entire test set at once

class FullyConnectedNet(nn.Module):
    def __init__(self,width,Beta,W):
        super(FullyConnectedNet, self).__init__()
        # Only one fully-connected layer
        self.fc1 = nn.Linear(28*28, width)  # Input is 784 features (28x28), output features = 1000 
        self.fc2 = nn.Linear(width, 9)  # Input is 784 features (28x28), output features = 1000
        self.fc1.weight = nn.Parameter(Beta)
        self.fc2.weight = nn.Parameter(W)
        
    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image to a vector
        x = self.fc2(F.relu(self.fc1(x)))/width  # Fully connected layer
        zeros = torch.zeros(x.shape[0], 1, device=x.device)  # Ensure zeros are on the same device as x
        x = torch.cat((x, zeros), dim=1)
        return x
    

class GMMNet(nn.Module):
    def __init__(self,K=10,gamma=1):
        super(GMMNet, self).__init__()
        # Only one fully-connected layer
        #self.fc1 = nn.Linear(28*28, 1000)  # Input is 784 features (28x28), output features = 1000 
        #self.fc2 = nn.Linear(1000, 10)  # Input is 784 features (28x28), output features = 1000
        self.K = K
        self.L = 10
        self.d = 28*28
        self.mu_all = nn.Parameter(gamma*torch.randn(self.d,self.K))
        self.sq_cov_all = nn.Parameter(gamma*torch.ones(self.d,self.K))
        self.U_all = nn.Parameter(gamma*torch.randn(self.L-1,self.d,self.K))
        self.V_all = nn.Parameter(gamma*torch.randn(self.L-1,self.K))

    def forward(self, x):
        X = x.view(-1, 28*28)  # Flatten the image to a vector
        d = self.d
        K = self.K
        L = self.L
        for k in range(K):
            mu = self.mu_all[:,k].reshape([d,1])
            sq_cov = self.sq_cov_all[:,k].reshape([d,1])
            U = self.U_all[:,:,k]
            V = self.V_all[:,k].reshape([L-1,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(X,mu) # n * K
            sig_A = torch.sqrt(torch.sum(X * X * cov.T, axis = 1))      
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1))
            sig_B = sig_B.reshape([L-1,1]) # (L-1) * 1
            rho = torch.matmul(X * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN = fcn / K
            else:
                FCN += fcn / K
        # Append a zero (constant) column to make it compatible with 10 classes
        zeros = torch.zeros(FCN.shape[0], 1, device=FCN.device)  # Ensure zeros are on the same device as x
        FCN = torch.cat((FCN, zeros), dim=1)
            
        return FCN

class GMMNet2(nn.Module):
    def __init__(self,K=10,m=100,gamma=[1/2,1/2]):
        super(GMMNet2, self).__init__()
        # Only one fully-connected layer
        #self.fc1 = nn.Linear(28*28, 1000)  # Input is 784 features (28x28), output features = 1000 
        #self.fc2 = nn.Linear(1000, 10)  # Input is 784 features (28x28), output features = 1000
        self.K = 20
        self.L = 10
        self.d = 28*28
        self.m = m
        # First layer
        self.mu_all_1 = nn.Parameter(gamma[0]*torch.randn(self.d,self.K))
        self.sq_cov_all_1 = nn.Parameter(gamma[0]*torch.ones(self.d,self.K))
        self.U_all_1 = nn.Parameter(torch.randn(self.m,self.d,self.K))
        self.V_all_1 = nn.Parameter(gamma[0]*torch.randn(self.m,self.K))
        # Second layer
        self.mu_all_2 = nn.Parameter(gamma[1]*torch.randn(self.m,self.K))
        self.sq_cov_all_2 = nn.Parameter(gamma[1]*torch.ones(self.m,self.K))
        self.U_all_2 = nn.Parameter(torch.randn(self.L-1,self.m,self.K))
        self.V_all_2 = nn.Parameter(gamma[1]*torch.randn(self.L-1,self.K))

    def forward(self, x):
        X = x.view(-1, 28*28)  # Flatten the image to a vector
        d = self.d
        K = self.K
        L = self.L
        m = self.m
        for k in range(K):
            mu = self.mu_all_1[:,k].reshape([d,1])
            sq_cov = self.sq_cov_all_1[:,k].reshape([d,1])
            U = self.U_all_1[:,:,k]
            V = self.V_all_1[:,k].reshape([m,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(X,mu) # n * 1
            sig_A = torch.sqrt(torch.sum(X * X * cov.T, axis = 1))      
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1))
            sig_B = sig_B.unsqueeze(1) # (L-1) * 1
            rho = torch.matmul(X * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN1 = fcn / K
            else:
                FCN1 += fcn / K
        
        FCN1 = FCN1 / torch.linalg.norm(FCN1, axis = 1).reshape([FCN1.shape[0],1])
                
        for k in range(K):
            mu = self.mu_all_2[:,k].reshape([m,1])
            sq_cov = self.sq_cov_all_2[:,k].reshape([m,1])
            U = self.U_all_2[:,:,k]
            V = self.V_all_2[:,k].reshape([L-1,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(FCN1,mu) # n * K
            sig_A = torch.sqrt(torch.sum(FCN1 * FCN1 * cov.T, axis = 1))      
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1))
            sig_B = sig_B.unsqueeze(1) # (L-1) * 1
            rho = torch.matmul(FCN1 * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN2 = fcn / K
            else:
                FCN2 += fcn / K
                
        # Append a zero (constant) column to make it compatible with 10 classes
        zeros = torch.zeros(FCN2.shape[0], 1, device=FCN2.device)  # Ensure zeros are on the same device as x
        FCN2 = torch.cat((FCN2, zeros), dim=1)
            
        return FCN2#F.log_softmax(FCN, dim=1)

gamma_cand = np.array([2])
num_gamma = gamma_cand.shape[0]
K = 10
num_epochs = 100
nrep = 3
test_err_GMM = np.zeros([nrep,10*num_epochs])
test_err_GMM2 = np.zeros([num_gamma,nrep,10*num_epochs])

for rep in range(nrep):
    for i in range(num_gamma):
        wid = gamma_cand[i]
    # GMM step
    model = GMMNet(gamma=1/2)
    criterion = nn.CrossEntropyLoss()
    
    # Training hyperparameters
    learning_rate = 0.1
        
    # Training loop
    counter = 0
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            # Zero gradients manually
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.zero_()
        
            # Forward pass
            output = model(data)
        
            # Calculate loss
            loss = criterion(output, target)
        
            # Backward pass
            loss.backward()
        
            # Manual SGD: update weights
            with torch.no_grad():
                for param in model.parameters():
                    param -= learning_rate * param.grad
            
            # Print training status
            if batch_idx % 100 == 0:
                with torch.no_grad():
                    test_loss = 0.0
                    correct = 0
                    total = 0
                    for test_data, test_target in test_loader:
                        test_output = model(test_data)
                        test_loss += criterion(test_output, test_target).item()
                        _, predicted = test_output.max(1)
                        total += test_target.size(0)
                        correct += predicted.eq(test_target).sum().item()

                    test_error = 1 - correct / total
                    print(f'Gamma index: {i}, replicate: {rep}, GMM, Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item()}, Test Error: {test_error}')
                    test_err_GMM[rep,counter] = test_error
                    counter += 1
    np.save('MNIST_GMM_two_layers/MNIST_GMM_final_100.npy',test_err_GMM)