import numpy as np
import torch

def feature_mixer(feature_maps, fold_tensors = True, use_cuda = True, final_type = np.float16):
    #makes sure the feature maps are compatible
    feature_map_lens = [len(feature_map) for feature_map in feature_maps]
    assert all([x == feature_map_lens[0] for x in feature_map_lens])
    num_per_map = len(feature_maps[0])
    for feature_map_inner_index in range(num_per_map):
        inner_feature_map_lens = [len(feature_map[feature_map_inner_index]) for feature_map in feature_maps]
        assert all([x == inner_feature_map_lens[0] for x in inner_feature_map_lens])
    out = [LazyFeatureMixer(lazy_arrs, fold_tensor = fold_tensors, use_cuda=use_cuda, final_type=final_type) for lazy_arrs in zip(*feature_maps)]
    return out

def feature_mixer2(feature_maps, final_type = np.float16):
    #makes sure the feature maps are compatible
    feature_map_lens = [len(feature_map) for feature_map in feature_maps]
    assert all([x == feature_map_lens[0] for x in feature_map_lens])
    num_per_map = len(feature_maps[0])
    for feature_map_inner_index in range(num_per_map):
        inner_feature_map_lens = [len(feature_map[feature_map_inner_index]) for feature_map in feature_maps]
        assert all([x == inner_feature_map_lens[0] for x in inner_feature_map_lens])
    out = [LazyFeatureMixer2(lazy_arrs, final_type=final_type) for lazy_arrs in zip(*feature_maps)]
    return out
    
class LazyFeatureMixer():
    def __init__(self, lazy_arrays, fold_tensor = True, use_cuda = True, final_type = np.float16):
        self.fold_tensor = fold_tensor
        self.lazy_arrays = lazy_arrays
        self.lazy_arr_shapes = [x.shape for x in lazy_arrays]
        assert [shape[0] == self.lazy_arr_shapes[0][0] for shape in self.lazy_arr_shapes]
        self.n_samples = self.lazy_arr_shapes[0][0]
        assert [shape[2] == self.lazy_arr_shapes[0][2] for shape in self.lazy_arr_shapes]
        self.delays = self.lazy_arr_shapes[0][2]
        self.feature_size = sum([x[1] for x in self.lazy_arr_shapes])
        self.feature_lens = [x[1] for x in self.lazy_arr_shapes]
        self.feature_start_indices = np.cumsum(np.array([0] + self.feature_lens[:-1], dtype=int))
        self.feature_end_indices = np.cumsum(np.array(self.feature_lens, dtype = int))
        #self.start_indices = 
        if self.fold_tensor:
            self.shape = (self.n_samples, self.feature_size*self.delays)
        else:
            self.shape = (self.n_samples, self.feature_size, self.delays)
        self.final_type = final_type
        self.use_cuda = use_cuda

    def __getitem__(self, idxs):
        if self.use_cuda:
            out_features = torch.empty((len(idxs), self.feature_size, self.delays), device = "cuda")
            for lazy_array_index in range(len(self.lazy_arrays)):
                out_features[:,self.feature_start_indices[lazy_array_index]:self.feature_end_indices[lazy_array_index], :] = torch.from_numpy(self.lazy_arrays[lazy_array_index][idxs]).to("cuda", non_blocking=True)
            if self.fold_tensor:
                out_features = out_features.reshape(out_features.shape[0], -1).cpu().numpy()
            else:
                out_features = out_features.cpu().numpy()
            out_features = out_features.astype(self.final_type)
            return out_features
        else:
            out_features = np.empty((self.n_samples, self.feature_size, self.delays))
            for lazy_array_index in range(len(self.lazy_arrays)):
                out_features[:,self.feature_start_indices[lazy_array_index]:self.feature_end_indices[lazy_array_index], :] = self.lazy_arrays[lazy_array_index][idxs]
            if self.fold_tensor:
                out_features = out_features.reshape(out_features.shape[0], -1)
            out_features = out_features.astype(self.final_type)
            return out_features
    
    def force_load_array(self):
        idxs = list(range(self.n_samples))
        return self.__getitem__(idxs)
    
    def __len__(self):
        return self.n_samples
    
class LazyFeatureMixer2():
    def __init__(self, lazy_arrays, use_cuda = True, final_type = np.float16):
        self.lazy_arrays = lazy_arrays
        self.lazy_arr_shapes = [x.shape for x in lazy_arrays]
        assert [shape[0] == self.lazy_arr_shapes[0][0] for shape in self.lazy_arr_shapes]
        self.n_samples = self.lazy_arr_shapes[0][0]
        self.delays = [x[2] for x in self.lazy_arr_shapes]
        self.feature_sizes = [x[1] for x in self.lazy_arr_shapes]
        self.mixed_sizes = [x[1]*x[2] for x in self.lazy_arr_shapes]
        self.feature_start_indices = np.cumsum([0] + self.mixed_sizes[:-1])
        self.feature_end_indices = np.cumsum(self.mixed_sizes)
        self.feature_size = sum(self.mixed_sizes)
        self.const_delay_feature_size = sum(self.feature_sizes)
        self.shape = (self.n_samples, self.feature_size)
        self.final_type = final_type

    def __getitem__(self, idxs):
        out_features = np.ones((len(idxs), self.feature_size))
        for lazy_array_index in range(len(self.lazy_arrays)):
            to_add = self.lazy_arrays[lazy_array_index][idxs]
            out_features[:,self.feature_start_indices[lazy_array_index]:self.feature_end_indices[lazy_array_index]] = to_add.reshape(out_features.shape[0], -1)
        return out_features.astype(self.final_type)
    
    def force_load_array(self):
        idxs = list(range(self.n_samples))
        return self.__getitem__(idxs)
    
    def __len__(self):
        return self.n_samples