import math
from functools import partial

import numpy as np
import torch
from torch import LongTensor, Tensor
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from e3nn.o3 import matrix_x, matrix_y, matrix_z
from e3nn.o3 import rand_matrix, axis_angle_to_matrix

from utils.seed import fix_seed



class KFoldDataset(Dataset):
    def __init__(
        self, 
        fold: int = 3,
        num_normal: int = 6,
        num_res: int = 36,
        seed: int = 42
    )-> None:
        fix_seed(seed)
        self.fold = fold
        self.node_feat_base, self.node_pos_base, self.edge_index_base = self.get_base_fold(fold)
        self.rot, self.normal_idx, self.angle = self.get_rot(num_normal, num_res) 
        print(f'Total length: {self.__len__()}')
        print(self.__getitem__(0))

    def __len__(self):
        return self.rot.size(0)
    
    def __getitem__(self, i):
        return Data(
            node_feat=self.node_feat_base,
            node_pos=torch.einsum('ij,nj->ni', self.rot[i], self.node_pos_base),
            edge_index=self.edge_index_base,
            normal_idx=self.normal_idx[i],
            angle=self.angle[i],
        )

    def get_base_fold(self, k):
        node_feat = Tensor([0] * (k + 1)).unsqueeze(-1)

        node_pos = [Tensor([0, 0, 0])]
        node_pos += [
            (Tensor([1, 0, 0]) @ matrix_z(Tensor([2 * math.pi / k * i])).squeeze(0).T)
            for i in range(k)
        ]
        node_pos = torch.stack(node_pos)
        edge_index = LongTensor( [ [0] * k, [i for i in range(1, k + 1)] ] )
        return node_feat, node_pos, edge_index

    def get_rot(self, num_normal, num_res):
        R0 = rand_matrix(num_normal)                 # (N,3,3)
        z = Tensor([0., 0., 1.])                     # (3,)
        axes = (R0 @ z)                              # (N,3)
        axes = axes / axes.norm(dim=-1, keepdim=True)
        angles = Tensor(np.linspace(0, 2*np.pi / self.fold, num_res, endpoint=False))
        axes_expand = axes.repeat_interleave(num_res, dim=0)
        angles_expand = angles.repeat(num_normal)
        R_axis = axis_angle_to_matrix(axes_expand, angles_expand)
        R0_expand = R0.repeat_interleave(num_res, dim=0)
        R_total = R_axis @ R0_expand
        return R_total, torch.arange(num_normal).repeat_interleave(num_res, dim=0), angles_expand / (2*np.pi / self.fold)
        

if __name__ == '__main__':
    kfold_dataset = KFoldDataset(3, 6 ,36)
    loader = partial(DataLoader, batch_size=100, drop_last=False, num_workers=4)
    mylorder = loader(dataset=kfold_dataset, shuffle=False)

    from models import TFNModel
    import e3nn

    max_ell = 2
    model_name = 'tfn'
    output_irreps = e3nn.o3.Irreps('2e')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    model = {
        "tfn": partial(TFNModel, max_ell=max_ell),
    }[model_name](node_input_dim=1, output_irreps=output_irreps)
    model.to(device)
    model.eval()

    for idx, batch in enumerate(mylorder):
        batch.to(device)
        print(batch.normal_idx, batch.angle)
