import torch
from . import utils
import torch.optim as optim

class TwoLayerNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, activation='relu'):
        super(TwoLayerNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        
        # Set activation based on the argument passed
        if activation == 'relu':
            self.activation = torch.nn.ReLU()
        elif activation == 'quadratic':
            self.activation = lambda x: x ** 2
        else:
            raise ValueError("Unsupported activation type. Choose 'relu' or 'quadratic'.")
        
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x

def nn_train_and_test(X_train, y_train, X_test, y_test, c,
                  hyperparams, log_hyperparams,
                  device='cpu'):
    ## TRAINS 2-LAYER NEURAL NETWORK ON X_train, y_train, AND TESTS ON X_test, y_test.
    ## c is the number of classes, although MSE loss is used
    ## hyperparams is a dict specifying the training process
    ## log-hyperparams is a dict with information on what to return
    ## device is the device on which the network is trained
    
    ## Given hyperparameters in hyperparams, trains a neural network on the task
    num_epochs=hyperparams['num_epochs']
    weight_decay=hyperparams['weight_decay']
    hidden_size=hyperparams['hidden_size']
    lr=hyperparams['lr']
    batch_size=hyperparams['batch_size']
    activation=hyperparams['activation']
    normalize=hyperparams['normalize']
    ## log_hyperparams is a dictionary with information about what to log and return
    return_model = log_hyperparams['return_model'] # Flag for whether to return trained model
    return_last = log_hyperparams['return_last'] # Flag for whether the model returned should be at BEST epoch (largest test acc) or LAST epoch.
    return_M_interval = log_hyperparams['return_M_interval'] # if -1, don't return any M matrices. Else return M every return_M_interval epochs.
    verbose = log_hyperparams['verbose'] # if True print, else don't print

    
    ## TRAINING CODE
    y_t_orig = y_train
    y_v_orig = y_test
    y_train = utils.convert_one_hot(y_train, c)
    y_test = utils.convert_one_hot(y_test, c)

    if normalize:
        X_train /= norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train).float().to(device)
    y_train = torch.from_numpy(y_train).float().to(device)
    X_test = torch.from_numpy(X_test).float().to(device)
    y_test = torch.from_numpy(y_test).float().to(device)
    
    # Create the model and move it to the device
    input_size = X_train.shape[1]  # Number of features
    num_classes = y_train.shape[1]  # Number of classes (from the one-hot encoded labels)
    model = TwoLayerNN(input_size, hidden_size, num_classes, activation=activation).to(device)

    # Loss function and optimizer
    criterion = torch.nn.MSELoss()  # Use MSE for one-hot encoded classification
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_epoch = -1
    best_acc = -1
    best_model = None
    Ms = []
    test_accs = []
    train_accs = []
    train_losses = []
    test_losses = []
    # Training loop
    
    
    ## EVALUATE MODEL
    # Testing the model on the training set
    model.eval()
    with torch.no_grad():
        outputs = model(X_train)
        train_loss = criterion(outputs, y_train)
        predicted = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(y_train, dim=1)
        correct = (predicted == true_labels).sum().item()
        train_accuracy = 100 * correct / X_train.size(0)
        if verbose:
            print(f'Training Accuracy: {train_accuracy:.2f}%')
        train_accs.append(train_accuracy)
        train_losses.append(train_loss.cpu().item())

    # Evaluation on the test set
    with torch.no_grad():
        outputs = model(X_test.to(device))
        test_loss = criterion(outputs, y_test)
        predicted = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(y_test.to(device), dim=1)
        correct = (predicted == true_labels).sum().item()
        test_accuracy = correct / X_test.size(0)
        if verbose:
            print(f'Test Accuracy: {100*test_accuracy:.2f}%')
        test_accs.append(test_accuracy)
        test_losses.append(test_loss.cpu().item())
    if not return_last:
        if test_accuracy >= best_acc:
            best_acc = test_accuracy
            best_epoch = epoch+1
            if return_model:
                best_model = copy.deepcopy(model.cpu())
    
    for epoch in range(num_epochs):
        if return_M_interval > -1:
            if epoch % return_M_interval == 0:
                W = model.fc1.weight.data
                Ms.append((epoch,(W.T @ W).cpu()))
        model.train()
        permutation = torch.randperm(X_train.size(0))

        for i in range(0, X_train.size(0), batch_size):
            indices = permutation[i:i + batch_size]
            batch_x, batch_y = X_train[indices], y_train[indices]

            # Forward pass
            outputs = model(batch_x)

            # Compute MSE loss using one-hot encoded y_train
            loss = criterion(outputs, batch_y)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if verbose:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

        ## EVALUATE MODEL
        # Testing the model on the training set
        model.eval()
        with torch.no_grad():
            outputs = model(X_train)
            train_loss = criterion(outputs, y_train)
            predicted = torch.argmax(outputs, dim=1)
            true_labels = torch.argmax(y_train, dim=1)
            correct = (predicted == true_labels).sum().item()
            train_accuracy = 100 * correct / X_train.size(0)
            if verbose:
                print(f'Training Accuracy: {train_accuracy:.2f}%')
            train_accs.append(train_accuracy)
            train_losses.append(train_loss.cpu().item())

        # Evaluation on the test set
        with torch.no_grad():
            outputs = model(X_test.to(device))
            test_loss = criterion(outputs, y_test)
            predicted = torch.argmax(outputs, dim=1)
            true_labels = torch.argmax(y_test.to(device), dim=1)
            correct = (predicted == true_labels).sum().item()
            test_accuracy = correct / X_test.size(0)
            if verbose:
                print(f'Test Accuracy: {100*test_accuracy:.2f}%')
            test_accs.append(test_accuracy)
            test_losses.append(test_loss.cpu().item())
        if not return_last:
            if test_accuracy >= best_acc:
                best_acc = test_accuracy
                best_epoch = epoch+1
                if return_model:
                    best_model = copy.deepcopy(model.cpu())
    
    if return_M_interval > 0:
        if num_epochs % return_M_interval == 0:
            W = model.fc1.weight.data
            Ms.append((num_epochs,(W.T @ W).cpu()))
    if return_last:
        best_acc = test_accuracy
        best_epoch = num_epochs-1
        if return_model:
                best_model = copy.deepcopy(model.cpu())
    log_dict = {'test_accs' : test_accs, 'train_accs' : train_accs, 'train_losses' : train_losses, 'test_losses' : test_losses}
    if return_model:
        log_dict['model'] = best_model
        log_dict['epoch'] = best_epoch
    if return_M_interval > 0:
        log_dict['Ms'] = Ms
    return log_dict
