import os
import numpy as np
import pickle
import torch
import shutil
from typing import Any, Dict, Optional, List


movementEst_norm_data = {}


def normalize_movementEst_data(
    data_list: list[torch.Tensor], name: str
) -> list[torch.Tensor]:
    data = torch.cat(data_list)
    min_value = torch.min(data)
    max_value = torch.max(data)

    normed_data_list = [
        (tensor - min_value) / (max_value - min_value) for tensor in data_list
    ]

    movementEst_norm_data[name] = {
        "min": min_value,
        "max": max_value,
    }

    return normed_data_list


def denormalize_movementEst_data(
    data_list: list[torch.Tensor], name: str
) -> list[torch.Tensor]:
    if not name in movementEst_norm_data:
        raise KeyError(f"'{name}' variable not normalized before !")

    min_value = movementEst_norm_data[name]["min"]
    max_value = movementEst_norm_data[name]["max"]

    denormed_data_list = [
        tensor * (max_value - min_value) + min_value for tensor in data_list
    ]

    return denormed_data_list


# normalize right and left sensor data
def normalize_sensor_data(X_r, X_l):
    # return normalize_variable(X_r, "X_r"), normalize_variable(X_l, "X_l")
    X = torch.cat((X_r, X_l))
    min_value = torch.min(X)
    max_value = torch.max(X)

    X_r = (X_r - min_value) / (max_value - min_value)
    X_l = (X_l - min_value) / (max_value - min_value)

    return X_r, X_l


# Global dictionary to store normalized values
Norm_Values_dic = {}


# normalize biomechanical variables (angles, moments, GRFs) between 0 and 1
def normalize_variable(Y, name, device="cpu"):
    # Check if name is a string
    if not isinstance(name, str):
        raise ValueError("name should be a string")

    Y = Y.to(device)
    norm_Y = torch.zeros(Y.shape).to(device)
    min_values = []
    max_values = []
    # assuming that the last index represent the variable, check this!
    for j in range(Y.shape[-1]):
        min_values.append(torch.min(Y[:, :, j]).to(device))
        max_values.append(torch.max(Y[:, :, j]).to(device))

        norm_Y[:, :, j] = (Y[:, :, j] - min_values[j]) / (max_values[j] - min_values[j])

    Norm_Values_dic[name] = {
        "min": min_values,
        "max": max_values,
        "shape": Y.shape[1:],
    }  # not batch_size

    return norm_Y


# denormalize biomechanical variables if normalized before, check this as well
def denormalize_variable(Y, name, device="cpu"):
    # check if normalized before
    if not name in Norm_Values_dic:
        raise KeyError(f"'{name}' variable not normalized before !")

    # check if the shape is valid
    if Norm_Values_dic[name]["shape"] != Y.shape[1:]:  # not batch_size
        raise ValueError("Variable shape not correct, mismatch with normalized shape!")

    denorm_Y = torch.zeros(Y.shape).to(device)
    min_values = Norm_Values_dic[name]["min"]
    max_values = Norm_Values_dic[name]["max"]
    # denormalize variables with norm values
    for j in range(Y.shape[-1]):
        denorm_Y[:, :, j] = Y[:, :, j] * (max_values[j] - min_values[j]) + min_values[j]

    return denorm_Y


# first and second derivatives coordinates and speeds (torch autograd compatible)
def take_derivatives(batch_data, device="cpu"):
    # setting nSample and duration from data
    nSample = int(batch_data["subject_data"][0, -1])
    durations = batch_data["subject_data"][:, 0]
    fs = (nSample) / durations

    def derivative(tensor):
        # Calculate differences between consecutive elements
        differences = torch.diff(tensor, dim=1) * fs[:, None]
        # differences = (tensor[:, 1:] - tensor[:, :-1].detach()) * fs[:, None]

        # Create a new tensor for the result
        der_tensor = torch.zeros_like(tensor)

        # Set the values in der_tensor
        der_tensor[:, 1:] = differences

        # Set the first element to the second element
        der_tensor[:, 0] = differences[:, 0]

        return der_tensor

    # coordinates and speeds defining the motion and DoF
    coordinates = torch.cat(
        (batch_data["translation_data"], batch_data["angles"]), dim=-1
    )
    der_coordinates = torch.zeros(coordinates.shape).to(device)

    # taking derivatives for each variable in coordinates
    for j in range(coordinates.shape[-1]):
        der_coordinates[:, :, j] = derivative(coordinates[:, :, j])

    speeds = torch.cat(
        (batch_data["qdot_translation_data"], der_coordinates[:, :, 2:]), dim=-1
    )
    # batch_data['qdot_angles']
    der_speeds = torch.zeros(speeds.shape).to(device)

    # taking derivatives for each variable in speeds
    for j in range(coordinates.shape[-1]):
        der_speeds[:, :, j] = derivative(speeds[:, :, j])

    return der_coordinates, der_speeds


# saves the test data for jupyter visualisation
def save_test_for_visualisation(batch_est, batch_true, indices, save_path):
    for key, value in batch_est.items():
        # check first since these two dont have batch size
        if not key in ["subject_data", "body_constants"]:
            batch_est[key] = batch_est[key][indices].cpu().detach().numpy()
            batch_true[key] = batch_true[key][indices].cpu().detach().numpy()

    # move remaining to numpy and cpu as well
    batch_est["subject_data"] = batch_est["subject_data"].cpu().detach().numpy()
    batch_est["body_constants"] = batch_est["body_constants"].cpu().detach().numpy()
    batch_true["subject_data"] = batch_true["subject_data"].cpu().detach().numpy()
    batch_true["body_constants"] = batch_true["body_constants"].cpu().detach().numpy()

    # Create the folder path
    save_path = os.path.join(save_path, "jupyter_test_vis/")

    # Create the destination directory if it doesn't exist
    os.makedirs(save_path, exist_ok=True)

    # Define the file name
    file_name = "vis_test_data.pkl"

    # Create the full save path
    save_path = os.path.join(save_path, file_name)

    # Save the processed arrays to a file
    data_to_save = {
        "batch_est": batch_est,
        "batch_true": batch_true,
    }

    with open(save_path, "wb") as file:
        pickle.dump(data_to_save, file)

    return save_path  # Return the path to the saved file


# loads the saves test data for jupyter visualisation
def load_saved_batch_data(iCycle):
    # Define the file name (always same inside each experiment)
    file_path = "vis_test_data.pkl"

    with open(file_path, "rb") as file:
        batch_data = pickle.load(file)

    batch_est = batch_data["batch_est"]
    batch_true = batch_data["batch_true"]

    # pick the iCycle from the test batch
    for name, array in batch_est.items():
        batch_est[name] = array[iCycle : iCycle + 1]

    for name, array in batch_true.items():
        batch_true[name] = array[iCycle : iCycle + 1]

    # setting the nSample from data
    nSample = int(batch_true["subject_data"][0, -1])

    # creatig time samples
    duration = batch_true["subject_data"][0, 0]
    t = np.linspace(0, duration, nSample)

    return batch_est, batch_true, t


def create_test_visualisation_jupyter(destination_dir):
    # Define the source file path (current directory or specific path)
    source_file = "notebooks/test_visualisation.ipynb"

    # Create the folder path
    destination_dir = os.path.join(destination_dir, "jupyter_test_vis/")

    # Use shutil.copy to copy the file to the destination directory
    shutil.copy(source_file, destination_dir)
