import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import numpy.linalg as la
from tqdm import tqdm
import scipy


class Dataset(Dataset):
    def __init__(self, root_dir, eq=True, log_flag=False):
        filename = '%s/data_edge.npz' % root_dir
        filename_labels = '%s/data_labels.npy' % root_dir
        filename_ca = '%s/data_CA.npy' % root_dir

        self.log_flag = log_flag

        if eq:
            filename_equilibrium = '%s/equilibrium.npz' % root_dir
            equilibrium = np.load(filename_equilibrium)
            self.equilibrium_coords = equilibrium['coord'].astype(np.float32)
        signal = np.load(filename)
        self.labels = np.load(filename_labels)

        self.ca_atom_coords = np.load(filename_ca)
        self.ca_bonds = la.norm(self.ca_atom_coords[:, :, 1:] - self.ca_atom_coords[:, :, :-1], axis=1, keepdims=True)
        self.ca_bonds[self.ca_bonds > 10] = 0
        if log_flag:
            self.ca_bonds[self.ca_bonds > 0] = np.log(self.ca_bonds[self.ca_bonds > 0])
        self.ca_bonds_stats = {'mu': np.mean(self.ca_bonds, axis=0), 'std': np.std(self.ca_bonds, axis=0) + 1E-8}

        self.coords = signal['coord'].astype(np.float32)
        self.coords_stats = {'mu': np.mean(self.coords, axis=0), 'std': np.std(self.coords, axis=0) + 1E-8}

        self.edge_length = signal['edge_length'].astype(np.float32)
        print('99.5 var', np.percentile(self.edge_length[:, 0], 99.5))
        self.edge_length_stats = {'mu': np.mean(self.edge_length, axis=0),
                                  'std': np.std(self.edge_length, axis=0) + 1E-8}
        self.edge_normal = signal['edge_normal'].astype(np.float32)
        self.edge_normal_stats = {'mu': np.mean(self.edge_normal, axis=0),
                                  'std': np.std(self.edge_normal, axis=0) + 1E-8}

        self.bond_idx0 = signal['bond_idx0']
        self.bond_idx1 = signal['bond_idx1']
        sparse_idx = torch.LongTensor(np.stack([self.bond_idx0, self.bond_idx1]))
        sparse_val = torch.FloatTensor(self.edge_length_stats['mu'][0])
        print(sparse_val.shape)

        self.weighted_adj = torch.sparse.FloatTensor(sparse_idx, sparse_val,
                                                     torch.Size([self.coords.shape[-1], self.coords.shape[-1]]))

        # Need to get c alpha indices from backbone
        weighted_adj = self.weighted_adj.to_dense()
        n_comp, components = scipy.sparse.csgraph.connected_components(weighted_adj)
        idx_vec = np.arange(len(components))
        idx_list = []
        for i in range(n_comp):
            idx_list += [idx_vec[components == i][1::4]]
        self.ca_idx_list = np.concatenate(idx_list)

    def __len__(self):
        return self.edge_length.shape[0]
  
    def __getitem__(self, idx, noise=True):
        edge_lens = self.edge_length[idx]  # has shape 1 x N_edge
        edge_norms = self.edge_normal[idx]  # has shape 3 x N_edge
        ca_bonds = self.ca_bonds[idx]  # has shape 1 x max_length x fragments
        coords = self.coords[idx]
        labels = self.labels[idx]

        # center_of_mass = np.mean(coords, axis=1)
        # center_of_mass = (center_of_mass - np.mean(self.coords_stats['mu'], axis=1)) \
        #                  / np.mean(self.coords_stats['std'] ** 2, axis=1) ** 0.5

        edge_lens = (edge_lens - self.edge_length_stats['mu']) / self.edge_length_stats['std']
        ca_bonds = (ca_bonds - self.ca_bonds_stats['mu']) / self.ca_bonds_stats['std']
        #edge_norms = (edge_norms - self.edge_normal_stats['mu']) / self.edge_normal_stats['std']
        #angles = (angles - self.angles_stats['mu']) / self.angles_stats['std']
        #angles[:2] = angles[:2] / 180.0
        #angles[2:] = (angles[2:] - self.angles_stats['mu'][2:]) / self.angles_stats['std'][2:]  

        coords = (coords - self.coords_stats['mu']) / self.coords_stats['std']
        return edge_lens, edge_norms, ca_bonds, labels, coords  # , center_of_mass

    def compute_edge_signal(self, coords):
        # coords = torch.transpose(coords, 1, 2)
        coords_diff = coords[:, :, self.bond_idx1] - coords[:, :, self.bond_idx0]
        edge_lens = torch.norm(coords_diff, dim=1, keepdim=True)
        edge_norms = coords_diff / (edge_lens + 1E-16)
        mu = torch.tensor(self.edge_length_stats['mu'], device=edge_lens.device)
        std = torch.tensor(self.edge_length_stats['std'], device=edge_lens.device)
        edge_lens = (edge_lens - mu) / std
        return edge_lens, edge_norms

    def compute_ca_signal(self, coords):
        ca_locations = coords[:, :, self.ca_idx_list]

        # Define footprint for assigning edge lengths
        ca_atom_sub = torch.tensor(self.ca_atom_coords[:coords.shape[0]], device=coords.device)
        ca_bond_sub = torch.tensor(self.ca_bonds[:coords.shape[0]], device=coords.device)
        ca_locations_shaped = torch.zeros(ca_atom_sub.shape, device=coords.device)
        # ca_bond_sub = torch.tensor(self.ca_bonds[:coords.shape[0], 0], device=coords.device)
        # ca_bond_sub = torch.transpose(ca_bond_sub, 1, 2)
        ca_atom_sub = torch.transpose(ca_atom_sub, 2, 3)
        nidx = torch.nonzero(ca_atom_sub, as_tuple=False)

        ca_locations_shaped[(nidx[:, 0], nidx[:, 1], nidx[:, 3], nidx[:, 2])] = ca_locations.flatten()

        ca_diff = ca_locations_shaped[:, :, 1:] - ca_locations_shaped[:, :, :-1]
        ca_lens = torch.norm(ca_diff, dim=1, keepdim=True)
        # ca_signal = torch.zeros([coords.shape[0], self.ca_bonds.shape[1], self.ca_bonds.shape[2],
        #                          self.ca_bonds.shape[3]], device=coords.device)
        # Need to mask these new lengths
        ca_lens_mask = ca_lens.clone()
        ca_lens_mask[ca_bond_sub == 0] = 0
        ca_lens_mask2 = ca_lens_mask.clone()
        if self.log_flag:
            ca_lens_mask2[ca_lens_mask > 0] = torch.log(ca_lens_mask[ca_lens_mask > 0])

        mu = torch.tensor(self.ca_bonds_stats['mu'], device=coords.device)
        std = torch.tensor(self.ca_bonds_stats['std'], device=coords.device)
        ca_lens_mask2 = (ca_lens_mask2 - mu) / std
        return ca_lens_mask2

    def print_stats(self):
        print('edge_length stats (min, max, median, mean): mu', np.amin(self.edge_length_stats['mu'], axis=1), np.amax(self.edge_length_stats['mu'], axis=1), np.median(self.edge_length_stats['mu'], axis=1), np.mean(self.edge_length_stats['mu'], axis=1))
        print('Bond stats (min, max, median, mean): std', np.amin(self.edge_length_stats['std'], axis=1), np.amax(self.edge_length_stats['std'], axis=1), np.median(self.edge_length_stats['std'], axis=1), np.mean(self.edge_length_stats['std'], axis=1))

        print('edge_normal stats (min, max, median, mean): mu', np.amin(self.edge_normal_stats['mu'], axis=1), np.amax(self.edge_normal_stats['mu'], axis=1), np.median(self.edge_normal_stats['mu'], axis=1), np.mean(self.edge_normal_stats['mu'], axis=1))
        print('edge_normal stats (min, max, median, mean): std', np.amin(self.edge_normal_stats['std'], axis=1), np.amax(self.edge_normal_stats['std'], axis=1), np.median(self.edge_normal_stats['std'], axis=1), np.mean(self.edge_normal_stats['std'], axis=1))

        print('Coordinates stats (min, max, median, mean): mu', np.amin(self.coords_stats['mu'], axis=1), np.amax(self.coords_stats['mu'], axis=1), np.median(self.coords_stats['mu'], axis=1), np.mean(self.coords_stats['mu'], axis=1))
        print('Coordinates stats (min, max, median, mean): std', np.amin(self.coords_stats['std'], axis=1), np.amax(self.coords_stats['std'], axis=1), np.median(self.coords_stats['std'], axis=1), np.mean(self.coords_stats['std'], axis=1))

    def dataset_split(self):
        train_dataset = torch.utils.data.Subset(self, np.arange(int(0.7 * self.__len__())))
        val_dataset = torch.utils.data.Subset(self, np.arange(int(0.7 * self.__len__()), int(0.8 * self.__len__())))
        test_dataset = torch.utils.data.Subset(self, np.arange(int(0.8 * self.__len__()), self.__len__()))
        return train_dataset, val_dataset, test_dataset

    def dataset_split_random(self, seed=1):
        train_len = int(0.7 * self.__len__())
        val_len = int(0.1 * self.__len__()) 
        test_len = self.__len__() - train_len - val_len 
        split = [train_len, val_len, test_len] 
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        train_dataset, val_dataset, test_dataset = \
                                    torch.utils.data.random_split(self, split)
        return train_dataset, val_dataset, test_dataset

    # 1668 Length of each trajectory
    def dataset_split_traj_end(self, dataname):
        if dataname == 'da_10906555':
            num_trajs = 50
            num_frames = 1668
        elif dataname == 'ace2':
            num_trajs = 75
            num_frames = 1668
        elif dataname == 'sarscov2_closed' or dataname == 'sarscov2_partialopen':
            num_trajs = 1
            num_frames = 2001
        elif dataname == 'protease':
            num_trajs = 1
            num_frames = 10001
        train_idx = []
        val_idx = []
        test_idx = []
        for idx in range(num_trajs):
            train_idx = np.concatenate([train_idx, np.arange(int(idx * num_frames),
                                                             int(idx * num_frames + 0.7 * num_frames))])
            val_idx = np.concatenate([val_idx, np.arange(int(idx * num_frames + 0.7 * num_frames),
                                                         int(idx * num_frames + 0.8 * num_frames))])
            test_idx = np.concatenate([test_idx, np.arange(int(idx * num_frames + 0.8 * num_frames),
                                                           int((idx + 1) * num_frames))])
        train_idx = np.array(train_idx, dtype=np.int)
        val_idx = np.array(val_idx, dtype=np.int)
        test_idx = np.array(test_idx, dtype=np.int)
        train_dataset = torch.utils.data.Subset(self, train_idx)
        val_dataset = torch.utils.data.Subset(self, val_idx)
        test_dataset = torch.utils.data.Subset(self, test_idx)
        return train_dataset, val_dataset, test_dataset

    def normalize_coordinates(self, pts):
        pts_norm = (pts - self.stats['mu']) / self.stats['sigma2'] ** 0.5
        return pts_norm

    def normalize_conformal(self, signal):
        if hasattr(self, 'stats_transformed'):
            stats = self.stats_transformed
        else:
            stats = np.load('%s/%s/conf_stats.npz' % (self.root_dir, self.dataset), allow_pickle=True)
            stats = dict((key, val) for key, val in zip(stats.keys(), stats.values()))
        signal_norm = (signal - stats['mu']) / stats['sigma2'] ** 0.5
        return signal_norm

    def normalize_conformal_torch(self, signal):
        if hasattr(self, 'stats_transformed'):
            stats = self.stats_transformed
        else:
            stats = np.load('%s/%s/conf_stats.npz' % (self.root_dir, self.dataset), allow_pickle=True)
            stats = dict((key, val) for key, val in zip(stats.keys(), stats.values()))
        mu = torch.tensor(stats['mu'], device=signal.device, dtype=torch.float32).t()
        std = torch.tensor(stats['sigma2'] ** 0.5, device=signal.device, dtype=torch.float32).t()
        signal_norm = (signal - mu) / std
        return signal_norm

    def unnormalize_coordinates(self, pts_norm):
        # stats = np.load('%s/%s/coord_stats.npy' % (self.root_dir, self.dataset), allow_pickle=True)
        pts_trans = pts_norm * self.coords_stats['std'] + self.coords_stats['mu']
        # pts += self.ref
        return pts_trans

    def unnormalize_coordinates_torch(self, pts_norm):
        # print('Coord stats confirmation', self.coords_stats)
        mu = torch.tensor(self.coords_stats['mu'], device=pts_norm.device)
        sigma = torch.tensor(self.coords_stats['std'], device=pts_norm.device)
        #print('mu', mu.shape, 'sigma', sigma.shape, 'pts', pts_norm.shape) 
        pts_trans = pts_norm * sigma + mu
        return pts_trans

    def unnormalize_edgelens_torch(self, edge_lens):
        mu = torch.tensor(self.edge_length_stats['mu'], device=edge_lens.device)
        sigma = torch.tensor(self.edge_length_stats['std'], device=edge_lens.device)
        # print('mu', mu.shape, 'sigma', sigma.shape, 'pts', pts_norm.shape)
        edge_lens_trans = edge_lens * sigma + mu
        return edge_lens_trans

    def unnormalize_calpha_torch(self, calpha):
        mu = torch.tensor(self.ca_bonds_stats['mu'], device=calpha.device)
        sigma = torch.tensor(self.ca_bonds_stats['std'], device=calpha.device)
        # sigmar = sigma.unsqueeze(0)
        # sigmar = sigmar.repeat([calpha.shape[0], sigma.shape[0], sigma.shape[1]])
        # print('mu', mu.shape, 'sigma', sigma.shape, 'pts', pts_norm.shape)
        edge_lens_trans = calpha * sigma + mu
        if self.log_flag:
            edge_lens_trans = torch.exp(edge_lens_trans)
            edge_lens_trans = edge_lens_trans * (sigma != 0)
            # edge_lens_trans[self.ca_bonds[:calpha.shape[0]] > 0] = \
            #     torch.exp(edge_lens_trans[self.ca_bonds[:calpha.shape[0]] > 0])

        return edge_lens_trans

    def compute_local_frames(self):
        u = torch.zeros([self.coords.shape[0], self.coords.shape[1], self.coords.shape[2]])
        b = torch.zeros(self.coords.shape)
        n = torch.zeros(self.coords.shape)
        bn = torch.zeros(self.coords.shape)

        u[:, :, 1:] = (self.coords[:, :, 1:] - self.coords[:, :, :-1]) \
                        / torch.norm(self.coords[:, :, 1:] - self.coords[:, :, :-1], dim=1, keepdim=True)
        # u has 0 at first entry
        b[:, :, 1:-1] = (u[:, :, 1:-1] - u[:, :, 2:]) / torch.norm(u[:, :, 1:-1] - u[:, :, 2:], dim=1, keepdim=True)
        #b has 0 at the endpoints
        n[:, :, 1:-1] = torch.cross(u[:, :, 1:-1], u[:, :, 2:], dim=1)
        n[:, :, 1:-1] = n[:, :, 1:-1] / torch.norm(n[:, :, 1:-1], dim=1, keepdim=True)
        bn[:, :, 1:-1] = torch.cross(b[:, :, 1:-1], n[:, :, 1:-1], dim=1)
        loc_frame = torch.stack([b, n, bn], dim=1)
        return loc_frame

    def compute_dist_for_spatial(self, loc_frame, k=30):
        diff = torch.unsqueeze(self.coords, dim=2) - torch.unsqueeze(self.coords, dim=3)
        dist = torch.sum(diff ** 2, dim=1, keepdim=True)
        dist = dist ** 0.5
        loc_frame_p = loc_frame.permute(0, 3, 1, 2)
        diff_p = diff.permute(0, 2, 1, 3)
        orient = torch.matmul(loc_frame_p, diff / torch.norm(diff_p, dim=2, keepdim=True))
        orient = orient.permute(0, 2, 1, 3)
        rot = torch.zeros([self.coords.shape[0], 4, self.coords.shape[-1], self.coords.shape[-1]])
        loc_frame_p1 = loc_frame_p.unsqueeze(dim=1)
        loc_frame_p2 = loc_frame_p.unsqueeze(dim=2)
        rot_d = torch.matmul(loc_frame_p1, loc_frame_p2.transpose(3, 4))
        rot_d = rot_d.permute(0, 3, 4, 1, 2)
        rot[:, 0] = 0.5 * (1 + rot_d[:, 0, 0] - rot_d[:, 1, 1] - rot_d[:, 2, 2]) ** 0.5
        rot[:, 1] = 1 / (4 * rot[:, 0]) * (rot_d[:, 0, 1] + rot_d[:, 1, 0])
        rot[:, 2] = 1 / (4 * rot[:, 0]) * (rot_d[:, 0, 2] + rot_d[:, 2, 0])
        rot[:, 3] = 1 / (4 * rot[:, 0]) * (rot_d[:, 2, 1] - rot_d[:, 1, 2])

        mask = torch.tensor(dist.shape)
        kdist, kidx = torch.topk(dist, k, largest=False)
        mask[kidx] = 1.0
        sigma = torch.median(kdist, dim=-1, keepdim=True)[0]
        rbf = torch.exp(- dist ** 2 / (sigma * sigma.transpose(2, 3)))
        return rbf, orient, rot, kidx[0]

    def get_level_masks(self, num_levels, k=30):
        masks = [None] * num_levels
        for idx in range(num_levels):
            step = 2 ** idx
            pts = self.coords[:, :, ::step]

            diff = torch.unsqueeze(pts, dim=2) - torch.unsqueeze(pts, dim=3)
            dist = torch.sum(diff ** 2, dim=1, keepdim=True)
            dist = dist ** 0.5

            kdist, kidx = torch.topk(dist, k, largest=False)
            masks[idx] = kidx[0]
        return masks
