import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F





class MyDataset(Dataset):
    def __init__(self, data_dict):
        self.data_dict= data_dict
        self.samples = [sample for samples in data_dict.values() for sample in samples]  
        self.labels = [label for label in data_dict.keys() for _ in range(len(data_dict[label]))]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (torch.from_numpy(self.samples[idx]).float(), torch.tensor(self.labels[idx], dtype=torch.long))


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc3 = nn.Linear(768, 10)

    def forward(self, x):

        x = self.fc3(x)
        return F.softmax(x, dim=1) #(bs,100)

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs) #(bs,100)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = correct / total

    return epoch_loss, epoch_accuracy

def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_features_file = r"..\cifar10_r200_TO5000_features.pkl"
with open(train_features_file, "rb") as f:
    train_features = pickle.load(f)

test_features_file = r"..\cifar10_test_features.pkl"
with open(test_features_file, "rb") as f:
    test_features = pickle.load(f)


train_dict = train_features
test_dict = test_features


train_dataset = MyDataset(train_dict)
test_dataset = MyDataset(test_dict)
batch_size = 64
learning_rate = 0.001
num_epochs = 60

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


model = MyNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


test_accuracies = []



def test_per_class(model, dataloader, device):
    class_correct_best = [0] * 10
    class_total_best = [0] * 10
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                class_correct_best[label] += (predicted[i] == labels[i]).item()
                class_total_best[label] += 1

    global class_accuracies_best
    class_accuracies_best = [class_correct_best[i] / class_total_best[i]  for i in range(10)]
    for i in range(10):
        print(f'Accuracy of class {i}: {class_accuracies_best[i]:.4f}')


best_test_accuracy = 0.0
best_epoch = 0


for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model, train_dataloader, criterion, optimizer, device)
    test_accuracy = test(model, test_dataloader, device)
    # scheduler.step()  

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}')


    if test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        best_epoch = epoch + 1
        torch.save(model.state_dict(), 'new_cifar10_best_model.pth')

best_model = MyNet().to(device)
best_model.load_state_dict(torch.load('new_cifar10_best_model.pth'))


test_per_class(best_model, test_dataloader, device)


print(f"Best Test Accuracy: {best_test_accuracy:.4f} at Epoch {best_epoch}")

