import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from scipy.stats import entropy


def entropy_loss(x):
    x = x + 1e-5
    with torch.no_grad():
        Z = torch.sum(x, dim=1, keepdim=True)
    x = x / Z

    loss = 0
    for i in range(x.size(0)):
        loss = loss - torch.sum(x[i] * x[i].log()) / np.log(x.size(1)) / x.size(0)

    return loss

def batch_entropy_loss(x):
    x = x + 1e-5
    with torch.no_grad():
        Z = torch.sum(x, dim=1, keepdim=True)
    x = x / Z

    x = x.mean(dim=0)
    loss = -torch.sum(x * x.log()) / np.log(10)
    return loss

# Define a custom dataset for MNIST addition
class MNISTAddition(Dataset):
    def __init__(self, mnist_data):
        self.mnist_data = mnist_data

    def __len__(self):
        return len(self.mnist_data)

    def __getitem__(self, idx):
        img1, label1 = self.mnist_data[random.randint(0, len(self.mnist_data) - 1)]
        img2, label2 = self.mnist_data[random.randint(0, len(self.mnist_data) - 1)]
        addition_result = int((label1 + label2) % 2 == 0)
        combined_image = torch.cat((img1, img2), dim=2)
        return combined_image, addition_result, torch.tensor(label1), torch.tensor(label2)


class FullyNeuralAdditionModel(nn.Module):
    def __init__(self):
        super(FullyNeuralAdditionModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # [32, 28, 28]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [32, 14, 14]
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [64, 14, 14]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [64, 7, 7]
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # [128, 7, 7]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [128, 3, 3]
        )
        self.flatten = nn.Flatten()
        self.fc_individual = nn.Sequential(
            nn.Linear(128 * 3 * 3, 256),  # Processed features for each image
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=1)
        )
        self.fc_aggregate = nn.Sequential(
            nn.Linear(100, 2, bias=False), # Output range: 0-18 (max sum of two MNIST digits)
        )

        torch.nn.init.normal_(self.fc_aggregate[0].weight, mean=0.0, std=3.0)

    def forward(self, x):
        features1 = self.fc_individual(self.flatten(self.cnn(x[:, :, :, :28])))
        features2 = self.fc_individual(self.flatten(self.cnn(x[:, :, :, 28:])))
        # Concatenate features and aggregate
        combined_features = features1.unsqueeze(2).multiply(features2.unsqueeze(1)).view(features1.shape[0], -1)
        # combined_features = torch.cat([features1, features2], dim=1)
        output = self.fc_aggregate(combined_features)
        return output, features1, features2

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main(seed, save_path, entropy):
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=False, transform=transform)
    
    mnist_addition = MNISTAddition(mnist_train)
    mnist_addition_test = MNISTAddition(mnist_test)
    train_loader = DataLoader(mnist_addition, batch_size=64, shuffle=True)
    test_loader = DataLoader(mnist_addition_test, batch_size=64, shuffle=False)

    num_epochs = 100

    model = FullyNeuralAdditionModel().to("cuda")
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, labels, c1, c2) in enumerate(train_loader):
            optimizer.zero_grad()

            images = images.to("cuda")
            c1 = c1.to("cuda")
            c2 = c2.to("cuda")
            labels = labels.to("cuda")
            outputs, p1, p2 = model(images)
            
            loss = criterion(outputs, labels)
            if entropy:
                loss += (1 - batch_entropy_loss(p1) + 1 - batch_entropy_loss(p2))

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # scheduler.step()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")

    print("Training finished!")

    torch.save(model.state_dict(), os.path.join(save_path, f"cbm_best_addmnist_seed_{seed}_{entropy}.pth"))
    print(f"Model saved for seed {seed} and entropy {entropy}")

    from sklearn.metrics import f1_score, accuracy_score

    model.eval()

    all_preds = []
    all_labels = []
    all_concepts = []
    all_c_preds = []

    device = torch.device("cuda")

    with torch.no_grad():
        for images, labels, c1, c2 in test_loader:
            images, c1, c2, labels = images.to(device), c1.to(device), c2.to(device), labels.to(device)
            
            outputs, c_pred_1, c_pred_2 = model(images)
            outputs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            full_concepts = torch.cat([c1, c2], dim=0)
            full_c_pred = torch.cat([c_pred_1, c_pred_2], dim=0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_c_preds.extend(full_c_pred.cpu().numpy())
            all_concepts.extend(full_concepts.cpu().numpy())

    f1_macro = f1_score(all_labels, all_preds, average='macro')
    accuracy = accuracy_score(all_labels, all_preds)

    print("#########")
    print(f"Evaluation for seed: {seed} and entropy {entropy}")
    print(f"Validation Accuracy: {accuracy:.4f}")
    print(f"Validation F1-Macro: {f1_macro:.4f}")

    # Calculate entropy for each predicted concept
    # entropies = np.array([entropy(p) for p in all_c_preds])
    # print(len(entropies), entropies.mean())

    # Print average entropy
    # print(f"Average Entropy of Predicted Concepts: {entropies.mean():.4f}")

    cm = confusion_matrix(all_concepts, np.argmax(all_c_preds, axis=1))
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Confusion Matrix for Concept")
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")
    plt.savefig(f"confusion_matrix_concept_{seed}_{entropy}.png")
    plt.close()

if __name__ == "__main__":
    save_path = "./cbm_models"
    os.makedirs(save_path, exist_ok=True)
    seeds = [1011, 1213, 1415, 1617, 1819]
    for entropy in [True]:
        for seed in seeds:
            main(seed, save_path, entropy)
