import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision.models import resnet18
import os

train_data_file = "data/mnist2d_train.npy"
test_data_file = "data/mnist2d_test.npy"
train_data = np.load(train_data_file).astype(np.float32).reshape(100, 6000, 1, 32, 32) / 255.
test_data = np.load(test_data_file).astype(np.float32).reshape(100, 1000, 1, 32, 32) / 255.


np.random.seed(42)
X_train_tensor = torch.from_numpy(train_data).reshape(-1,1,32,32)
Y_train_tensor = torch.from_numpy(np.linspace(0, 99, 100)).unsqueeze(-1).repeat(1, train_data.shape[1]).reshape(-1).long()
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)

X_test_tensor = torch.from_numpy(test_data).reshape(-1,1,32,32)
Y_test_tensor = torch.from_numpy(np.linspace(0, 99, 100)).unsqueeze(-1).repeat(1, test_data.shape[1]).reshape(-1).long()
test_dataset = TensorDataset(X_test_tensor, Y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)


model = resnet18(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(model.fc.in_features, 100)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def train(model, device, train_loader, criterion, optimizer, epochs=50):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        test_loss, accuracy = test(model, device, test_loader, criterion)
        print(f'Epoch {epoch+1}, Train Loss: {total_loss/len(train_loader)}, Test Loss: {test_loss}, Accuracy: {accuracy}')

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            #print(torch.cat([pred.reshape(-1,1),target.reshape(-1,1)],dim=1))
           
    test_loss /= len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return test_loss, accuracy

def evaluate_accuracy(X, Y):
    X_test_tensor = X.reshape(-1,1,32,32)
    Y_test_tensor = torch.from_numpy(Y.reshape(-1)).long()
    test_dataset = TensorDataset(X_test_tensor, Y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)


    model = resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(model.fc.in_features, 100)
    model_state_dict = torch.load('weights/classifier_mnist2d.pth')
    model.load_state_dict(model_state_dict)
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    _, accuracy = test(model, device, test_loader, criterion)
    return accuracy


weights_path = 'weights/classifier_mnist2d.pth'
if os.path.exists(weights_path):
    model_state_dict = torch.load(weights_path)
    model.load_state_dict(model_state_dict)
else:
    train(model, device, train_loader, criterion, optimizer)
    torch.save(model.state_dict(), weights_path)
    print("Model trained and weights saved.")