import torch
from network import MNIST_Net
from torchvision import transforms

class Trainer:
    def __init__(self, num_numbers, n_clusters, examples, test_X, test_Y):
        self.num_numbers = num_numbers
        self.n_clusters = n_clusters
        self.examples = examples
        self.test_X = test_X
        self.test_Y = test_Y
        self.network = MNIST_Net()
        self.network.optimizer = torch.optim.Adam(self.network.parameters(), lr=1e-3)
        self.criterion = torch.nn.CrossEntropyLoss()

    def train(self, cluster_ids, cluster_labels, epochs):
        for e in range(epochs):
            for i, ex in enumerate(self.examples):
                self.processExample(ex, cluster_ids, cluster_labels)

    def processExample(self, example, cluster_ids, cluster_labels):
        probs = self.getNNOutput(example[0], self.num_numbers)
        for i in range(self.num_numbers):
            for j in range(len(example[2][i])):
                idx = example[2][i][j]
                k = cluster_ids[idx]
                label = cluster_labels[k]
                target = [0.0 for _ in range(10)]
                target[label] = 1.0
                loss = self.criterion(probs[i][j], torch.tensor([target]))
                loss.backward()
        self.network.optimizer.step()
        self.network.optimizer.zero_grad()

    def getNNOutput(self, images, num_numbers):
        convert_tensor = transforms.ToTensor()
        probs = []
        for i in range(num_numbers):
            probsJ = []
            for j in range(len(images[i])):
                input = convert_tensor(images[i][j])
                out = self.network(input)
                probsJ.append(out)
            probs.append(probsJ)
        return probs

    def test(self):
        total = len(self.test_X)
        count = 0
        result = {}
        for i in range(10):
            for j in range(10):
                result[(i,j)] = 0
        for i in range(total):
            convert_tensor = transforms.ToTensor()
            input = convert_tensor(self.test_X[i])
            out = self.network(input)
            x = int(torch.argmax(out))
            y = self.test_Y[i]
            if x == y:
                count += 1
            result[(y,x)] += 1
        return count, total