from typing import List
import torch
import numpy as np

def linear_interpolation_1d(origin_signal: np.ndarray, origin_sr, final_sr):
    '''
    Given a 1d signal with sample rate `origin_sr`,
    do linear interpolation and reture a new signal with a sample rate of `final_sr`.
    '''
    assert len(origin_signal.shape) == 1, "The signal is not 1-d"
    n_point_origin = origin_signal.shape[0]
    origin_duration = n_point_origin / origin_sr
    n_point_final = int(origin_duration * final_sr)
    
    final_signal = np.zeros(n_point_final, np.float32)
    for i in range(n_point_final):
        # find the left corresponding point in origin signal of the i-th point in final signal  
        j = int(i / final_sr * origin_sr)
        # they overlap on the timeline
        if j * final_sr == i * origin_sr:
            final_signal[i] = origin_signal[j]
        else:
            # do linear interpolation
            if j == (n_point_origin - 1):
                # special case, when do linear interpolation, no right point to refer
                # directly copy
                final_signal[i] = origin_signal[j]
            else:
                k = (origin_signal[j + 1] - origin_signal[j]) * origin_sr
                delta_x = i / final_sr - j / origin_sr
                final_signal[i] = origin_signal[j] + k * delta_x
    return final_signal

def align_time_dimension(visual_info, eeg_info, canbus_info, n_sample_per_frame):
    '''
    Process the three modal data
    and solve the data alignment problem caused by sample rate difference
    '''
    n_frame = visual_info.shape[0]
    n_sample = eeg_info.shape[0]

    n_eeg_channel = eeg_info.shape[1]
    n_canbus_channel = canbus_info.shape[1]
    # find the frame number based on collected human data
    # if there at least one data point in one frame, we should keep the frame
    # else just drop the frame
    if n_sample == (n_sample // n_sample_per_frame) * n_sample_per_frame:
        final_n_frame = n_sample // n_sample_per_frame
    else:
        final_n_frame = (n_sample + n_sample_per_frame) // n_sample_per_frame

    assert type(final_n_frame) is int, "type of final_n_frame is not int"
    
    if n_frame >= final_n_frame:
        # cut the visual info
        final_n_sample = final_n_frame * n_sample_per_frame
        # (n_frame, C, H, W)
        final_visual_info = visual_info[:final_n_frame, :, :, :]
        # (n_sample, n_eeg_channel)
        final_eeg_info = np.zeros((final_n_sample, n_eeg_channel), dtype=np.float32)
        final_eeg_info[:eeg_info.shape[0], :] = eeg_info
        # 1 means invalid, 0 means valid
        final_eeg_mask = np.ones(final_n_sample, dtype=np.bool)
        final_eeg_mask[:eeg_info.shape[0]] = False
        # (n_sample, n_canbus_channel)
        final_canbus_info = np.zeros((final_n_sample, n_canbus_channel), dtype=np.float32)
        final_canbus_info[:canbus_info.shape[0], :] = canbus_info
        # 1 means invalid, 0 means valid
        final_canbus_mask = np.ones(final_n_frame, dtype=np.bool)
        final_canbus_mask[:canbus_info.shape[0]] = False
    else:
        # cut the human data
        final_n_frame = n_frame
        final_n_sample = final_n_frame * n_sample_per_frame
        # (n_frame, C, H, W)
        final_visual_info = visual_info
        # (n_sample, n_eeg_channel)
        final_eeg_info = eeg_info[:final_n_sample, :]
        final_eeg_mask = np.zeros(final_n_sample, dtype=np.bool)
        # (n_sample, n_canbus_channel)
        final_canbus_info = canbus_info[:final_n_sample, :]
        final_canbus_mask = np.zeros(final_n_sample, dtype=np.bool)

    return final_visual_info, final_eeg_info, final_eeg_mask, final_canbus_info, final_canbus_mask

def align_batch_n_sample_dimension(info_list: List[torch.Tensor], info_mask_list: List[torch.Tensor]):
    # shape of info_list's elements: (n_sample, n_label)
    # shape of info_mask_list's element: (n_sample)
    # shape (n_batch, n_sample_max, n_label)
    # "mask = 1" means we should mask the data when we are training/testing
    ret_tensor_list = []
    ret_mask_list = []
    
    n_sample_max = np.max([info.shape[0] for info in info_list])
    n_label = info_list[0].shape[1]
    
    for info, info_mask in zip(info_list, info_mask_list):
        new_tensor = torch.zeros((n_sample_max, n_label), dtype=torch.float32)
        new_tensor[:info.shape[0], :] = info
        
        new_mask = torch.ones(n_sample_max, dtype=torch.bool)
        new_mask[:info_mask.shape[0]] = info_mask
        
        ret_tensor_list.append(new_tensor)
        ret_mask_list.append(new_mask)
        
    ret_tensor = torch.stack(ret_tensor_list)
    ret_mask = torch.stack(ret_mask_list)

    return ret_tensor, ret_mask

def align_batch_n_frame_dimension(val_list: List[torch.Tensor]):
    # for each val in val_list, shape is (n_frame, C, H, W)
    n_frame_max = np.max([val.shape[0] for val in val_list])
    
    ret_val_list = []
    ret_mask_list = []
    
    _, C, H, W = val_list[0].shape
    
    for val in val_list:
        new_tensor = torch.zeros((n_frame_max, C, H, W), dtype=val.dtype)
        new_mask = torch.ones((n_frame_max), dtype=torch.bool)
        
        new_tensor[:val.shape[0],:,:,:] = val
        new_mask[:val.shape[0]] = False
        
        ret_val_list.append(new_tensor)
        ret_mask_list.append(new_mask)
    
    ret_tensor = torch.stack(ret_val_list, dim=0) # (batch_size, n_frame_max, C, H, W)
    ret_mask = torch.stack(ret_mask_list, dim=0) # (batch_size, n_frame_max)
    
    return ret_tensor, ret_mask