import copy
import logging
import os

import mdtraj
import numpy as np
import openmm as mm
import openmm.app as app
import torch
from openmm import unit
from openmmtools.testsystems import TestSystem
from torch.utils.data import random_split
from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from utils.transform import CompleteInternalCoordinateTransform

from utils.plotting import save_ramachandran_plot

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('ALDP-Dataset')

# TODO: implement remove_hydrogens
# TODO: implement atom position representation
class ALDPDataset(Dataset):
    def __init__(self, args, target_distribution: TestSystem, transform="normalize"):
        super(ALDPDataset, self).__init__(transform)

        self.tau = args.tau
        self.no_propagator = args.no_propagator
        self.graph_representation = args.graph_representation
        self.target = target_distribution

        

        mdtraj_topology = mdtraj.Topology.from_openmm(self.target.topology)
        data_path = os.path.join(args.data_dir, f'aldp_md_data_{args.data_temperature}_{args.data_save_frequency}_{args.data_size}')

        # Generate MD data if it does not exist
        if not os.path.exists(data_path + "_frames.npy"):
            logger.info(f'Generating MD data for {args.data_size} steps at {args.data_temperature}K at {data_path}')
            os.makedirs(args.data_dir, exist_ok=True)
            frames = []
            potEnergies = []
            sim = app.Simulation(self.target.topology, self.target.system,
                                      mm.LangevinIntegrator(args.data_temperature * unit.kelvin,
                                                            1.0 / unit.picosecond,
                                                            1.0 * unit.femtosecond),
                                      platform=mm.Platform.getPlatformByName(args.md_device))
            sim.context.setPositions(self.target.positions)
            for _ in tqdm(range(args.data_size)):
                sim.step(args.data_save_frequency)
                frames.append(sim.context.getState(getPositions=True).getPositions(asNumpy=True).in_units_of(unit.angstrom)._value)
                potEnergies.append(sim.context.getState(getEnergy=True).getPotentialEnergy().in_units_of(unit.kilojoule_per_mole)._value)
            frames = np.array(frames)
            potEnergies = np.array(potEnergies)
            logger.info(f'Saving generated dataset at {data_path}')
            np.save(data_path + "_frames.npy", frames)
            np.save(data_path + "_potEnergies.npy", potEnergies)
        

        logger.info(f"Time difference x_t <-> x_t+tau: {self.tau * args.data_save_frequency * unit.femtosecond.conversion_factor_to(unit.picosecond)}ps")
        logger.info(f'Loading data from {data_path}')
        frames = np.load(data_path + "_frames.npy")
        potEnergies = np.load(data_path + "_potEnergies.npy")
        logger.info(f'Loaded {len(frames)} frames')
        logger.info(f"Starting state energy: {potEnergies[0]:.2f} kJ/mol")
        
       

        traj = mdtraj.Trajectory(frames, mdtraj_topology)
        if args.save_pdb:
            logging.info("Saving pdb file")
            # RMSD align the trajectory to the first frame and save it as pdb
            traj.superpose(traj).save_pdb(os.path.join(args.data_dir, f'aldp_md_data_{args.data_temperature}_{args.data_save_frequency}_{args.data_size}.pdb'))
        psi = list(mdtraj.compute_psi(traj)[1].flat)
        phi = list(mdtraj.compute_phi(traj)[1].flat)


        # save ramachandran plot 
        save_ramachandran_plot(phi, psi, 0, args, name='dataset')
        self.frames = torch.from_numpy(frames).float()
        self.potEnergies = torch.from_numpy(potEnergies).float()



        # set coordinate transformations to change between internal and cartesian coordinates
        ndim = 66 
        #  internal coordinate mapping for aldp
        z_matrix = [
                (0, [1, 4, 6]),
                (1, [4, 6, 8]),
                (2, [1, 4, 0]),
                (3, [1, 4, 0]),
                (4, [6, 8, 14]),
                (5, [4, 6, 8]),
                (7, [6, 8, 4]),
                (9, [8, 6, 4]),
                (10, [8, 6, 4]),
                (11, [10, 8, 6]),
                (12, [10, 8, 11]),
                (13, [10, 8, 11]),
                (15, [14, 8, 16]),
                (16, [14, 8, 6]),
                (17, [16, 14, 15]),
                (18, [16, 14, 8]),
                (19, [18, 16, 14]),
                (20, [18, 16, 19]),
                (21, [18, 16, 19])
            ]
        # cartesian indices for aldp defining the "base" of the molecule
        cart_indices = [8, 6, 14]


        self.coordinate_transform = CompleteInternalCoordinateTransform(ndim, z_matrix, cart_indices, self.frames.view(-1,66))
        
        _, (self.bond_lengths , self.bond_angles, self.torsion_angles) = self.coordinate_transform(self.frames.view(-1,66))

        if transform == "normalize":
            logger.info("Normalizing data")
            # rescale bond_lengths, bond_angles, torsion_angles to be in the range [0,1]
            self.bond_lengths_min , self.bond_lengths_max = self.bond_lengths.min(), self.bond_lengths.max()
            self.bond_angles_min , self.bond_angles_max = self.bond_angles.min(), self.bond_angles.max()
            self.torsion_angles_min , self.torsion_angles_max = self.torsion_angles.min(), self.torsion_angles.max()

            self.bond_lengths = (self.bond_lengths - self.bond_lengths_min) / (self.bond_lengths_max - self.bond_lengths_min)
            self.bond_angles = (self.bond_angles - self.bond_angles_min) / (self.bond_angles_max - self.bond_angles_min)
            self.torsion_angles = (self.torsion_angles - self.torsion_angles_min) / (self.torsion_angles_max - self.torsion_angles_min)

            self.inverse_transform = self.normalize_internal_coordinates_inverse
        elif transform == "standardize":
            logger.info("Standardizing data")
            # standardize bond_lengths, bond_angles, torsion_angles
            self.bond_lengths_mean , self.bond_lengths_std = self.bond_lengths.mean(dim=0), self.bond_lengths.std(dim=0)
            self.bond_angles_mean , self.bond_angles_std = self.bond_angles.mean(dim=0), self.bond_angles.std(dim=0)
            sin = torch.mean(torch.sin(self.torsion_angles), dim=0)
            cos = torch.mean(torch.cos(self.torsion_angles), dim=0)
            self.torsion_angles_mean = torch.atan2(sin, cos)
            self.torsion_angles_std = self.torsion_angles.std(dim=0)

            self.bond_lengths = (self.bond_lengths - self.bond_lengths_mean) / self.bond_lengths_std
            self.bond_angles = (self.bond_angles - self.bond_angles_mean) / self.bond_angles_std
            self.torsion_angles = (self.torsion_angles - self.torsion_angles_mean) / self.torsion_angles_std

            self.inverse_transform = self.standardize_internal_coordinates_inverse

        else :
            logger.info("No transformation applied to the dataset")
            self.inverse_transform = None 

        # log dataset info 
        logger.info(f"bond_lengths: mean {self.bond_lengths.mean():.2f}, std {self.bond_lengths.std():.2f} , min {self.bond_lengths.min():.2f}, max {self.bond_lengths.max():.2f}")
        logger.info(f"bond_angles: mean {self.bond_angles.mean():.2f}, std {self.bond_angles.std():.2f} , min {self.bond_angles.min():.2f}, max {self.bond_angles.max():.2f}")
        logger.info(f"torsion_angles: mean {self.torsion_angles.mean():.2f}, std {self.torsion_angles.std():.2f} , min {self.torsion_angles.min():.2f}, max {self.torsion_angles.max():.2f}")
        
        # compute graph representation
        self.construct_graph(mdtraj_topology,traj, graph_representation=args.graph_representation)

    def normalize_internal_coordinates_inverse(self, bond_lengths, bond_angles, torsion_angles):
        bond_lengths = bond_lengths * (self.bond_lengths_max.to(bond_lengths.device) - self.bond_lengths_min.to(bond_lengths.device)) + self.bond_lengths_min.to(bond_lengths.device)
        bond_angles = bond_angles * (self.bond_angles_max.to(bond_lengths.device) - self.bond_angles_min.to(bond_lengths.device)) + self.bond_angles_min.to(bond_lengths.device)
        torsion_angles = torsion_angles * (self.torsion_angles_max.to(bond_lengths.device) - self.torsion_angles_min.to(bond_lengths.device)) + self.torsion_angles_min.to(bond_lengths.device)
        return bond_lengths, bond_angles, torsion_angles


    def standardize_internal_coordinates_inverse(self, bond_lengths, bond_angles, torsion_angles):
        bond_lengths = bond_lengths * self.bond_lengths_std.to(bond_lengths.device) + self.bond_lengths_mean.to(bond_lengths.device)
        bond_angles = bond_angles * self.bond_angles_std.to(bond_lengths.device) + self.bond_angles_mean.to(bond_lengths.device)
        torsion_angles = torsion_angles * self.torsion_angles_std.to(bond_lengths.device) + self.torsion_angles_mean.to(bond_lengths.device)
        return bond_lengths, bond_angles, torsion_angles

    def len(self):
        return len(self.frames) - self.tau 

    def no_propagator_get(self, idx):
        # idx = 0 # overfitting, TODO: remove
        data = copy.deepcopy(self.graph)
        # return graph(x_t) and internal coordinates of x_t+tau
        bond_lengths , bond_angles, torsion_angles = self.bond_lengths[idx+self.tau], self.bond_angles[idx+self.tau], self.torsion_angles[idx+self.tau]
        if self.graph_representation == "internal":

            # set correct bond witdths
            data.x[:,2] = self.bond_lengths[idx]
            # set correct bond angles and torsion angles
            data.edge_attr[:self.bond_angles[0].shape[0],0] = self.bond_angles[idx]
            data.edge_attr[self.bond_angles[0].shape[0]:,0] = self.torsion_angles[idx][self.valid_torsion_indices]

        elif self.graph_representation in  {"extrinsic","simple_internal"}:
            # set the correct atom positions of current frame
            data.pos = self.frames[idx]

            # update bond lengths
            data.edge_attr = self.bond_lengths[idx]

        return data , bond_lengths , bond_angles, torsion_angles , self.frames[idx + self.tau] , self.potEnergies[idx + self.tau]


    def propagator_get(self,idx):
        data = copy.deepcopy(self.graph)
        # return graph(x_t) and internal coordinates of x_t+tau
        bond_lengths , bond_angles, torsion_angles = self.bond_lengths[idx], self.bond_angles[idx], self.torsion_angles[idx]
        if self.graph_representation == "internal":

            # set correct bond witdths
            data.x[:,2] = self.bond_lengths[idx]
            # set correct bond angles and torsion angles
            data.edge_attr[:self.bond_angles[0].shape[0],0] = self.bond_angles[idx]
            data.edge_attr[self.bond_angles[0].shape[0]:,0] = self.torsion_angles[idx][self.valid_torsion_indices]

        elif self.graph_representation in  {"extrinsic","simple_internal"}:
            # set the correct atom positions of current frame
            data.pos = self.frames[idx]

            # update bond lengths
            data.edge_attr = self.bond_lengths[idx]

        return data , bond_lengths , bond_angles, torsion_angles , self.frames[idx]

    def get(self, idx):
        if self.no_propagator:
            return self.no_propagator_get(idx)
        else:
            return self.propagator_get(idx)


    def construct_graph(self,mdtraj_topology,traj, graph_representation="internal"):
        logger.info("Constructing graph")

        # get bond indices as defined in the z-matrix of the coordinate transform 
        bonds_indices = torch.cat([self.coordinate_transform.ic_transform.rev_z_indices[:,1][:,None],self.coordinate_transform.ic_transform.rev_z_indices[:,0][:,None]],dim=1)

        # append b1 and b2 from complete internal coordinate transform that cover the bonds between 6,8,14
        bonds_indices = torch.cat([torch.tensor([[6,8]]), torch.tensor([[8,14]]),bonds_indices],dim=0)


        # bond lengths can also be computed with mdtraj
        # bond_lengths_ = torch.from_numpy(mdtraj.compute_distances(traj, bonds_indices.numpy()))
       
        


        # get the bond angles of all angles in the system
        # get angle indices from internal coordinate transform
        angle_indices = torch.cat([self.coordinate_transform.ic_transform.rev_z_indices[:,2][:,None],self.coordinate_transform.ic_transform.rev_z_indices[:,1][:,None],self.coordinate_transform.ic_transform.rev_z_indices[:,0][:,None]],dim=1)
        # stack the angle indices of the cart indices 
        angle_indices = torch.cat([torch.tensor([[6,8,14]]),angle_indices],dim=0)

        # bond angles can also be computed with mdtraj
        # bond_angles_ = torch.from_numpy(mdtraj.compute_angles(traj, angle_indices.numpy()))
        


      

        # not longer needed, since we use the complete internal coordinate transform, just for reference of the important torsions 
        # torsion_indices = np.array([[6,8,14, 16], # psi
        #                             [8,14,16, 18],
        #                             [4,6,8, 14], # phi
        #                             [1,4,6, 8]])



        torsion_indices = torch.cat([self.coordinate_transform.ic_transform.rev_z_indices[:,3][:,None],self.coordinate_transform.ic_transform.rev_z_indices[:,2][:,None],self.coordinate_transform.ic_transform.rev_z_indices[:,1][:,None],self.coordinate_transform.ic_transform.rev_z_indices[:,0][:,None]],dim=1)
        
        # they are the same, however sometimes the torsions from mdtraj are the opposite sign
        # torsion_angles_ = - torch.from_numpy(mdtraj.compute_dihedrals(traj, torsion_indices.numpy()))

        if graph_representation == "internal":
            """Construct a graph of type Data for each frame in the dataset with the following characteristics:
            Nodes represent bonds, with the following attributes:
                - atom1 atomic number 
                - atom2 atomic number
                - bond length
                - atom1 mass
                - atom2 mass
                
            Edges represent either bond angles or torsion angles between two bonds, with the following attributes:
                - boolean indicating if edge is a bond angle or torsion angle (0/1)
                - bond angle or torsion angle
            """

            # find the mapping between the order of bonds in the mdtraj topology and the order of bonds in the complete internal coordinate transform
            # now the indixes of the nodes match the order of the bond lengths from the complete internal coordinate transform
            mdtraj_bond_indices = []
            for bond in torch.tensor([[bond.atom1.index , bond.atom2.index] for bond in mdtraj_topology.bonds]):
                # find the index of that specific bond in the complete internal coordinate transform bond index mapping 
                try:
                    idx = int(torch.where((bond == bonds_indices).sum(dim=1)==2)[0])

                except ValueError:
                    # try reverse bond order
                    idx = int(torch.where((bond.flip(0,) == bonds_indices).sum(dim=1)==2)[0])
                mdtraj_bond_indices.append(idx)
            mdtraj_bonds = list(mdtraj_topology.bonds)


            #  get [atom1 atomic number ,atom2 atomic number,atom1 mass , bond length ,atom2 mass ] for each bond
            #  TODO: should we switch bond order if we used the reversed(flipped) order above?
            #  TODO: maybe expand with torsional bond nodes at some point
            node_attributes = torch.tensor([[mdtraj_bonds[i].atom1.element.atomic_number, mdtraj_bonds[i].atom2.element.atomic_number, self.bond_lengths[0][j] , mdtraj_bonds[i].atom1.element.mass, mdtraj_bonds[i].atom2.element.mass ] for j,i in enumerate(mdtraj_bond_indices)])
            
            # connect two nodes if they define a bond angle
            bond_angle_edge_idx = []
            for angle in angle_indices:
                # find the indices of the bonds that are part of the angle
                try:
                    bond1_idx = int(torch.where((angle[0:2] == bonds_indices).sum(dim=1)==2)[0])
                except ValueError:
                    # try reverse bond order
                    bond1_idx = int(torch.where((angle[0:2].flip(0,) == bonds_indices).sum(dim=1)==2)[0])
                try:
                    bond2_idx = int(torch.where((angle[1::] == bonds_indices).sum(dim=1)==2)[0])
                except ValueError:
                    # try reverse bond order
                    bond2_idx = int(torch.where((angle[1::].flip(0,) == bonds_indices).sum(dim=1)==2)[0])
                bond_angle_edge_idx.append([bond1_idx,bond2_idx])

            # edges between bonds that define a bond angle
            bond_angle_edge_idx = torch.tensor(bond_angle_edge_idx).T
            
            # bond angle edge attributes : [bond_angle, 0]
            bond_angle_edge_attr = torch.cat([self.bond_angles[0][:,None], torch.zeros_like(self.bond_angles[0])[:,None]],dim=1)
            
            # connect two nodes if they define a torsion angle , store the indices of the bonds that define a valid torsion angle
            torsion_angle_edge_idx = []
            valid_torsion_indices = []
            for i,torsion in enumerate(torsion_indices):
                try:
                    # find the indices of the bonds that are part of the torsion
                    try:
                        bond1_idx = int(torch.where((torsion[0:2] == bonds_indices).sum(dim=1)==2)[0])
                    except ValueError:
                        # try reverse bond order
                        bond1_idx = int(torch.where((torsion[0:2].flip(0,) == bonds_indices).sum(dim=1)==2)[0])
                    try:
                        bond2_idx = int(torch.where((torsion[2::] == bonds_indices).sum(dim=1)==2)[0])
                    except ValueError:
                        # try reverse bond order
                        bond2_idx = int(torch.where((torsion[2::].flip(0,) == bonds_indices).sum(dim=1)==2)[0])
                    torsion_angle_edge_idx.append([bond1_idx,bond2_idx])
                    valid_torsion_indices.append(i)
                except ValueError:
                    # if the torsion is not defined by two bonds, skip it
                    pass 
            # store the indices of the valid torsion angles, to select them out of the complete internal coordinate transform torsion angles
            self.valid_torsion_indices = valid_torsion_indices
            # edges between bonds that define a torsion angle
            torsion_angle_edge_idx = torch.tensor(torsion_angle_edge_idx).T

            # torsion angle edge attributes : [torsion_angle, 1]
            torsion_angle_edge_attr = torch.cat([self.torsion_angles[0][self.valid_torsion_indices,None], torch.ones_like(self.torsion_angles[0])[self.valid_torsion_indices,None]],dim=1)

            # combine bond angle and torsion angle edges
            edge_idx = torch.cat([bond_angle_edge_idx, torsion_angle_edge_idx], dim=1)
            edge_attr = torch.cat([bond_angle_edge_attr, torsion_angle_edge_attr], dim=0)

            # TODO: print graph and verify
            self.graph = Data(x=node_attributes, edge_index=edge_idx, edge_attr=edge_attr)

        elif graph_representation == "extrinsic" or graph_representation == "simple_internal":
            """Construct a graph of type Data for each frame in the dataset with the following characteristics:
            Nodes represent atoms, with the following attributes:
                - atomic number
                - mass
 
                Edges represent bonds, with the following attributes:
                - bond type"""
            
            # get [atomic number, mass, charge] for each atom
            node_attributes = torch.tensor([[atom.element.atomic_number, atom.element.mass] for atom in mdtraj_topology.atoms])

            # get atom positions 
            atom_positions = torch.from_numpy(traj.xyz[0])

            # get [bond type] for each bond
            # TODO: get bond type ? 
            # edge_attr = torch.tensor([[bond.type if bond.type is not None else 0] for bond in mdtraj_topology.bonds])
            edge_attr = self.bond_lengths[0]


            # get the indices of the atoms that define a bond
            edge_idx = bonds_indices.T


            self.graph = Data(x=node_attributes, edge_index=edge_idx, edge_attr=edge_attr, pos=atom_positions , torsion_indices=torsion_indices)


        else:
            raise ValueError("Graph representation not supported")
        

def construct_loader(args):

    # create data splits for training, validation and testing, fix random seed for reproducibility
    train_set, val_set , test_set = random_split(ALDPDataset(args,args.testsystem, transform=args.feature_transform), [0.7,0.2,0.1], generator=torch.Generator().manual_seed(199))

    logger.info(f"Number of training frames: {len(train_set)} | Number of validation frames: {len(val_set)} | Number of test frames: {len(test_set)}")
    train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, drop_last=args.dataloader_drop_last , pin_memory=True)
    val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, drop_last=args.dataloader_drop_last, pin_memory=True)

    return train_loader, val_loader , train_set.dataset.coordinate_transform

            

        

   

        
        





        
    

       
    


        

      
       