import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.distributions import Normal
from math import sqrt, pi
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader

def gaussian_density(x):
    return 1.0 / sqrt(2*pi) * torch.exp(-0.5*torch.square(x))
normal_dist = Normal(loc=0.0, scale=1.0)

def normalize_individual_image(img):
    # Flatten the image to compute mean and std
    img_flat = img.view(-1)
    mean = img_flat.mean()
    std = img_flat.std()
    
    # Normalize the image
    img_normalized = (img_flat - mean) / (std + 1e-5)  # Adding a small constant to avoid division by zero
    return img_normalized.view(img.size())  # Reshape it back to the original shape


# Transformation for MNIST data (normalize with mean and std of MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_individual_image)
    #transforms.Normalize((0.1307,), (0.3081,))  # Normalize using MNIST mean and std
])

# Load MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)  # Evaluate on entire test set at once

class FullyConnectedNet(nn.Module):
    def __init__(self,width,Beta,W):
        super(FullyConnectedNet, self).__init__()
        # Only one fully-connected layer
        self.fc1 = nn.Linear(28*28, width)  # Input is 784 features (28x28), output features = 1000 
        self.fc2 = nn.Linear(width, 9)  # Input is 784 features (28x28), output features = 1000
        self.fc1.weight = nn.Parameter(Beta)
        self.fc2.weight = nn.Parameter(W)
        
    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image to a vector
        x = self.fc2(F.relu(self.fc1(x)))/width  # Fully connected layer
        zeros = torch.zeros(x.shape[0], 1, device=x.device)  # Ensure zeros are on the same device as x
        x = torch.cat((x, zeros), dim=1)
        return x
    
class GMMNet(nn.Module):
    def __init__(self,K=20,gamma=1):
        super(GMMNet, self).__init__()
        # Only one fully-connected layer
        #self.fc1 = nn.Linear(28*28, 1000)  # Input is 784 features (28x28), output features = 1000 
        #self.fc2 = nn.Linear(1000, 10)  # Input is 784 features (28x28), output features = 1000
        self.K = K
        self.L = 10
        self.d = 28*28
        self.mu_all = nn.Parameter(gamma*torch.randn(self.d,self.K))
        self.sq_cov_all = nn.Parameter(gamma*torch.ones(self.d,self.K))
        self.U_all = nn.Parameter(gamma*torch.randn(self.L-1,self.d,self.K))
        self.V_all = nn.Parameter(gamma*torch.randn(self.L-1,self.K))

    def forward(self, x):
        X = x.view(-1, 28*28)  # Flatten the image to a vector
        d = self.d
        K = self.K
        L = self.L
        for k in range(K):
            mu = self.mu_all[:,k].reshape([d,1])
            sq_cov = self.sq_cov_all[:,k].reshape([d,1])
            U = self.U_all[:,:,k]
            V = self.V_all[:,k].reshape([L-1,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(X,mu) # n * K
            sig_A = torch.sqrt(torch.sum(X * X * cov.T, axis = 1))      
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1))
            sig_B = sig_B.reshape([L-1,1]) # (L-1) * 1
            rho = torch.matmul(X * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN = fcn / K
            else:
                FCN += fcn / K
        # Append a zero (constant) column to make it compatible with 10 classes
        zeros = torch.zeros(FCN.shape[0], 1, device=FCN.device)  # Ensure zeros are on the same device as x
        FCN = torch.cat((FCN, zeros), dim=1)
            
        return FCN

K = 5
num_epochs = 200
nrep = 10
gamma = 1/2
mu_history = np.zeros([28*28,K,10*num_epochs])
cov_history = np.zeros([28*28,K,10*num_epochs])
mu_proj_history = np.zeros([2,K,10*num_epochs])
cov_proj_history = np.zeros([2,2,K,10*num_epochs])
test_err_GMM = np.zeros([10*num_epochs])
n_mcmc = 20000

# GMM step
model = GMMNet(K=K,gamma=gamma)
criterion = nn.CrossEntropyLoss()

#mu_history[:,:,0] = model.mu_all.clone().detach().numpy()
#sq_cov_history[:,:,0] = model.sq_cov_all.clone().detach().numpy()
    
# Training hyperparameters
learning_rate_mu = 0.1
learning_rate_cov = 1
learning_rate_others = 0.1
        
# Training loop
counter = 0
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Zero gradients manually
        for param in model.parameters():
            if param.grad is not None:
                param.grad.zero_()
        
        # Forward pass
        output = model(data)
        
        # Calculate loss
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        
        # Manual SGD: update weights
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name=='sq_cov_all':
                    param -= learning_rate_cov * param.grad
                elif name=='mu_all':
                    param -= learning_rate_mu * param.grad
                else:    
                    param -= learning_rate_others * param.grad
            
        # Print training status
        if batch_idx % 100 == 0:
            with torch.no_grad():
                test_loss = 0.0
                correct = 0
                total = 0
                for test_data, test_target in test_loader:
                    test_output = model(test_data)
                    test_loss += criterion(test_output, test_target).item()
                    _, predicted = test_output.max(1)
                    total += test_target.size(0)
                    correct += predicted.eq(test_target).sum().item()

                test_error = 1 - correct / total
                print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item()}, Test Error: {test_error}')
                mu_current = model.mu_all.clone().detach().numpy()
                sq_cov_current = model.sq_cov_all.clone().detach().numpy()
                #latent_class = np.random.randint(0, K, size=n_mcmc)
                #latent_class = latent_class.tolist()
                #Beta = np.random.randn(28*28,n_mcmc)
                #Beta = mu_current[:,latent_class] + sq_cov_current[:,latent_class] * Beta
                #sample_cov = np.cov(Beta)
                #eigvals, eigvecs = np.linalg.eig(sample_cov)
                #for k in range(K):
                #    mean_k_proj = np.matmul(eigvecs[:,0:2].T,mu_current[:,k])
                #    cov_k = np.diag(np.square(sq_cov_current[:,k]))
                #    cov_k_proj = np.matmul(np.matmul(eigvecs[:,0:2].T,cov_k),eigvecs[:,0:2])
                #    mu_history[:,k,counter] = mean_k_proj
                #    cov_history[:,:,k,counter] = cov_k_proj
                mu_history[:,:,counter] = mu_current
                cov_history[:,:,counter] = sq_cov_current
                test_err_GMM[counter] = test_error
                counter += 1
                
latent_class = np.random.randint(0, K, size=n_mcmc)
latent_class = latent_class.tolist()
Beta = np.random.randn(28*28,n_mcmc)
Beta = mu_current[:,latent_class] + sq_cov_current[:,latent_class] * Beta
sample_cov = np.cov(Beta)
eigvals, eigvecs = np.linalg.eig(sample_cov)
for i in range(counter):
    for k in range(K):
        mean_k_proj = np.matmul(eigvecs[:,0:2].T,mu_history[:,k,i])
        cov_k = np.diag(np.square(cov_history[:,k,i]))
        cov_k_proj = np.matmul(np.matmul(eigvecs[:,0:2].T,cov_k),eigvecs[:,0:2])
        mu_proj_history[:,k,i] = mean_k_proj
        cov_proj_history[:,:,k,i] = cov_k_proj
                    
np.save('MNIST_GMM_dynamic/MNIST_GMM_dynamic_mu_K5.npy',mu_proj_history)
np.save('MNIST_GMM_dynamic/MNIST_GMM_dynamic_cov_K5.npy',cov_proj_history)
np.save('MNIST_GMM_dynamic/MNIST_GMM_dynamic_test_K5.npy',test_err_GMM)