import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.distributions import Normal
from math import sqrt, pi
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader

def gaussian_density(x):
    return 1.0 / sqrt(2*pi) * torch.exp(-0.5*torch.square(x))
normal_dist = Normal(loc=0.0, scale=1.0)

def normalize_individual_image(img):
    # Flatten the image to compute mean and std
    img_flat = img.view(-1)
    mean = img_flat.mean()
    std = img_flat.std()
    
    # Normalize the image
    img_normalized = (img_flat - mean) / (std + 1e-5)  # Adding a small constant to avoid division by zero
    return img_normalized.view(img.size())  # Reshape it back to the original shape


# Transformation for MNIST data (normalize with mean and std of MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_individual_image)
    #transforms.Normalize((0.1307,), (0.3081,))  # Normalize using MNIST mean and std
])

# Load FashionMNIST dataset
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)  # Evaluate on entire test set at once

class FullyConnectedNet(nn.Module):
    def __init__(self,width,Beta,W):
        super(FullyConnectedNet, self).__init__()
        # Only one fully-connected layer
        self.fc1 = nn.Linear(28*28, width)  # Input is 784 features (28x28), output features = 1000 
        self.fc2 = nn.Linear(width, 9)  # Input is 784 features (28x28), output features = 1000
        self.fc1.weight = nn.Parameter(Beta)
        self.fc2.weight = nn.Parameter(W)
        
    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image to a vector
        x = self.fc2(F.relu(self.fc1(x)))/width  # Fully connected layer
        zeros = torch.zeros(x.shape[0], 1, device=x.device)  # Ensure zeros are on the same device as x
        x = torch.cat((x, zeros), dim=1)
        return x

class GMMNet(nn.Module):
    def __init__(self,K=10,gamma=1):
        super(GMMNet, self).__init__()
        # Only one fully-connected layer
        #self.fc1 = nn.Linear(28*28, 1000)  # Input is 784 features (28x28), output features = 1000 
        #self.fc2 = nn.Linear(1000, 10)  # Input is 784 features (28x28), output features = 1000
        self.K = K
        self.L = 10
        self.d = 28*28
        self.mu_all = nn.Parameter(gamma*torch.randn(self.d,self.K))
        self.sq_cov_all = nn.Parameter(gamma*torch.ones(self.d,self.K))
        self.U_all = nn.Parameter(gamma*torch.randn(self.L-1,self.d,self.K))
        self.V_all = nn.Parameter(gamma*torch.randn(self.L-1,self.K))

    def forward(self, x):
        X = x.view(-1, 28*28)  # Flatten the image to a vector
        d = self.d
        K = self.K
        L = self.L
        for k in range(K):
            mu = self.mu_all[:,k].reshape([d,1])
            sq_cov = self.sq_cov_all[:,k].reshape([d,1])
            U = self.U_all[:,:,k]
            V = self.V_all[:,k].reshape([L-1,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(X,mu) # n * K
            sig_A = torch.sqrt(torch.sum(X * X * cov.T, axis = 1))      
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1))
            sig_B = sig_B.reshape([L-1,1]) # (L-1) * 1
            rho = torch.matmul(X * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN = fcn / K
            else:
                FCN += fcn / K
        # Append a zero (constant) column to make it compatible with 10 classes
        zeros = torch.zeros(FCN.shape[0], 1, device=FCN.device)  # Ensure zeros are on the same device as x
        FCN = torch.cat((FCN, zeros), dim=1)
            
        return FCN

K_cand = np.array([5,10,20])
num_K = K_cand.shape[0]
num_epochs = 100
nrep = 5
test_err_GMM = np.zeros([num_K,nrep,10*num_epochs])

for rep in range(nrep):
    for i in range(num_K):
        K = K_cand[i]
        # GMM step
        model = GMMNet(K=K,gamma=1/2)
        criterion = nn.CrossEntropyLoss()
    
        # Training hyperparameters
        learning_rate_mu = 0.1
        learning_rate_cov = 1
        learning_rate_others = 0.1
        
        # Training loop
        counter = 0
        for epoch in range(num_epochs):
            for batch_idx, (data, target) in enumerate(train_loader):
                # Zero gradients manually
                for param in model.parameters():
                    if param.grad is not None:
                        param.grad.zero_()
        
                # Forward pass
                output = model(data)
        
                # Calculate loss
                loss = criterion(output, target)
        
                # Backward pass
                loss.backward()
        
                with torch.no_grad():
                    for name, param in model.named_parameters():
                        if name=='sq_cov_all':
                            param -= learning_rate_cov * param.grad
                        elif name=='mu_all':
                            param -= learning_rate_mu * param.grad
                        else:    
                            param -= learning_rate_others * param.grad
            
                # Print training status
                if batch_idx % 100 == 0:
                    with torch.no_grad():
                        test_loss = 0.0
                        correct = 0
                        total = 0
                        for test_data, test_target in test_loader:
                            test_output = model(test_data)
                            test_loss += criterion(test_output, test_target).item()
                            _, predicted = test_output.max(1)
                            total += test_target.size(0)
                            correct += predicted.eq(test_target).sum().item()

                        test_error = 1 - correct / total
                        test_err_GMM[i,rep,counter] = test_error
                        counter += 1
            print(f'K: {K}, replicate: {rep}, Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item()}, Test Error: {test_error}')
    np.save('MNIST_GMM_test_error/Fashion_GMM_test_error.npy',test_err_GMM)