import numpy as np
import pandas as pd
import tensor.tensor_product_wrapper as tp
from data.utils import rescale, normalize
import matplotlib.pyplot as plt
from utils import similarity_metrics as sm
import seaborn as sns
from tqdm import tqdm
import torch
from geospca import geospca_solver

def projection(A, U, prod_type='t'):
    training_coeff = tp.ten_prod(tp.ten_tran(U, prod_type=prod_type), A, prod_type=prod_type)
    return tp.ten_prod(U, training_coeff, prod_type=prod_type)

def normalize(data):
    min_data = np.min(data, axis=0)
    max_data = np.max(data, axis=0)
    diff = max_data - min_data
    diff[diff == 0] = 1
    return (data - min_data) / diff

def load_csv_data(file_path, selected_labels, num_train=100, num_test=9):
    df = pd.read_csv(file_path)
    labels = df.iloc[:, 0].values
    data = df.iloc[:, 1:].values
    data = normalize(data / 255.0)
    data = data.reshape(-1, IMAGE_HEIGHT, IMAGE_WIDTH)
    selected_train_data, selected_train_labels = [], []
    selected_test_data, selected_test_labels = [], []
    for label in selected_labels:
        class_indices = np.where(labels == label)[0]
        if len(class_indices) < num_train + num_test:
            raise ValueError(
                f"Insufficient samples for class {label}: at least {num_train + num_test} samples are required. Please verify the dataset.")
        selected_train_indices = class_indices[:num_train]
        selected_test_indices = class_indices[num_train:num_train + num_test]
        selected_train_data.append(data[selected_train_indices])
        selected_train_labels.append(labels[selected_train_indices])
        selected_test_data.append(data[selected_test_indices])
        selected_test_labels.append(labels[selected_test_indices])

    train_data = np.vstack(selected_train_data)
    train_labels = np.concatenate(selected_train_labels)
    test_data = np.vstack(selected_test_data)
    test_labels = np.concatenate(selected_test_labels)

    return train_data, train_labels, test_data, test_labels



def geospca_per_class_fixed_k(data, labels, label_list, geospca_solver,
                              nc=20, fixed_k=50, epsilon=0.001, maxiter=1000,
                              pad_mode='wrap'):
    new_data = []
    new_labels = []
    for label in label_list:
        class_data = data[labels == label]
        A_geospca = class_data.reshape(class_data.shape[0], -1)
        A_geospca = torch.tensor(A_geospca, dtype=torch.float32)
        result = geospca_solver(A_geospca, nc=nc, k=fixed_k, epsilon=epsilon, maxiter=maxiter)
        selected_indices = np.array(result["Bindices"])
        selected_indices = selected_indices % class_data.shape[0]
        if len(selected_indices) > fixed_k:
            selected_indices = selected_indices[:fixed_k]
        elif len(selected_indices) < fixed_k:
            pad_size = fixed_k - len(selected_indices)
            selected_indices = np.pad(selected_indices, (0, pad_size), mode=pad_mode)
        class_data_selected = class_data[selected_indices, :, :]
        class_labels_selected = np.full((fixed_k,), label)
        new_data.append(class_data_selected)
        new_labels.append(class_labels_selected)

    new_data = np.concatenate(new_data, axis=0)
    new_labels = np.concatenate(new_labels, axis=0)

    return new_data, new_labels

np.random.seed(20)



IMAGE_HEIGHT = 28
IMAGE_WIDTH = 28
SELECTED_LABELS = list(range(10))
SAMPLES_PER_CLASS = 500
NUM_TEST_SAMPLES = 80
train_file = "data/mnist_train.csv"
test_file = "data/mnist_test.csv"

training_data, training_labels, test_data, test_labels = load_csv_data(
    train_file, SELECTED_LABELS,SAMPLES_PER_CLASS,NUM_TEST_SAMPLES
)

training_data, training_labels = geospca_per_class_fixed_k(
    training_data,
    training_labels,
    label_list=SELECTED_LABELS,
    geospca_solver=geospca_solver,
    nc=25,
    fixed_k=50,
    epsilon=0.001,
    maxiter=1000
)

print("New training data shape:", training_data.shape)

training_data = training_data.transpose(1, 0, 2)
test_data = test_data.transpose(1, 0, 2)
training_data = normalize(rescale(training_data))
test_data = normalize(rescale(test_data))


# visualize data
# for i in range(len(SELECTED_LABELS)):
#     slice_subplots(test_data[:, test_labels == i, :], axis=1, title='class ' + str(i))
#     plt.show()

U = []
for i in tqdm(range(len(SELECTED_LABELS)), desc="Computing projection errors"):
    u, _, _, _ = tp.ten_svd(training_data[:, training_labels == i], k=2)
    U.append(u)


# visualizations
# for i in range(len(SELECTED_LABELS)):
#     slice_subplots(U[i], axis=1, title='First two sample-mode basis elements for class ' + str(i))
#     plt.savefig(f'pic/feature_{i}.png', bbox_inches='tight')
#     plt.close()


#%% Compute results
training_error = np.zeros([len(SELECTED_LABELS), training_data.shape[1]])
test_error = np.zeros([len(SELECTED_LABELS), test_data.shape[1]])

for i in tqdm(range(len(SELECTED_LABELS)), desc="Computing projection errors"):
    training_projection = projection(training_data, U[i])
    training_error[i, :] = sm.frobenius_metric(training_data, training_projection, axis=1)
    test_projection = projection(test_data, U[i])
    test_error[i, :] = sm.frobenius_metric(test_data, test_projection, axis=1)

# classification
training_predicted_classes = np.argmin(training_error, axis=0).reshape(-1, 1)
test_predicted_classes = np.argmin(test_error, axis=0).reshape(-1, 1)

# results
training_num_correct = np.sum(training_predicted_classes == training_labels.reshape(-1, 1), axis=0)
training_accuracy = training_num_correct / training_data.shape[1]
test_num_correct = np.sum(test_predicted_classes == test_labels.reshape(-1, 1), axis=0)
test_accuracy = test_num_correct / test_data.shape[1]

print('k = %d: train accuracy = %0.2f\ttest accuracy = %0.2f' %
      (2, 100 * training_accuracy.item(), 100 * test_accuracy.item()))


#%% Confusion matrix



conf_mat = sm.confusion_matrix(test_predicted_classes, test_labels)
print(conf_mat)
correct = np.trace(conf_mat)
total = np.sum(conf_mat)
accuracy = correct / total * 100
print(f"Overall Accuracy: {accuracy:.2f}%")
conf_mat_normalized = conf_mat.astype(np.float32)
conf_mat_normalized = conf_mat_normalized / conf_mat_normalized.sum(axis=1, keepdims=True)

plt.figure(figsize=(6, 5))
sns.heatmap(conf_mat_normalized, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=SELECTED_LABELS, yticklabels=SELECTED_LABELS)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Normalized Test Confusion Matrix")
plt.show()

