import random
import numpy as np
import torch
import math
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_squared_error


def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def zscore_2d(arr, axis=1):
    means = np.mean(arr, axis=0)
    stds = np.std(arr, axis=0)
    return (arr - means) / stds


def rescale_to_01(sample):
    min_val = sample.min()
    max_val = sample.max()
    return (sample - min_val) / (max_val - min_val)


def rescale_to_minus1_1(sample):
    min_val = sample.min()
    max_val = sample.max()
    return 2 * (sample - min_val) / (max_val - min_val) - 1


def zscore_by_column(data):
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    std[std == 0] = 1
    return (data - mean) / std


def normalize_to_01(tensor):
    data_min = np.min(tensor)
    data_max = np.max(tensor)
    normalized_tensor = (tensor - data_min) / (data_max - data_min + 1e-8)
    return normalized_tensor, data_min, data_max


class StimulusDataset(Dataset):
    def __init__(self, stimulus_data, transform=None, apply_transform=False):
        self.stimulus_data = stimulus_data
        self.transform = transform
        self.apply_transform = apply_transform

    def __len__(self):
        return len(self.stimulus_data)

    def __getitem__(self, idx):
        stimulus = self.stimulus_data[idx]
        if self.transform is not None and self.apply_transform:
            stimulus = self.transform(stimulus)
        return stimulus


class PairwiseDataset(Dataset):
    def __init__(self, neural_data, stimulus_data, transform=None, apply_transform=False):
        self.neural_data = neural_data
        self.stimulus_data = stimulus_data
        self.transform = transform
        self.apply_transform = apply_transform

    def __len__(self):
        return len(self.stimulus_data)

    def __getitem__(self, idx):
        neural = self.neural_data[idx]
        stimulus = self.stimulus_data[idx]
        if self.transform is not None and self.apply_transform:
            stimulus = self.transform(stimulus)
        return neural, stimulus


def warmup_then_decay_lr(current_step, warmup_steps, total_steps, portion=0.4):
    if current_step < warmup_steps:
        return portion + (1 - portion) * math.sin(0.5 * math.pi * current_step / warmup_steps)
    else:
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + torch.cos(torch.tensor(progress * math.pi)))


def plotting_random_images(input_matrice):
    sample_indices = np.random.choice(len(input_matrice), 9, replace=False)
    plt.figure(figsize=(4, 4))
    for idx, i in enumerate(sample_indices):
        plt.subplot(3, 3, idx + 1)
        plt.imshow(input_matrice[i], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.show()


def plotting_ordering_images(input_matrice):
    sample_indices = np.arange(9)
    plt.figure(figsize=(8, 8))
    for idx, i in enumerate(sample_indices):
        plt.subplot(3, 3, idx + 1)
        plt.imshow(input_matrice[i], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.show()


def get_latents_and_recons(data_loader, model):
    device = next(model.parameters()).device
    latents_list = []
    recons_list = []
    with torch.no_grad():
        for images in data_loader:
            images = images.unsqueeze(1).to(device)
            reconstructed_images, mu, z, logvar = model(images)
            latents_list.append(mu.cpu())
            recons_list.append(reconstructed_images.cpu())
    return latents_list, recons_list


def get_latents_and_recons_ae(data_loader, model):
    device = next(model.parameters()).device
    latents_list = []
    recons_list = []
    with torch.no_grad():
        for images in data_loader:
            images = images.unsqueeze(1).to(device)
            reconstructed_images, mu = model(images)
            latents_list.append(mu.cpu())
            recons_list.append(reconstructed_images.cpu())
    return latents_list, recons_list


def reshape_outputs(full_latents, full_reconstructed_images, latent_size):
    full_reconstructed_images = np.concatenate(full_reconstructed_images, axis=0)
    _, channel, height, width = full_reconstructed_images.shape
    full_latents_reshaped = np.concatenate(full_latents, axis=0).reshape(-1, latent_size)
    full_reconstructed_images_reshaped = full_reconstructed_images.reshape(-1, height, width)
    return full_latents_reshaped, full_reconstructed_images_reshaped


def get_dataloaders(X_train_torch, y_train_torch, X_test_torch, y_test_torch, batch_size=64):
    train_dataset = TensorDataset(X_train_torch, y_train_torch)
    test_dataset = TensorDataset(X_test_torch, y_test_torch)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def evaluate_vae_on_loader(model, dataloader, mode="train", device="cuda"):
    model.eval()
    all_recon_neural, all_recon_stimulus = [], []
    all_neural_data, all_stimulus_data = [], []

    with torch.no_grad():
        for neural_batch, stimulus_batch in dataloader:
            neural_batch = neural_batch.to(device)
            stimulus_batch = stimulus_batch.to(device)
            recon_neural, recon_stimulus, *_ = model(neural_batch)

            all_recon_neural.append(recon_neural.cpu())
            all_neural_data.append(neural_batch.cpu())
            all_recon_stimulus.append(recon_stimulus.cpu())
            all_stimulus_data.append(stimulus_batch.cpu())

    all_neural_data = np.vstack(all_neural_data)
    all_recon_neural = np.vstack(all_recon_neural)
    all_stimulus_data = np.vstack(all_stimulus_data)
    all_recon_stimulus = np.vstack(all_recon_stimulus)

    neural_r2 = r2_score(all_neural_data, all_recon_neural)
    stimulus_r2 = r2_score(all_stimulus_data[:, :-1], all_recon_stimulus)
    neural_rmse = np.sqrt(mean_squared_error(all_neural_data, all_recon_neural))
    stimulus_rmse = np.sqrt(mean_squared_error(all_stimulus_data[:, :-1], all_recon_stimulus))

    print(f"R² Score [{mode}] - Neural: {neural_r2:.4f} | Stimulus: {stimulus_r2:.4f}")
    print(f"RMSE Score [{mode}] - Neural: {neural_rmse:.4f} | Stimulus: {stimulus_rmse:.4f}")

    return neural_r2, stimulus_r2
