import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from math import sqrt, pi
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal

def gaussian_density(x):
    return 1.0 / sqrt(2*pi) * torch.exp(-0.5*torch.square(x))
normal_dist = Normal(loc=0.0, scale=1.0)

# Define the neural network structure
class VCG_GMM(nn.Module):
    def __init__(self,K=10,gamma=0.01):
        super(VCG_GMM, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256*8*8, 256)
        self.dropout = nn.Dropout(0.5)
        
        self.K = K
        self.L = 10
        self.d = 256
        self.mu_all = nn.Parameter(gamma*torch.randn(self.d,self.K))
        self.sq_cov_all = nn.Parameter(gamma*torch.ones(self.d,self.K))
        self.U_all = nn.Parameter(gamma*torch.randn(self.L-1,self.d,self.K))
        self.V_all = nn.Parameter(gamma*torch.randn(self.L-1,self.K))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(F.relu(self.bn4(self.conv4(x))), 2)
        x = x.view(-1, 256*8*8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        
        d = self.d
        K = self.K
        L = self.L
        X = x.view(-1, d)
        for k in range(K):
            mu = self.mu_all[:,k].reshape([d,1])
            sq_cov = self.sq_cov_all[:,k].reshape([d,1])
            U = self.U_all[:,:,k]
            V = self.V_all[:,k].reshape([L-1,1])
            cov = torch.square(sq_cov)
            mu_A = torch.matmul(X,mu) # n * K
            sig_A = torch.sqrt(torch.sum(X * X * cov.T, axis = 1) + 1e-6)    
            sig_A = sig_A.unsqueeze(1) # n * 1
            mu_B = torch.matmul(U,mu) # (L-1) * 1
            sig_B = torch.sqrt(torch.sum(U * U * cov.T, axis = 1) + 1e-6)
            sig_B = sig_B.reshape([L-1,1]) # (L-1) * 1
            rho = torch.matmul(X * cov.T, U.T) / (sig_A * sig_B.T) # n * (L-1)
            alpha =  (1.0 / sig_A) * rho * sig_B.T # n * (L-1)
            ratio = mu_A / sig_A
            PDF = gaussian_density(-ratio)
            CDF = normal_dist.cdf(-ratio)
            E_relu_A = mu_A + sig_A * PDF - mu_A * CDF # n * 1
            E_A_relu_A = (torch.square(mu_A) + torch.square(sig_A)) * (1 - CDF) + mu_A * sig_A * PDF # n * 1
            E_B_relu_A = E_relu_A * mu_B.T - alpha * mu_A * E_relu_A + alpha * E_A_relu_A # n * (L-1)
            fcn = E_B_relu_A - E_relu_A * V.T
            if k==0:
                FCN = fcn / K
            else:
                FCN += fcn / K
        # Append a zero (constant) column to make it compatible with 10 classes
        zeros = torch.zeros(FCN.shape[0], 1, device=FCN.device)  # Ensure zeros are on the same device as x
        FCN = torch.cat((FCN, zeros), dim=1)
            
        return FCN

# Load the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Training and testing loop
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]\tLoss: {loss.item():.6f}')

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_error = 100. * (1 - correct / len(test_loader.dataset))
    print(f'\nTest set: Average loss: {test_loss:.4f}, Error: {test_error:.2f}%\n')
    return test_error

# Hyperparameters
epochs = 50
nrep = 3
test_err_GMM = np.zeros([nrep,epochs])

for rep in range(nrep):
    print(f'Replicate {rep}\n')
    # Initialize the neural network, loss function, and optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VCG_GMM().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training and evaluation
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test_error = test(model, device, test_loader)
        test_err_GMM[rep,epoch-1] = test_error
    
    np.save('CIFAR_test_error/CIFAR_test_error_GMM.npy',test_err_GMM)