import os, sys
import torch.nn as nn
import torch
from datetime import datetime
import numpy as np

def random_invert(image, p):
    """
    Inverts a given percentage of pixels in an image.

    Parameters:
    image (PIL.Image): The image to be inverted.
    p (float): The percentage of pixels to invert.

    Returns:
    numpy.ndarray: The inverted image.
    """
    # Convert the image to a NumPy array
    image = np.array(image)

    flattened_image = image.flatten()
    num_pixels = flattened_image.shape[0]

    # Randomly select % of the pixel indices
    num_pixels_to_invert = int(p * num_pixels)
    pixels_to_invert_indices = np.random.choice(num_pixels, num_pixels_to_invert, replace=False)

    # Invert the selected pixels
    flattened_image[pixels_to_invert_indices] = 255 - flattened_image[pixels_to_invert_indices]

    inverted_image = flattened_image.reshape(image.shape)
    return inverted_image


def create_folder_by_date(base_folder):
    """
    Creates a new folder with the current date as its name in the specified base folder.

    Parameters:
    base_folder (str): The path to the base folder where the new folder will be created.

    Returns:
    str: The path to the newly created folder.
    """
    current_date = datetime.now().strftime('%Y-%m-%d')
    new_folder_path = os.path.join(base_folder, current_date)

    if not os.path.exists(new_folder_path):
        os.makedirs(new_folder_path)

    return new_folder_path


class PrintToFile:
    """
    A class used to redirect the standard output to a file.

    Attributes:
    filename (str): The name of the file where the output will be written.

    Methods:
    start(): Starts redirecting the output to the file.
    stop(): Stops redirecting the output to the file.
    write(text): Writes the specified text to the file and the standard output.
    flush(): Does nothing. Implemented to maintain compatibility with sys.stdout.
    """
    def __init__(self, filename):
        self.stdout = sys.stdout
        self.filename = filename

    def start(self):
        """Starts redirecting the output to the file."""
        sys.stdout = self

    def stop(self):
        """Stops redirecting the output to the file."""
        sys.stdout = self.stdout

    def write(self, text):
        """
        Writes the specified text to the file and the standard output.

        Parameters:
        text (str): The text to be written.
        """
        with open(self.filename, 'a') as file:
            file.write(text)
        self.stdout.write(text)

    def flush(self):
        """Does nothing. Implemented to maintain compatibility with sys.stdout."""
        pass

class CNNModel200k(nn.Module):
    """
    A class used to represent a Convolutional Neural Network (CNN) with approximately 200k parameters.

    Attributes:
    conv1 (nn.Conv2d): The first convolutional layer.
    conv2 (nn.Conv2d): The second convolutional layer.
    relu (nn.ReLU): The ReLU activation function.
    maxpool (nn.MaxPool2d): The max pooling layer.
    dropout (nn.Dropout): The dropout layer.
    fc1 (nn.Linear): The first fully connected layer.
    fc2 (nn.Linear): The second fully connected layer.

    Methods:
    forward(x): Defines the forward pass of the CNN.
    """
    def __init__(self):
        super(CNNModel200k, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(7 * 7 * 32, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        """
         Defines the forward pass of the CNN.

         Parameters:
         x (torch.Tensor): The input to the CNN.

         Returns:
         torch.Tensor: The output of the CNN.
         """
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(-1, 7 * 7 * 32)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def evaluate_model(model, criterion, test_loader, device):
    """
    Evaluates the performance of a model on a test dataset.

    Parameters:
    model (nn.Module): The model to be evaluated.
    criterion (nn.Module): The loss function.
    test_loader (torch.utils.data.DataLoader): The test dataset.
    device (torch.device): The device where the tensors will be allocated.

    Returns:
    float: The accuracy of the model on the test dataset.
    float: The average loss of the model on the test dataset.
    """
    model.eval()
    correct = 0
    total = 0
    valid_loss = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            valid_loss += loss.item()

    accuracy = 100 * correct / total
    avg_valid_loss = valid_loss / len(test_loader)
    return accuracy, avg_valid_loss