import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score

import torch
import matplotlib.pyplot as plt
import matplotlib
import os

from utils import NoisyMNISTDataset

class linear_model(nn.Module):
    def __init__(self):
        super(linear_model, self).__init__()
        
        self.fc1 = nn.Linear(28*28, 128, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64, bias=False)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 32, bias=False)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(32, 10, bias=False)
        # Note: No ReLU after the last layer, as this will go to a softmax

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc4(x)
        
        return x
    
    def get_hidden_layer_fc1_output(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu1(x)
        
        return x
    
    def get_hidden_layer_fc2_output(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        
        return x
        
    def get_hidden_layer_fc3_output(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        
        return x

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_loader)
        val_loss = validate_model(model, val_loader, criterion)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}')
        
        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print('Best model saved')

def validate_model(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
    avg_loss = total_loss / len(val_loader)
    return avg_loss

def test_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    print(f'Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

if __name__ == "__main__":
    
    transform = transforms.Compose([transforms.ToTensor(),])
    mnist_train = NoisyMNISTDataset(
        image_folder='data/noisy_mnist/train',
        labels_file=os.path.join('data/noisy_mnist/train', 'labels.txt'),
        transform=transform
        )
    mnist_test = NoisyMNISTDataset(
        image_folder='data/noisy_mnist/test',
        labels_file=os.path.join('data/noisy_mnist/test', 'labels.txt'),
        transform=transform
        )
    
    train_size = int(0.8 * len(mnist_train))
    val_size = len(mnist_train) - train_size
    train_dataset, val_dataset = random_split(mnist_train, [train_size, val_size])
    bz = 64
    train_loader = DataLoader(train_dataset, batch_size=bz, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=bz, shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=bz, shuffle=True)

    model = linear_model()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)
    
    model.load_state_dict(torch.load('weights/best_model.pth'))
    test_model(model, test_loader, criterion)
