

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
import e3nn
from e3nn import o3



def filter_by_l(tensors, irreps, lmax):
    ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
    return tensors[:, ls_indices <= lmax]

def filter_by_channels(tensors, mul_rst, channels, ch_to_ind_encoding):
    separated_tensors = mul_rst.separate(tensors)
    new_tensors = []
    filtered_encoding = {}
    for c, channel in enumerate(channels):
        ind = ch_to_ind_encoding[channel]
        new_tensors.append(separated_tensors[:, ind, :])
        filtered_encoding[channel] = c
    
    filtered_tensors = mul_rst.combine(torch.stack(new_tensors, dim=1))

    return filtered_tensors, filtered_encoding


class NeighborhoodsDataset(Dataset):
    def __init__(self, x: Tensor, y: Tensor):
        self.x = x # [N, dim]
        self.y = y # [N, num_aa]
        assert x.shape[0] == y.shape[0]
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        return self.x[idx], self.y[idx]

class WrongOneHotDimension(Exception):
    pass

class MNISTDataset(Dataset):
    def __init__(self, x: Tensor, y: Tensor):
        assert x.shape[0] == y.shape[0]
        self.x = x # [N, dim]
        if len(y.shape) == 2:
            if y.shape[1] == 10:
                self.y = y # digits are already correctly one-hot encoded; [N, 10]
            else:
                raise WrongOneHotDimension
        else:
            self.y = torch.nn.functional.one_hot(y, 10) # one-hot encode digits; [N, 10]
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        return self.x[idx], self.y[idx]

class MNISTDatasetWithConditioning(Dataset):
    def __init__(self, x: Tensor, y: Tensor, c: Tensor):
        assert x.shape[0] == y.shape[0]
        self.x = x # [N, dim]
        self.c = c # [N, ANY]
        self.y = y
        # if len(y.shape) == 2:
        #     if y.shape[1] == 10:
        #         self.y = y # digits are already correctly one-hot encoded; [N, 10]
        #     else:
        #         raise WrongOneHotDimension
        # else:
        #     self.y = torch.nn.functional.one_hot(y, 10) # one-hot encode digits; [N, 10]
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        return self.x[idx], self.y[idx], self.c[idx]

class NeighborhoodsDatasetWithConditioning(Dataset):
    def __init__(self, x: Tensor, y: Tensor, c: Tensor):
        self.x = x # [N, dim]
        self.y = y # [N,]
        self.c = c # [N, ANY]
        assert x.shape[0] == y.shape[0]
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        return self.x[idx], self.y[idx], self.c[idx]

def make_dict(tensor, irreps, device):
    batch_size = tensor.shape[0]
    ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
    dict_repr = {}
    for l in sorted(list(set(irreps.ls))):
        dict_repr[l] = tensor[:, ls_indices == l].reshape(batch_size, -1, 2*l+1).to(device)
    return dict_repr

class MNISTDatasetWithConditioning__fibers(Dataset):
    def __init__(self, x: Tensor, irreps: o3.Irreps, y: Tensor, c: Tensor):
        assert x.shape[0] == y.shape[0]
        self.x = x # [N, dim]
        self.c = c # [N, ANY]
        self.y = y

        self.ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
        self.unique_ls = sorted(list(set(irreps.ls)))
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        x_fiber = {}
        for l in self.unique_ls:
            x_fiber[l] = self.x[idx][self.ls_indices == l].view(-1, 2*l+1)
        
        return x_fiber, self.x[idx], self.y[idx], self.c[idx]
    
class Shrec17Dataset__fibers(Dataset):
    def __init__(self, x: Tensor, irreps: o3.Irreps, y: Tensor, ids: Tensor):
        assert x.shape[0] == y.shape[0]
        self.x = x # [N, dim]
        self.y = y
        self.ids = ids

        self.ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
        self.unique_ls = sorted(list(set(irreps.ls)))
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        x_fiber = {}
        for l in self.unique_ls:
            x_fiber[l] = self.x[idx][self.ls_indices == l].view(-1, 2*l+1)
        
        return x_fiber, self.x[idx], self.y[idx], self.ids[idx]

class NeighborhoodsDatasetWithConditioning__fibers(Dataset):
    def __init__(self, x: Tensor, irreps: o3.Irreps, y: Tensor, c: Tensor):
        self.x = x # [N, dim]
        self.y = y # [N,]
        self.c = c # [N, ANY]
        assert x.shape[0] == y.shape[0]

        self.ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
        self.unique_ls = sorted(list(set(irreps.ls)))
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        x_fiber = {}
        for l in self.unique_ls:
            x_fiber[l] = self.x[idx][self.ls_indices == l].view(-1, 2*l+1)
        
        return x_fiber, self.x[idx], self.y[idx], self.c[idx]


class NeighborhoodsDatasetWithConditioningAndIds__fibers(Dataset):
    def __init__(self, x: Tensor, irreps: o3.Irreps, y: Tensor, c: Tensor, ids: np.array):
        self.x = x # [N, dim]
        self.y = y # [N,]
        self.c = c # [N, ANY]
        self.ids = ids # [N,] (in string form)
        assert x.shape[0] == y.shape[0]

        self.ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
        self.unique_ls = sorted(list(set(irreps.ls)))
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx: int):
        x_fiber = {}
        for l in self.unique_ls:
            x_fiber[l] = self.x[idx][self.ls_indices == l].view(-1, 2*l+1)
        
        return x_fiber, self.x[idx], self.y[idx], self.c[idx], self.ids[idx]