import argparse
import time
import warnings

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms, models

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn import metrics
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score


def get_best_params(algorithm):
    if algorithm == "Logistic_regression":
        best_params = {
            "C":0.0001,
            "max_iter":50,
            "multi_class":'multinomial',
            "solver":'lbfgs'
        }
    elif algorithm == "SVM":
        best_params = {
            "C":1,
            "max_iter":50,
            "kernel":'rbf',
            "probability":True
        }
    elif algorithm == "MLP":
        best_params = {
            "hidden_layer_sizes": (256,),
            "max_iter": 100,
            "activation":'logistic'
        }
    else:
        raise NotImplementedError("Not implemented yet!")
    return best_params


def get_final_model(algorithm, best_params):
    if algorithm == "Logistic_regression":
        model = LogisticRegression(
            C=best_params['C'],
            max_iter=best_params['max_iter'],
            multi_class='multinomial',
            solver='lbfgs'
        )
    elif algorithm == "SVM":
        model = SVC(
            C=best_params['C'],
            max_iter=best_params['max_iter'],
            kernel=best_params['kernel'],
            probability=True
        )
    elif algorithm == "MLP":
        print(": ", best_params['hidden_layer_sizes'])
        model = MLPClassifier(
            hidden_layer_sizes=best_params['hidden_layer_sizes'],
            max_iter=best_params['max_iter'],
            activation=best_params['activation']
        )
    else:
        raise NotImplementedError("Not implemented yet!")
    return model


def perform_cross_validation(X_train, y_train, algorithm):
    if algorithm == "Logistic_regression":
        best_params = cross_validation_logistic_regression(X_train, y_train)
    elif algorithm == "SVM":
        best_params = cross_validation_SVM(X_train, y_train)
    elif algorithm == "MLP":
        best_params = cross_validation_MLP(X_train, y_train)
    else:
        raise NotImplementedError("Not implemented yet!")
    return best_params


def cross_validation_MLP(X_train, y_train):
    # Cross validation Define the parameter grid
    param_grid = {
        'hidden_layer_sizes': [(100,), (100,100), (100, 100, 100), (200)], #
        'activation': ['relu', 'logistic'], #
        'max_iter': [100, 200, 300] #
    }

    best_score = -1
    best_params = {}
    cv = KFold(n_splits=5, shuffle=True, random_state=42)

    # 2. Loop through each parameter combination
    for hidd_layer_val in param_grid['hidden_layer_sizes']:
        for max_iter_val in param_grid['max_iter']:
            for act_val in param_grid['activation']:
                # Store the scores for the current parameter combination
                fold_scores = []
                # 3. Perform cross-validation
                for train_index, val_index in cv.split(X_train):
                    # Split the data into training and validation sets
                    X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
                    y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]
                    # Create and train the model
                    model = MLPClassifier(
                        hidden_layer_sizes=hidd_layer_val,
                        max_iter=max_iter_val,
                        activation=act_val
                    )
                    model.fit(X_train_fold, y_train_fold)
                    # Evaluate the model on the validation fold
                    y_pred_val = model.predict(X_val_fold)
                    score = accuracy_score(y_val_fold, y_pred_val)
                    fold_scores.append(score)
                # 4. Calculate the average score for this parameter combination
                mean_score = np.mean(fold_scores)
                # print(f"Params: C={c_val}, max_iter={max_iter_val}, kernel={ker_val} | Avg Score: {mean_score:.4f}")
                # 5. Update best parameters if the current score is better
                if mean_score > best_score:
                    best_score = mean_score
                    best_params = {'hidden_layer_sizes': hidd_layer_val, 'max_iter': max_iter_val, 'activation': act_val}
    print("\n--- Manual Grid Search Results ---")
    print(f"Best hyperparameters: {best_params}")
    print(f"Best cross-validation score: {best_score:.4f}")
    return best_params


def cross_validation_SVM(X_train, y_train):
    # Cross validation Define the parameter grid
    param_grid = {
        'C': [0.0001, 0.001,  0.01, 0.1, 1, 10, 100], # ,
        'kernel': ['rbf', 'poly', 'sigmoid'], #
        'max_iter': [10, 20, 50, 100] #
    }

    best_score = -1
    best_params = {}
    cv = KFold(n_splits=5, shuffle=True, random_state=42)

    # 2. Loop through each parameter combination
    for c_val in param_grid['C']:
        for max_iter_val in param_grid['max_iter']:
            for ker_val in param_grid['kernel']:
                # Store the scores for the current parameter combination
                fold_scores = []
                # 3. Perform cross-validation
                for train_index, val_index in cv.split(X_train):
                    # Split the data into training and validation sets
                    X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
                    y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]
                    # Create and train the model
                    model = SVC(
                        C=c_val,
                        max_iter=max_iter_val,
                        kernel=ker_val
                    )
                    model.fit(X_train_fold, y_train_fold)
                    # Evaluate the model on the validation fold
                    y_pred_val = model.predict(X_val_fold)
                    score = accuracy_score(y_val_fold, y_pred_val)
                    fold_scores.append(score)
                # 4. Calculate the average score for this parameter combination
                mean_score = np.mean(fold_scores)
                # print(f"Params: C={c_val}, max_iter={max_iter_val}, kernel={ker_val} | Avg Score: {mean_score:.4f}")
                # 5. Update best parameters if the current score is better
                if mean_score > best_score:
                    best_score = mean_score
                    best_params = {'C': c_val, 'max_iter': max_iter_val, 'kernel': ker_val}

    print("\n--- Manual Grid Search Results ---")
    print(f"Best hyperparameters: {best_params}")
    print(f"Best cross-validation score: {best_score:.4f}")
    return best_params


def cross_validation_logistic_regression(X_train, y_train):
    # Cross validation Define the parameter grid
    param_grid = {
        'C': [0.1, 1, 10, 100], # 0.0001, 0.001, 0.01
        'solver': ['lbfgs'],
        'max_iter': [100, 200] # 20, 50,
    }

    best_score = -1
    best_params = {}
    cv = KFold(n_splits=5, shuffle=True, random_state=42)

    # 2. Loop through each parameter combination
    for c_val in param_grid['C']:
        for max_iter_val in param_grid['max_iter']:
            # Store the scores for the current parameter combination
            fold_scores = []
            # 3. Perform cross-validation
            for train_index, val_index in cv.split(X_train):
                # Split the data into training and validation sets
                X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
                y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]
                # Create and train the model
                model = LogisticRegression(
                    C=c_val,
                    max_iter=max_iter_val,
                    multi_class='multinomial',
                    solver='lbfgs'
                )
                model.fit(X_train_fold, y_train_fold)
                # Evaluate the model on the validation fold
                y_pred_val = model.predict(X_val_fold)
                score = accuracy_score(y_val_fold, y_pred_val)
                fold_scores.append(score)
            # 4. Calculate the average score for this parameter combination
            mean_score = np.mean(fold_scores)
            # print(f"Params: C={c_val}, max_iter={max_iter_val} | Avg Score: {mean_score:.4f}")
            # 5. Update best parameters if the current score is better
            if mean_score > best_score:
                best_score = mean_score
                best_params = {'C': c_val, 'max_iter': max_iter_val}

    print("\n--- Manual Grid Search Results ---")
    print(f"Best hyperparameters: {best_params}")
    print(f"Best cross-validation score: {best_score:.4f}")

    # 6. Train the final model with the best parameters on the full training set
    # final_model = LogisticRegression(
    #   C=best_params['C'],
    #  max_iter=best_params['max_iter'],
    # multi_class='multinomial',
    # solver='lbfgs'
    # )
    return best_params


def write_metrics_to_file(file_path, args, argmin_accuracy, argmax_accuracy, best_params):
    """
    Writes a summary of model training and performance metrics to a specified file.

    Args:
        file_path (str): The path to the file where the output will be written.
        args (argparse.Namespace): An object containing command-line arguments,
                                   including dataset, noise_mode, symmetric_flip_prob,
                                   feature_type, and encoder_name.
        argmin_accuracy (float): The accuracy of the model using argmin prediction.
        argmax_accuracy (float): The accuracy of the model using argmax prediction.
    """
    with open(file_path, 'a') as f:
        f.write("\n")
        f.write(
            f"================ {args.dataset} with {args.noise_mode} (symmetric flip prob = {args.symmetric_flip_prob}) ========================\n")
        f.write(f"Linear model + feature by {args.feature_type}, encoder: {args.encoder_name}\n")
        f.write(f"Best params: {best_params}")
        f.write("\n")
        f.write(f"Accuracy on test set with argmin prediction: {argmin_accuracy:.4f}\n")
        f.write(f"Accuracy on test set with argmax prediction: {argmax_accuracy:.4f}\n")
        f.write("\n")


def generate_symmetric_noise_matrix(num_classes, noise_rate):
    """
    Generates a transition matrix for symmetric label noise.

    Args:
        num_classes (int): The number of classes in the dataset.
        noise_rate (float): The probability that a true label flips to a different,
                            random label. Must be between 0 and 1.

    Returns:
        numpy.ndarray: A square transition matrix of shape (num_classes, num_classes).
    """
    if not (0 <= noise_rate <= 1):
        raise ValueError("Noise rate must be between 0 and 1.")

    # Initialize a matrix with zeros
    matrix = np.zeros((num_classes, num_classes))

    # Calculate the probability of a label staying the same
    # This is 1 - the total probability of it flipping to *any* other class.
    prob_correct = 1 - noise_rate

    # Calculate the probability of a label flipping to a specific other class
    # The noise is distributed equally among all other classes.
    if num_classes > 1:
        prob_flip = noise_rate / (num_classes - 1)
    else:
        prob_flip = 0

    # Fill the matrix
    for i in range(num_classes):
        for j in range(num_classes):
            if i == j:
                # Diagonal elements represent the probability of the label being correct
                matrix[i, j] = prob_correct
            else:
                # Off-diagonal elements represent the probability of a flip
                matrix[i, j] = prob_flip

    return matrix



def generate_anti_diag_noise_matrix(num_classes, noise_rate):
    """
    Generates a transition matrix for symmetric label noise.

    Args:
        num_classes (int): The number of classes in the dataset.
        noise_rate (float): The probability that a true label flips to a different,
                            random label. Must be between 0 and 1.

    Returns:
        numpy.ndarray: A square transition matrix of shape (num_classes, num_classes).
    """
    if not (0 <= noise_rate <= 1):
        raise ValueError("Noise rate must be between 0 and 1.")

    if noise_rate == 0.99 and num_classes == 10:
        matrix = np.array(
            [[0.01, 0.09, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11],
            [0.13, 0.01, 0.09, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11],
            [0.11, 0.11, 0.01, 0.08, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11],
            [0.11, 0.11, 0.08, 0.01, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11],
            [0.11, 0.11, 0.11, 0.11, 0.01, 0.13, 0.09, 0.11, 0.11, 0.11],
            [0.11, 0.11, 0.11, 0.11, 0.11, 0.01, 0.09, 0.13, 0.11, 0.11],
            [0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.01, 0.09, 0.11, 0.11],
            [0.11, 0.11, 0.11, 0.11, 0.11, 0.09, 0.13, 0.01, 0.11, 0.11],
            [0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.07, 0.01, 0.15],
            [0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.07, 0.15, 0.01]]
        )
    elif noise_rate == 0.97 and num_classes == 10:
        matrix = np.array(
            [[0.03, 0.117778, 0.097778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778],
            [0.117778, 0.03, 0.097778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778],
            [0.107778, 0.087778, 0.03, 0.127778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778],
            [0.107778, 0.107778, 0.097778, 0.03, 0.117778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778],
            [0.107778, 0.107778, 0.107778, 0.097778, 0.03, 0.117778, 0.107778, 0.107778, 0.107778, 0.107778],
            [0.107778, 0.107778, 0.107778, 0.107778, 0.087778, 0.03, 0.127778, 0.107778, 0.107778, 0.107778],
            [0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.127778, 0.03, 0.087778, 0.107778, 0.107778],
            [0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.087778, 0.03, 0.127778, 0.107778],
            [0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.087778, 0.03, 0.127778],
            [0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.107778, 0.127778, 0.087778, 0.03]]
        )
    elif noise_rate == 0.95 and num_classes == 10:
        matrix = np.array(
            [[0.05, 0.115556, 0.095556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556],
            [0.095556, 0.05, 0.115556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556],
            [0.105556, 0.105556, 0.05, 0.125556, 0.085556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556],
            [0.105556, 0.105556, 0.125556, 0.05, 0.085556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556],
            [0.105556, 0.105556, 0.105556, 0.105556, 0.05, 0.135556, 0.075556, 0.105556, 0.105556, 0.105556],
            [0.105556, 0.105556, 0.105556, 0.105556, 0.075556, 0.05, 0.135556, 0.105556, 0.105556, 0.105556],
            [0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.05, 0.145556, 0.065556, 0.105556],
            [0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.065556, 0.145556, 0.05, 0.105556, 0.105556],
            [0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.095556, 0.05, 0.115556],
            [0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.105556, 0.115556, 0.095556, 0.05]]
        )
    elif noise_rate == 0.93 and num_classes == 10:
        matrix = np.array(
            [[0.07, 0.113333, 0.093333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333],
            [0.123333, 0.07, 0.093333, 0.093333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333],
            [0.103333, 0.133333, 0.07, 0.093333, 0.093333, 0.093333, 0.103333, 0.103333, 0.103333, 0.103333],
            [0.103333, 0.103333, 0.143333, 0.07, 0.903333, 0.903333, 0.903333, 0.903333, 0.103333, 0.103333],
            [0.103333, 0.103333, 0.103333, 0.113333, 0.07, 0.093333, 0.103333, 0.103333, 0.103333, 0.103333],
            [0.103333, 0.103333, 0.103333, 0.103333, 0.113333, 0.07, 0.093333, 0.103333, 0.103333, 0.103333],
            [0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.113333, 0.07, 0.093333, 0.103333, 0.103333],
            [0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.123333, 0.07, 0.093333, 0.093333],
            [0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.093333, 0.07, 0.113333],
            [0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.103333, 0.113333, 0.093333, 0.07]]
        )
    elif noise_rate == 0.91 and num_classes == 10:
        matrix = np.array(
            [[0.09, 0.111111, 0.095111, 0.095111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111],
            [0.111111, 0.09, 0.095111, 0.095111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111],
            [0.101111, 0.111111, 0.09, 0.095111, 0.095111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111],
            [0.101111, 0.101111, 0.111111, 0.09, 0.095111, 0.095111, 0.101111, 0.101111, 0.101111, 0.101111],
            [0.101111, 0.101111, 0.101111, 0.111111, 0.09, 0.095111, 0.095111, 0.101111, 0.101111, 0.101111],
            [0.101111, 0.101111, 0.101111, 0.101111, 0.111111, 0.09, 0.095111, 0.095111, 0.101111, 0.101111],
            [0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.111111, 0.09, 0.095111, 0.095111, 0.101111],
            [0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.111111, 0.09, 0.095111, 0.095111],
            [0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.095111, 0.111111, 0.09, 0.095111],
            [0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.101111, 0.111111, 0.095111, 0.095111, 0.09]]
        )
    else:
        raise NotImplementedError("not implemented T yet")
    return matrix

def normalize_rows_to_one(matrix: np.ndarray) -> np.ndarray:
    """
    Normalizes a matrix such that the sum of each row is 1.

    Args:
        matrix (np.ndarray): The matrix to normalize.

    Returns:
        np.ndarray: The row-normalized matrix.
    """
    # Check if the input is a NumPy array and has at least two dimensions
    if not isinstance(matrix, np.ndarray) or matrix.ndim < 2:
        raise ValueError("Input must be a 2D or higher NumPy array.")

    # Calculate the sum of each row.
    # The sum is kept as a 2D array with shape (n, 1) for broadcasting.
    row_sums = matrix.sum(axis=1, keepdims=True)

    # To avoid division by zero, replace any zero sums with 1.
    # This keeps rows with all zeros as a zero row.
    row_sums[row_sums == 0] = 1

    # Divide each element in the matrix by its corresponding row sum.
    normalized_matrix = matrix / row_sums
    row_sums = np.sum(normalized_matrix, axis=1)
    print("row_sums: ", row_sums)
    return normalized_matrix


def is_valid_transition_matrix(matrix: np.ndarray) -> bool:
    """
    Checks if a given matrix is a valid transition matrix.

    A matrix is a valid transition matrix if:
    1. All elements are non-negative (>= 0).
    2. The sum of the elements in each row is equal to 1.

    Args:
        matrix (np.ndarray): The matrix to check.
        tolerance (float): The allowed tolerance for the sum of rows due to
                           floating-point arithmetic.

    Returns:
        bool: True if the matrix is a valid transition matrix, False otherwise.
    """
    # Check if the input is a NumPy array
    if not isinstance(matrix, np.ndarray):
        print("Error: Input is not a NumPy array.")
        return False

    # Check if all elements are non-negative
    if np.any(matrix < 0):
        print("Error: Matrix contains negative values.")
        return False

    # Check if each row sums to 1 (within a small tolerance)
    row_sums = np.sum(matrix, axis=1)
    print("row_sums: ", row_sums)
    if not np.allclose(row_sums, 1.0, atol=0):
        print(f"Error: Not all rows sum to 1.0. Row sums are: {row_sums}")
        return False

    return True

# Define command-line arguments
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

## load noisy training labels
#if args.dataset == 'cifar10':
#    noise_file_path = './data/CIFAR-10_human.pt'
#elif args.dataset == 'cifar100':
#    noise_file_path = './data/CIFAR-100_human.pt'
#noise_file = torch.load(noise_file_path)
#if args.noise_mode == 'clean_label':
#    y_train = noise_file['clean_label']
#elif args.noise_mode == 'aggre_label':
#    y_train = noise_file['aggre_label']
#elif args.noise_mode == 'rand_1_label':
#    y_train = noise_file['random_label1']
#elif args.noise_mode == 'rand_2_label':
#    y_train = noise_file['random_label2']
#elif args.noise_mode == 'rand_3_label':
#    y_train = noise_file['random_label3']
#elif args.noise_mode == 'worst_label':
#    if args.dataset == 'cifar10':
#        y_train = noise_file['worse_label']
#    elif args.dataset == 'cifar100':
#        y_train = noise_file['noisy_label']

