import random

import torch
from network import MNIST_Net
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

class Trainer:
    def __init__(self, num_numbers, n_clusters, examples, image_ids, train_X, test_X, test_Y):
        self.num_numbers = num_numbers
        self.n_clusters = n_clusters
        self.examples = examples
        self.image_ids = image_ids
        self.train_X = train_X
        self.test_X = test_X
        self.test_Y = test_Y
        self.network = MNIST_Net()
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=1e-3)
        self.criterion = torch.nn.CrossEntropyLoss()


    def train(self,cluster_labels, epochs):

        train_loader = self.process_examples(cluster_labels)
        for epoch in range(epochs):
            for i, data in enumerate(train_loader, 0):
                running_loss = 0.0

                inputs, labels = data

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.network(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()

        #save the model
        PATH = './classifier_result/MNIST_NET.pth'
        torch.save(self.network.state_dict(), PATH)

    def process_examples(self, cluster_labels):
        train_x = []
        train_y = []
        for idx in self.image_ids:
            train_x.append(self.train_X[idx])
            train_y.append(cluster_labels[idx])
        train_x = np.array(train_x)
        train_y = np.array(train_y)

        tensor_x = torch.Tensor(train_x)
        tensor_y = torch.Tensor(train_y).type(torch.LongTensor)

        train_dataset = TensorDataset(tensor_x, tensor_y)
        train_loader = DataLoader(train_dataset, batch_size=128)
        return train_loader
    def process_examples_old(self, cluster_labels ):
        data = {}
        train_X = []
        train_Y = []
        for ex in self.examples:
            for i in range(self.num_numbers):
                for j in range(len(ex[2][i])):
                    idx = ex[2][i][j]
                    if idx not in data:
                        #k = cluster_ids[idx]
                        label = cluster_labels[idx]
                        image = ex[0][i][j]
                        data[idx] = 0
                        train_X.append(image)
                        train_Y.append(label)

        print("Number of examples used to train classifier: {}".format(len(data)))
        #train_X = np.asarray(list(zip(*list(data.values())))[0])
        #train_Y = np.asarray(list(zip(*list(data.values())))[1])

        train_X = np.array(train_X)
        train_Y = np.array(train_Y)

        tensor_x = torch.Tensor(train_X)
        tensor_y = torch.Tensor(train_Y).type(torch.LongTensor)

        train_dataset = TensorDataset(tensor_x, tensor_y)
        train_loader = DataLoader(train_dataset, batch_size=128)
        return train_loader

    def transform_image(self, image):
        image_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (1,)),
            ]
        )
        return image_transform(image)

    def test(self):

        # PATH = './classifier_result/MNIST_NET.pth'
        # network = MNIST_Net()
        # network.load_state_dict(torch.load(PATH))
        tensor_x_test = torch.Tensor(self.test_X)
        tensor_y_test = torch.Tensor(self.test_Y).type(torch.LongTensor)

        test_dataset = TensorDataset(tensor_x_test, tensor_y_test)
        test_loader = DataLoader(test_dataset, batch_size=1000)

        correct = 0
        total = 0
        # since we're not training, we don't need to calculate the gradients for our outputs
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                # calculate outputs by running images through the network
                outputs = self.network(images)
                # the class with the highest energy is what we choose as prediction
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
        return correct, total

    def test_addition(self, examples):
        total = len(examples)
        correct = 0
        for ex in examples:
            sum = 0
            for n in range(len(ex[0])):
                for m in range(len(ex[0][n])):
                    convert_tensor = transforms.ToTensor()
                    input = convert_tensor(ex[0][n][m])
                    out = self.network(input)
                    label = int(torch.argmax(out))
                    coeff = pow(10, len(ex[0][n]) - m - 1)
                    sum += label * coeff
            if sum == ex[1]:
                correct += 1

        print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
        return correct, total




