import os
import random
import numpy as np
from scipy.sparse import coo_matrix
import torch
from pytorch3d import transforms
from torch.utils.data import Dataset

from MDAnalysisData import datasets
import MDAnalysis as mda
from MDAnalysis import transformations
from MDAnalysis.analysis import distances
from dataset.mol_protein_graph import atom_mapping, one_of_k_encoding_unk_indices



class MDAnalysisDataset(Dataset):
    """
    NBodyDataset

    """

    def __init__(self, dataset_name, partition='train', tmp_dir=None, delta_frame=1, train_valid_test_ratio=None,
                 test_rot=False, test_trans=False, load_cached=False, cut_off=6, backbone=False):
        super().__init__()
        self.delta_frame = delta_frame
        self.dataset = dataset_name
        self.partition = partition
        self.load_cached = load_cached
        self.test_rot = test_rot
        self.test_trans = test_trans
        self.cut_off = cut_off
        self.backbone = backbone
        if load_cached:
            print(f'Loading {dataset_name} from cached data for {partition}...')
            if backbone:
                tmp_dir = os.path.join(tmp_dir, 'adk_backbone_processed')
            else:
                tmp_dir = os.path.join(tmp_dir, 'adk_processed')
        self.tmp_dir = tmp_dir
        if train_valid_test_ratio is None:
            train_valid_test_ratio = [0.6, 0.2, 0.2]
        assert sum(train_valid_test_ratio) <= 1

        if load_cached:
            edges, self.edge_attr, self.charges, self.n_frames = torch.load(os.path.join(tmp_dir,
                                                                                         f'{dataset_name}.pkl'))
            self.edges = torch.stack(edges, dim=0)
            self.train_valid_test = [int(train_valid_test_ratio[0] * (self.n_frames - delta_frame)),
                                     int(sum(train_valid_test_ratio[:2]) * (self.n_frames - delta_frame))]

            adk = datasets.fetch_adk_equilibrium(data_home=tmp_dir)
            data = mda.Universe(adk.topology, adk.trajectory)
            atom_names = data.select_atoms('backbone').names
            elements = []
            atom_encoder = {value: key for key, value in atom_mapping.items()}
            for atom_name in atom_names:
                elements.append(atom_encoder[atom_name[0]])
            self.node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e - 1, atom_mapping) for e in elements])
            return

        assert not self.backbone, NotImplementedError("Use load_cached for backbone case.")
        if dataset_name.lower() == 'adk':
            adk = datasets.fetch_adk_equilibrium(data_home=tmp_dir)
            self.data = mda.Universe(adk.topology, adk.trajectory)
        else:
            raise NotImplementedError(f'{dataset_name} is not available in MDAnalysisData.')

        # Local Graph information
        try:
            self.charges = torch.tensor(self.data.atoms.charges)
        except OSError:
            print(f'Charge error')
        try:
            self.edges = torch.stack([torch.tensor(self.data.bonds.indices[:, 0], dtype=torch.long),
                                      torch.tensor(self.data.bonds.indices[:, 1], dtype=torch.long)], dim=0)
        except OSError:
            print(f'edges error')
        try:
            self.edge_attr = torch.tensor([bond.length() for bond in self.data.bonds])
        except OSError:
            print(f'edge_attr error')
        atom_names = self.data.atoms.names
        elements = []
        atom_encoder = {value: key for key, value in atom_mapping.items()}
        for atom_name in atom_names:
            elements.append(atom_encoder[atom_name[0]])
        self.node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e - 1, atom_mapping) for e in elements])
        self.train_valid_test = [int(train_valid_test_ratio[0] * (len(self.data.trajectory) - delta_frame)),
                                 int(sum(train_valid_test_ratio[:2]) * (len(self.data.trajectory) - delta_frame))]

    def __getitem__(self, i):

        charges, edges, edge_attr = self.charges, self.edges, self.edge_attr
        if len(charges.size()) == 1:
            charges = charges.unsqueeze(-1)
        if len(edge_attr.size()) == 1:
            edge_attr = edge_attr.unsqueeze(-1)

        if self.partition == "valid":
            i = i + self.train_valid_test[0]
        elif self.partition == "test":
            i = i + self.train_valid_test[1]

        # Frames
        frame_0, frame_t = i, i + self.delta_frame

        if self.load_cached:
            pos_curr, vel_0, edge_global, edge_global_attr = torch.load(os.path.join(self.tmp_dir,
                                                                                  f'{self.dataset}_{frame_0}.pkl'))

            pos_next, vel_t, _, _ = torch.load(os.path.join(self.tmp_dir,
                                                         f'{self.dataset}_{frame_t}.pkl'))
            if self.test_rot and self.partition == 'test':
                rot = transforms.random_rotation()
                pos_curr = torch.tensor(np.matmul(pos_curr.detach().numpy(), rot.detach().numpy()))
                pos_next = torch.tensor(np.matmul(pos_next.detach().numpy(), rot.detach().numpy()))
            if self.test_trans and self.partition == 'test':
                dimension = pos_next.max(dim=0)[0] - pos_next.min(dim=0)[0]
                trans = torch.randn(3) * dimension / 2
                pos_curr += trans
                pos_next += trans
            pro_mol_cutoff = 0
            node_mask = torch.ones_like(pos_curr[:, 0]).unsqueeze(-1)
            return pos_curr, pos_next, self.node_feats, pro_mol_cutoff, node_mask

        assert not self.backbone, NotImplementedError("Use load_cached for backbone case.")

        ts_0, ts_t, d, angle, trans = None, None, [0, 0, 1], 0, [0, 0, 0]

        # Rotations and Translations
        not_get = True
        while not_get:
            try:
                ts_0 = self.data.trajectory[frame_0].copy()
                ts_t = self.data.trajectory[frame_t].copy()
                not_get = False
            except OSError:
                frame_0 += 1
                frame_t += 1

        # Rotations and Translations
        if self.test_rot and self.partition == "test":
            d = np.random.randn(3)
            d = d / np.linalg.norm(d)
            angle = random.randint(0, 360)
            ts_0 = transformations.rotate.rotateby(angle, direction=d, ag=self.data.atoms)(ts_0)
            ts_t = transformations.rotate.rotateby(angle, direction=d, ag=self.data.atoms)(ts_t)
        if self.test_trans and self.partition == 'test':
            trans = np.random.randn(3) * ts_0.dimensions[:3] / 2
            ts_0 = transformations.translate(trans)(ts_0)
            ts_t = transformations.translate(trans)(ts_t)
        pos_curr = torch.tensor(ts_0.positions)
        pos_next = torch.tensor(ts_t.positions)
        # return loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t
        pro_mol_cutoff = 0
        node_mask = torch.ones_like(pos_curr[:, 0]).unsqueeze(-1)
        return pos_curr, pos_next, self.node_feats, pro_mol_cutoff, node_mask

    def __len__(self):
        if self.load_cached:
            total_len = max(0, self.n_frames - self.delta_frame)
        else:
            total_len = max(0, len(self.data.trajectory) - 1 - self.delta_frame)
        if self.partition == 'train':
            return min(total_len, self.train_valid_test[0])
        if self.partition == 'valid':
            return max(0, min(total_len, self.train_valid_test[1]) - self.train_valid_test[0])
        if self.partition == 'test':
            return max(0, total_len - self.train_valid_test[1])

    @staticmethod
    def get_cfg(batch_size, n_nodes, cfg):
        offset = torch.arange(batch_size) * n_nodes
        for type in cfg:
            index = cfg[type]  # [B, n_type, node_per_type]
            cfg[type] = (index + offset.unsqueeze(-1).unsqueeze(-1).expand_as(index)).reshape(-1, index.shape[-1])
            if type == 'Isolated':
                cfg[type] = cfg[type].squeeze(-1)
        return cfg


def collate_mda(data):
    loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t = zip(*data)

    # edges
    offset = torch.cumsum(torch.tensor([0] + [_.size(0) for _ in loc_0], dtype=torch.long), dim=0)
    edge_global = torch.cat(list(map(lambda _: _[0] + _[1], zip(edge_global, offset))), dim=-1)
    edges = torch.cat(list(map(lambda _: _[0] + _[1], zip(edges, offset))), dim=-1)
    edge_global_attr = torch.cat(edge_global_attr, dim=0).type(torch.float)
    edge_attr = torch.cat(edge_attr, dim=0).type(torch.float)

    loc_0 = torch.stack(loc_0, dim=0).type(torch.float)
    vel_0 = torch.stack(vel_0, dim=0).view(-1, vel_0[0].size(-1)).type(torch.float)
    loc_t = torch.stack(loc_t, dim=0).view(-1, loc_t[0].size(-1)).type(torch.float)
    vel_t = torch.stack(vel_t, dim=0).view(-1, vel_t[0].size(-1)).type(torch.float)
    charges = torch.stack(charges, dim=0).view(-1, charges[0].size(-1)).type(torch.float)

    return loc_0, vel_0, edge_global, edge_global_attr, edges, edge_attr, charges, loc_t, vel_t
