import torch
import torch.nn as nn
import torch.optim as optim
from .networks import Network

class Trainer:
    def __init__(self, net_architecture, train_loader, test_loader):
        """Constructs trainer which manages and trains neural network
        Args:
            net_architecture: Dictionary of the network architecture. Needs keys 'type' and 'dims'. Low-rank layers need key 'rank'.
            train_loader: loader for training data
            test_loader: loader for test data
        """

        # torch.manual_seed(0)

        # Set the device (GPU or CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize the model
        self.model = Network(net_architecture).to(self.device)

        # find all ids of dynamical low-rank layers, since these layer require two steps
        self.dlr_layer_ids = [index for index, layer in enumerate(net_architecture) if layer['type'] == 'dynamical_low_rank']

        # find all rank-adaptive layers
        self.adaptive_layer_ids = [index for index, layer in enumerate(net_architecture) if layer['type'] == 'dynamical_low_rank' or layer['type'] == 'parallel_low_rank']


        # store train and test data
        self.train_loader = train_loader
        self.test_loader = test_loader

    def train(self, num_epochs, learning_rate, optimizer_type="Adam"):
        """Trains neural network for specified number of epochs with specified learning rate
        Args:
            num_epochs: number of epochs for training
            learning_rate: learning rate for optimization method
            optimizer_type: used optimizer. Use Adam for vanilla training.
        """
        # Define the loss function and optimizer. Optimizer is only needed to set all gradients to zero.
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        
        # torch.manual_seed(0)
        
        # Training loop
        for epoch in range(num_epochs):
            self.model.train()
            for batch_idx, (data, targets) in enumerate(self.train_loader):
                data = data.to(self.device)
                targets = targets.to(self.device)

                # Forward pass
                outputs = self.model(data)
                loss = criterion(outputs, targets)

                # Backward to calculate gradients of parameters
                optimizer.zero_grad()
                loss.backward()

                ################ update entire network without low-rank coefficients ################
                if optimizer_type == "Adam" and not self.dlr_layer_ids:
                    optimizer.step()
                else:
                    self.model.step(learning_rate)

                ################### Coefficient update ###################
                if self.dlr_layer_ids:
                    # Forward pass
                    outputs = self.model(data)
                    loss = criterion(outputs, targets)

                    # Backward to calculate gradients of coefficients
                    optimizer.zero_grad()
                    loss.backward()

                    # update coefficients in low-rank layers
                    for i in self.dlr_layer_ids:
                        self.model.layers[i].step(learning_rate, "coefficients")

                # print progress
                if (batch_idx + 1) % 300 == 0:
                    print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(self.train_loader)}], Loss: {loss.item():.4f}")

            # evaluate model on test date
            self.test_model()

    def test_model(self):
        """Prints the model's accuracy on the test data
        """
        # Test the model
        self.model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for data, targets in self.test_loader:
                data = data.to(self.device)
                targets = targets.to(self.device)

                outputs = self.model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

            accuracy = 100 * correct / total
            print(f"Accuracy of the network on the test images: {accuracy}%")

            print("Ranks: ", end=' ')
            for i in self.adaptive_layer_ids:
                print(self.model.layers[i].r, end=' ')
            print()
