import os
import sys
import time
import pandas as pd
import argparse
import numpy as np
import torch
import csv
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GroupShuffleSplit
from utils import (
    save_plot_with_timestamp,
    sliding_window,
    load_dataset_NIRS,
    train_test_split,
    is_file_empty_or_nonexistent,
)


def classify_CNN_Affective_Individual_Task_NIRS(
    path, num_epochs, batch_size, learning_rate, window_size, gpu, offset
):
    if gpu != "cuda:1":
        # Set the device to be used for training
        gpu = "cuda:0"

    use_wavelet = False

    # Create the output folder if it doesn't exist
    output_folder = "output"
    if not os.path.exists(output_folder):
        try:
            os.makedirs(output_folder)
        except OSError as e:
            print(f"Error creating output folder: {e}")
            return

    # File path for the csv to log results
    csv_file_path = os.path.join(output_folder, "results.csv")

    # Check if the file is empty or doesn't exist
    if is_file_empty_or_nonexistent(csv_file_path):
        # Write headers to the csv
        # Headers for the csv
        headers = [
            "Datetime",
            "modality",
            "CV_method",
            "window_overlap",
            "window_size",
            "look_at",
            "offset",
            "num_epochs",
            "batch_size",
            "learning_rate",
            "valence_accuracy",
            "valence_std_dev",
            "arousal_accuracy",
            "arousal_std_dev",
            "total_loss",
            "total_loss_std_dev",
            "baseline_arousal_accuracy",
            "baseline_arousal_accuracy_std_dev",
            "baseline_valence_accuracy",
            "baseline_valence_accuracy_std_dev",
            "cm_path_arousal",
            "cm_path_valence",
        ]
        with open(csv_file_path, "w", newline="") as file:
            writer = csv.writer(file)
            writer.writerow(headers)

    # Load dataset
    merged_df = load_dataset_NIRS(path, offset=offset, window_size=window_size)

    pos = [-2, -1]
    subject_ids = merged_df["subject_id"]
    merged_df = merged_df.drop(["subject_id"], axis=1)

    # Check if CUDA is available
    device = torch.device(gpu if torch.cuda.is_available() else "cpu")

    # Preprocess data
    features = merged_df.iloc[:, : pos[0]].values
    arousal_score = LabelEncoder().fit_transform(
        merged_df.iloc[:, pos[0]] + 2
    )  # Mapping -2 -> 0, -1 -> 1, 0 -> 2, 1 -> 3, 2 -> 4
    valence_score = LabelEncoder().fit_transform(
        merged_df.iloc[:, pos[1]] + 2
    )  # Same mapping for valence_score
    targets = list(zip(arousal_score, valence_score))

    # Get images from sliding window
    look_back = window_size

    window_overlap_str = True
    features, valence, arousal = sliding_window(
        features,
        valence_score,
        arousal_score,
        subject_ids,
        "nirs",
        use_wavelet,
        look_back=look_back,
    )

    targets = list(zip(valence, arousal))

    # Hyperparameters
    input_size = features.shape[1:]
    num_classes = 5  # Classes representing -2, -1, 0, 1, 2
    num_folds = 5

    # Create DataLoaders
    dataset = TensorDataset(
        torch.tensor(features).float().to(device),
        torch.tensor(targets).long().to(device),
    )

    kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)

    (
        fold_losses,
        fold_accuracies,
        all_true_arousal,
        all_pred_arousal,
        all_true_valence,
        all_pred_valence,
        average_chance_accuracy_arousal,
        average_chance_accuracy_valence,
        stdev_chance_accuracy_arousal,
        stdev_chance_accuracy_valence,
    ) = train_test_split(
        kfold,
        dataset,
        num_folds,
        num_epochs,
        batch_size,
        input_size,
        num_classes,
        time,
        tqdm,
        Subset,
        DataLoader,
        device,
        learning_rate,
        "nirs",
    )

    # Print average accuracy and standard deviation across folds
    arousal_accuracies, valence_accuracies = zip(*fold_accuracies)

    print("Arousal Accuracies for Each Fold:")
    for i, accuracy in enumerate(arousal_accuracies, start=1):
        print(f"Fold {i}: {accuracy}")

    print("\nValence Accuracies for Each Fold:")
    for i, accuracy in enumerate(valence_accuracies, start=1):
        print(f"Fold {i}: {accuracy}")

    print("Average accuracy for arousal_score:", np.mean(arousal_accuracies))
    print("Standard deviation for arousal_score:", np.std(arousal_accuracies))
    print("Average accuracy for valence_score:", np.mean(valence_accuracies))
    print("Standard deviation for valence_score:", np.std(valence_accuracies))

    # Print the average loss per fold.
    print(f"Average loss per fold: {np.mean(fold_losses)}")
    print(f"Standard deviation of loss per fold: {np.std(fold_losses)}")

    arousal_cm = confusion_matrix(all_true_arousal, all_pred_arousal)
    valence_cm = confusion_matrix(all_true_valence, all_pred_valence)

    # Define the class names (assuming -2 to 2 for arousal and valence scores)
    class_names = [-2, -1, 0, 1, 2]

    subject_holdout_str = "regular kfold"

    # Plotting confusion matrix for arousal
    plt.figure(figsize=(20, 14))
    sns.heatmap(
        arousal_cm,
        annot=True,
        fmt="d",
        xticklabels=class_names,
        yticklabels=class_names,
        cmap="Blues",
        annot_kws={"size": 16},
    )
    plt.title(
        f"NIRS-CNN: Confusion Matrix for Arousal\n Sliding window overlap: {window_overlap_str}, Holdout method: {subject_holdout_str} , Window size: {look_back}, Batch Size: {batch_size}, Learning Rate: {learning_rate}, Epochs: {num_epochs}, Accuracy: {np.mean(arousal_accuracies):.2f}%, std: {np.std(arousal_accuracies):.2f}%, loss: {np.mean(fold_losses):.2f}, std: {np.std(fold_losses):.2f}"
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    confusion_matrix_arousal_file_path = save_plot_with_timestamp(
        plt, "confusion_matrix_arousal", output_folder
    )

    # Plotting confusion matrix for valence
    plt.figure(figsize=(20, 14))
    sns.heatmap(
        valence_cm,
        annot=True,
        fmt="d",
        xticklabels=class_names,
        yticklabels=class_names,
        cmap="Blues",
        annot_kws={"size": 16},
    )
    plt.title(
        f"NIRS-CNN: Confusion Matrix for Valence\n, Sliding window overlap: {window_overlap_str}, Holdout method: {subject_holdout_str}, Window size: {look_back}, Batch Size: {batch_size}, Learning Rate: {learning_rate}, Epochs: {num_epochs}, Accuracy: {np.mean(valence_accuracies):.2f}%, std: {np.std(valence_accuracies):.2f}%, loss: {np.mean(fold_losses):.2f}, std: {np.std(fold_losses):.2f}"
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    confusion_matrix_valence_file_path = save_plot_with_timestamp(
        plt, "confusion_matrix_valence", output_folder
    )

    # Write results to csv
    with open(csv_file_path, "a", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(
            [
                time.strftime("%Y-%m-%d %H:%M:%S"),
                "NIRS",
                subject_holdout_str,
                window_overlap_str,
                window_size,
                window_size * 2,
                offset,
                num_epochs,
                batch_size,
                learning_rate,
                np.mean(valence_accuracies),
                np.std(valence_accuracies),
                np.mean(arousal_accuracies),
                np.std(arousal_accuracies),
                np.mean(fold_losses),
                np.std(fold_losses),
                average_chance_accuracy_arousal,
                stdev_chance_accuracy_arousal,
                average_chance_accuracy_valence,
                stdev_chance_accuracy_valence,
                confusion_matrix_arousal_file_path,
                confusion_matrix_valence_file_path,
            ]
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Post experiment script for xdf to csv file conversion"
    )
    parser.add_argument(
        "--p",
        required=True,
        help="Path to the directory with the derived affective task data",
    )

    parser.add_argument(
        "--num_epochs", type=int, default=100, help="Number of epochs for training"
    )
    parser.add_argument(
        "--batch_size", type=int, default=256, help="Batch size for training"
    )
    parser.add_argument(
        "--learning_rate", type=float, default=0.001, help="Learning rate for optimizer"
    )

    parser.add_argument(
        "--gpu", default="cuda:1", required=False, type=str, help="cuda:0 or cuda:1"
    )

    parser.add_argument(
        "--window_size", type=int, default=10, help="Use subject holdout for CV"
    )

    parser.add_argument(
        "--offset",
        type=int,
        default=False,
        help="Offset the signal by number of samples. fNIRS: 1 sec = 10 samples",
    )

    args = parser.parse_args()
    path = args.p
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    window_size = args.window_size
    gpu = args.gpu
    offset = args.offset

    sys.exit(
        classify_CNN_Affective_Individual_Task_NIRS(
            path, num_epochs, batch_size, learning_rate, window_size, gpu, offset
        )
    )
