import os
import numpy as np
from copy import deepcopy

import torch
from torch.utils.data import DataLoader, TensorDataset


def train_multilayer_network(model, X_train, y_train, X_test, y_test, criterion, optimizer, epochs, batch_size=32,
                             pre_train=False):
    """
    Trains a multi-layer neural network and evaluates it on a test set.

    Args:
        model (torch.nn.Module): The neural network model to be trained, with multiple layers.
                The model consists of two layers,
                where the input dimensionality `D` is the number of input features,
                and the output dimensionality `C` is the number of classes.
            It takes input tensors of shape (N, D), where:
            - N: Batch size,
            - D: Input dimensionality (number of features).
        X_train (torch.Tensor): Training input data of shape (N_train, D), where:
            - N_train: Number of training samples,
            - D: Input dimensionality (number of features).
        y_train (torch.Tensor): Training target data of shape (N_train,), where:
            - N_train: Number of training samples.
        X_test (torch.Tensor): Testing input data of shape (N_test, D), where:
            - N_test: Number of testing samples,
            - D: Input dimensionality (number of features).
        y_test (torch.Tensor): Testing target data of shape (N_test,), where:
            - N_test: Number of testing samples.
        criterion (torch.nn.Module): The loss function used to calculate the difference between the model's output and the ground truth labels.
            Examples include `nn.BCELoss` for binary classification or `nn.CrossEntropyLoss` for multi-class classification.
        optimizer (torch.optim.Optimizer): The optimizer used to update the model's parameters during training, such as `Adam` or `SGD`.
        epochs (int): The number of training epochs to perform.
        batch_size (int, optional): The size of batches for training and testing data. Default is 32.
        pre_train (bool, optional): If `True`, the model will not print validation accuracy or perform validation evaluation.
            Useful during pre-training. Default is `False`.

    Returns:
        Tuple:
        - results (list of lists): A list of results for each epoch during evaluation, each element containing:
            - TP (True Positives): Number of true positive predictions,
            - TN (True Negatives): Number of true negative predictions,
            - FP (False Positives): Number of false positive predictions,
            - FN (False Negatives): Number of false negative predictions,
            - N_acc (Accuracy count): Number of correctly predicted samples.
        - model (torch.nn.Module): The trained neural network model.

    Input Tensor Shapes:
        - X_train (training input): Tensor of shape (N_train, D), where:
            - N_train: Number of training samples,
            - D: Dimensionality of input features.
        - y_train (training target): Tensor of shape (N_train,), where each element represents the target label for the corresponding sample.
        - X_test (test input): Tensor of shape (N_test, D), where:
            - N_test: Number of testing samples,
            - D: Dimensionality of input features.
        - y_test (test target): Tensor of shape (N_test,), where each element represents the target label for the corresponding test sample.

    Output:
        - results (list): A list of classification performance metrics for each epoch.
        - model (torch.nn.Module): The trained neural network model.
    """
    # Create DataLoader for training set, shuffling the data for better training
    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    # Create DataLoader for test set, without shuffling the data
    test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=batch_size, shuffle=False)

    # Initialize a list to store results for each epoch
    results = []
    # Loop through the specified number of epochs
    for epoch in range(epochs):
        # Iterate through the training data in batches
        for i, (X_train, y_train) in enumerate(train_loader):
            X_train, y_train = X_train.cuda(), y_train.cuda()
            optimizer.zero_grad()
            outputs = model(X_train)
            loss = criterion(outputs, y_train)

            # Check if the loss is NaN, if so, stop training and return empty results and model
            if torch.isnan(loss):
                print(f'nan, epoch: {epoch}')
                return [], model
            loss.backward()
            optimizer.step()

        # Initialize accuracy for this epoch
        accuracy = 0

        # If not in pre-training mode, print the current epoch's accuracy and loss
        if not pre_train:
            print(f"\repoch: {epoch} / {epochs} Accuracy: {accuracy} Loss: {loss.item()}", end='')
        # If not in pre-training mode, evaluate the model on the test set
        if not pre_train:
            # Initialize variables for counting True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN)
            TP = 0
            TN = 0
            FP = 0
            FN = 0
            # Count of correct predictions
            N_acc = 0

            # Disable gradient calculations for evaluation to save memory and computation
            with torch.no_grad():
                for i, (X_test, y_test) in enumerate(test_loader):
                    X_test, y_test = X_test.cuda(), y_test.cuda()
                    outputs = model(X_test)
                    predictions = outputs.argmax(1)
                    # Calculate the number of True Positives, True Negatives, False Positives, and False Negatives
                    TP += ((predictions == 1) & (y_test == 1)).float().sum().item()
                    TN += ((predictions == 0) & (y_test == 0)).float().sum().item()
                    FP += ((predictions == 1) & (y_test == 0)).float().sum().item()
                    FN += ((predictions == 0) & (y_test == 1)).float().sum().item()
                    # Calculate the total number of correct predictions
                    N_acc += (predictions == y_test).float().sum().item()

            # Store the classification results for the current epoch
            result = [TP, TN, FP, FN, N_acc]
            results.append(result)

    return results, model


def train_single_layer_network(model, train_loader, val_loader, criterion, optimizer, data_name, epochs):
    """
    Trains a single-layer neural network and evaluates it on a validation set.

    Args:
        model (torch.nn.Module): The neural network model to be trained. The model consists of single layer
        the input dimensionality `D`is the number of input features, and output one `C` is the number of classes.
            It takes input tensors of shape (N, D) where:
            - N: Batch size
            - D: Input dimensionality (number of features).
        train_loader (torch.utils.data.DataLoader): DataLoader object containing the training data,
            which provides batches of input data `x` and target labels `y` where:
            - x: Input tensor of shape (N, D),
            - y: Target tensor of shape (N, 1).
        val_loader (torch.utils.data.DataLoader): DataLoader object containing the validation data,
            similar to `train_loader` in shape and format.
        criterion (torch.nn.Module): The loss function to optimize, such as `nn.BCELoss` for binary classification
            or `nn.CrossEntropyLoss` for multi-class classification.
        optimizer (torch.optim.Optimizer): The optimizer used to update the model's parameters during training,
            e.g., `Adam`, `SGD`.
        data_name (str): A string that specifies the type of task and dataset, and determines the activation function and label processing.
            Possible values:
            - 'classification', 'circles', 'moons': Binary classification tasks, sigmoid activation is applied.
            - 'blobs': Multi-class classification task.
            - 'friedman1': Regression task.
        epochs (int): The number of training epochs to perform.

    Returns:
        Tuple:
        - For classification tasks ('classification', 'circles', 'moons', 'blobs'):
            accuracy (float): The proportion of correctly classified samples over the entire validation set.
                It is a value between 0 and 1.
            model (torch.nn.Module): The trained model after completing all epochs.
        - For regression tasks ('friedman1'):
            average_loss (float): The MSE computed over the entire validation set.
            model (torch.nn.Module): The trained model after completing all epochs.

    Input Tensor Shapes:
        - x (training/validation input): Tensor of shape (N, D), where:
            - N: Batch size (varies depending on the DataLoader batch size),
            - D: Dimensionality of input features.
        - y (training/validation labels):
            - For binary classification: Tensor of shape (N, 1) where each element is 0 or 1,
            - For multi-class classification: Tensor of shape (N,) where each element is an integer class label,
            - For regression: Tensor of shape (N, 1) where each element is a continuous value.

    Output:
        - accuracy (float): For classification tasks, it returns the accuracy, which is the ratio of correctly predicted samples to total samples.
        - average_loss (float): For regression tasks, it returns the average loss over the validation set.
        - model (torch.nn.Module): The trained model after the completion of all epochs.
    """
    # Initialize the best loss to infinity to track the lowest validation loss during training
    best_loss = np.inf

    for epoch in range(epochs):
        # Initialize variable to accumulate training loss for the current epoch
        accumulate_loss = 0

        # Training phase: iterate through the training data in batches
        for x, y in train_loader:
            x, y = x.cuda(), y.cuda()
            # Adjust the shape of the labels to match the model's output shape
            y = y.unsqueeze(1).float()
            optimizer.zero_grad()
            out = model(x)

            # Apply activation function based on the type of data
            if data_name in ['classification', 'circles', 'moons']:
                # For binary classification tasks, apply sigmoid to the output
                out = torch.sigmoid(out)
            elif data_name == 'blobs':
                # For multi-class classification, convert labels to long type
                y = y.squeeze(1).long()

            loss = criterion(out, y)
            accumulate_loss += loss.item()
            loss.backward()
            optimizer.step()
        print(
            f'\rEpoch {epoch + 1}/{epochs}, Best Loss: {best_loss / len(train_loader)}, Loss: {accumulate_loss / len(train_loader)}',
            end='')

        # Reset the accumulated loss for validation
        accumulate_loss = 0

        # Validation phase: evaluate the model without computing gradients
        with torch.no_grad():
            # Iterate through the validation data in batches
            for x, y in val_loader:
                x, y = x.cuda(), y.cuda()
                y = y.unsqueeze(1).float()
                out = model(x)
                # Apply the same activation logic for validation as in training
                if data_name in ['classification', 'circles', 'moons']:
                    out = torch.sigmoid(out)
                elif data_name == 'blobs':
                    y = y.squeeze(1).long()
                # Compute the validation loss
                loss = criterion(out, y)
                # Accumulate the validation loss
                accumulate_loss += loss.item()

        # Check if the current validation loss is better than the best loss
        if accumulate_loss < best_loss:
            # Update the best loss if the current one is better
            best_loss = accumulate_loss

    # Post-training evaluation to compute accuracy or average loss
    accumulate_loss = 0
    correct_predictions = 0
    total_samples = 0

    # No gradient computation needed for final evaluation
    with torch.no_grad():
        # Iterate through the validation data in batches
        for x, y in val_loader:
            # Move data and labels to GPU
            x, y = x.cuda(), y.cuda()
            # Adjust the shape of labels for specific datasets
            if data_name == 'blobs':
                y = y.long()
            else:
                y = y.unsqueeze(1).float()

            # Forward pass to get predictions
            out = model(x)

            # Compute predictions and accuracy for different tasks
            if data_name in ['classification', 'circles', 'moons']:
                # For binary classification, apply sigmoid and threshold at 0.5
                out = torch.sigmoid(out)
                preds = (out > 0.5).float()
                # Count correct predictions for accuracy
                correct_predictions += (preds == y).sum().item()
                total_samples += y.size(0)
            elif data_name == 'blobs':
                # For multi-class classification, use argmax to get predictions
                preds = torch.argmax(out, dim=1)
                # Count correct predictions for accuracy
                correct_predictions += (preds == y).sum().item()
                total_samples += y.size(0)
            elif data_name == 'friedman1':
                # For regression, just track total samples
                total_samples += y.size(0)

            # Compute the loss for the current batch
            loss = criterion(out, y)
            accumulate_loss += loss.item()

    # Calculate the average loss over the entire validation set
    average_loss = accumulate_loss / len(val_loader)

    # Return accuracy for classification tasks, and average loss for regression
    if data_name in ['classification', 'circles', 'moons', 'blobs']:
        accuracy = correct_predictions / total_samples
        return accuracy, model
    elif data_name == 'friedman1':
        return average_loss, model


def train_model(model, train_loader, criterion, optimizer, epoch, epochs, l1_lambda, l2_lambda):
    """
    Trains the given model for one epoch using the provided data loader, criterion (loss function),
    and optimizer. Optionally applies L1 and L2 regularization.

    Args:
        model (torch.nn.Module): The neural network model to be trained.
        train_loader (torch.utils.data.DataLoader): DataLoader providing the training data.
        criterion (torch.nn.Module): Loss function used for training.
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
        epoch (int): The current epoch number.
        epochs (int): The total number of epochs.
        l1_lambda (float): Regularization strength for L1 norm. Set to 0 to disable.
        l2_lambda (float): Regularization strength for L2 norm. Set to 0 to disable.

    Returns:
        None
    """
    optimizer.zero_grad()
    accuracy = 0
    total_loss = 0

    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        loss = criterion(outputs, labels)
        accuracy += (outputs.argmax(1) == labels).float().sum().item()
        total_loss += loss.item()

        # If L1 or L2 regularization is enabled, calculate the norms and add to the loss
        if l1_lambda or l2_lambda:
            # L1 regularization: sum of absolute values of certain model parameters
            l1_norm = sum(p.abs().sum() for name, p in model.named_parameters() if 'darts_weight' in name)
            # L2 regularization: sum of squares of certain model parameters
            l2_norm = sum(p.pow(2).sum() for name, p in model.named_parameters() if 'darts_weight' in name)
            # Add regularization penalties to the loss
            loss += l2_lambda * l2_norm + l1_lambda * l1_norm

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"\repoch: {epoch}/{epochs} batch: {i} / {len(train_loader)} Loss: {loss.item()}", end="")
    print(f"\repoch: {epoch}/{epochs} Loss: {total_loss / len(train_loader)} Accuracy: {accuracy / len(train_loader.dataset)}")
    return model


def val_model(model, val_loader, criterion):
    accumulate_loss = 0
    accumulate_accuracy = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(val_loader):
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)
            accumulate_loss += loss.item()
            accuracy = (outputs.argmax(1) == labels).float().sum().item()
            accumulate_accuracy += accuracy
            print(f"\rVal: {i}/{len(val_loader)} Loss: {loss.item()} Accuracy: {accuracy}", end="")
    accumulate_accuracy /= len(val_loader.dataset)
    accumulate_loss /= len(val_loader)
    return accumulate_loss, accumulate_accuracy


def basic_train(model, train_loader, val_loader, criterion, optimizer, epochs, model_file=None, l1_lambda=0, l2_lambda=0):
    model = model.cuda()
    best_loss, best_accuracy = 0, 0
    for epoch in range(epochs):
        model = train_model(model, train_loader, criterion, optimizer, epoch, epochs, l1_lambda, l2_lambda)
        loss, accuracy = val_model(model, val_loader, criterion)
        print(f"\rEpoch: {epoch}/{epochs} Loss: {loss} Accuracy: {accuracy}")
        if accuracy > best_accuracy:
            if model_file is not None:
                model_dir = os.path.join(model_file[0], model_file[1])
                torch.save(model.state_dict(), model_dir)
            best_loss, best_accuracy = loss, accuracy

    return model, best_loss, best_accuracy


def train_supernet_body(model, train_loader, test_loader, criterion, model_dir, lr, round_num=30):
    model_file = (model_dir, 'final1.pth')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr / 10)
    model, _, _ = basic_train(model, train_loader, test_loader, criterion, optimizer, model_file=model_file, epochs=round_num)
    model_file = (model_dir, 'final2.pth')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr / 10)
    model, _, _ = basic_train(model, train_loader, test_loader, criterion, optimizer, model_file=model_file, epochs=round_num)
    model_file = (model_dir, 'final3.pth')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr / 10)
    model, _, _ = basic_train(model, train_loader, test_loader, criterion, optimizer, model_file=model_file, epochs=round_num)
