import os

from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Geometry import Point3D
from rdkit import RDLogger
import imageio
import networkx as nx
import numpy as np
import rdkit.Chem
import wandb
import matplotlib.pyplot as plt




class MolecularVisualization:
#make this a callback like in this code
    def __init__(self, remove_h, dataset_infos):
        self.remove_h = remove_h
        self.dataset_infos = dataset_infos

    def mol_from_graphs(self, node_list, adjacency_matrix):
        """
        Convert graphs to rdkit molecules
        node_list: the nodes of a batch of nodes (bs x n)
        adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
        """
        # dictionary to map integer value to the char of atom
        atom_decoder = self.dataset_infos.atom_decoder

        # create empty editable mol object
        mol = Chem.RWMol()

        # add atoms to mol and keep track of index
        node_to_idx = {}
        for i in range(len(node_list)):
            if node_list[i] == -1:
                continue
            a = Chem.Atom(atom_decoder[int(node_list[i])])
            molIdx = mol.AddAtom(a)
            node_to_idx[i] = molIdx

        for ix, row in enumerate(adjacency_matrix):
            for iy, bond in enumerate(row):
                # only traverse half the symmetric matrix
                if iy <= ix:
                    continue
                if bond == 1:
                    bond_type = Chem.rdchem.BondType.SINGLE
                elif bond == 2:
                    bond_type = Chem.rdchem.BondType.DOUBLE
                elif bond == 3:
                    bond_type = Chem.rdchem.BondType.TRIPLE
                elif bond == 4:
                    bond_type = Chem.rdchem.BondType.AROMATIC
                else:
                    continue
                mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)

        try:
            mol = mol.GetMol()
        except rdkit.Chem.KekulizeException:
            print("Can't kekulize molecule")
            mol = None
        return mol


        # if pl_module.cfg.visual.save_mols:
        #     # we save the figures here.
        #     save_molist(
        #         path=path,
        #         molecule_list=self.outputs,
        #         index2atom=pl_module.cfg.dataset.atom_decoder,
        #     )
        #     if pl_module.cfg.visual.visual_nums > 0:
        #         images = visualize(
        #             path=path,
        #             atom_decoder=pl_module.cfg.dataset.atom_decoder,
        #             color_dic=pl_module.cfg.dataset.colors_dic,
        #             radius_dic=pl_module.cfg.dataset.radius_dic,
        #             max_num=pl_module.cfg.visual.visual_nums,
        #         )
        #         # table = [[],[]]
        #         table = []
        #         for p_ in images:
        #             im = plt.imread(p_)
        #             table.append(wandb.Image(im))
        #             # if len(table[0]) < 5:
        #             #     table[0].append(wandb.Image(im))
        #             # else:
        #             #     table[1].append(wandb.Image(im))
        #         # pl_module.logger.log_table(key="epoch {}".format(epoch),data=table,columns= ['1','2','3','4','5'])
        #         pl_module.logger.log_image(key="epoch {}".format(epoch), images=table)
        #         # wandb.log()
        #         # update to wandb
        # if pl_module.cfg.visual.visual_chain:
        #     # we save the chains and visual the gif here.
        #     # print(len(self.chain_outputs),chain_path)
        #     save_molist(
        #         path=chain_path,
        #         molecule_list=self.chain_outputs,
        #         index2atom=pl_module.cfg.dataset.atom_decoder,
        #     )
        #     # if pl_module.cfg.visual.visual_nums > 0:
        #     gif_path = visualize_chain(
        #         path=chain_path,
        #         atom_decoder=pl_module.cfg.dataset.atom_decoder,
        #         color_dic=pl_module.cfg.dataset.colors_dic,
        #         radius_dic=pl_module.cfg.dataset.radius_dic,
        #         spheres_3d=False,
        #     )
        #     gifs = wandb.Video(gif_path)
        #     columns = ["Generation Path"]
        #     pl_module.logger.log_table(
        #         key="epoch_{}".format(epoch), data=[[gifs]], columns=columns
        #     )



    def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph',epoch=None):
        #add epoch parameter.

        # define path to save figures
        if not os.path.exists(path):
            os.makedirs(path)

        # visualize the final molecules
        print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
        if num_molecules_to_visualize > len(molecules):
            print(f"Shortening to {len(molecules)}")
            num_molecules_to_visualize = len(molecules)
        
        table = []
        #         for p_ in images:
        #             im = plt.imread(p_)
        #             table.append(wandb.Image(im))
        for i in range(num_molecules_to_visualize):
            file_path = os.path.join(path, 'molecule_{}.png'.format(i))
            mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
            try:
                Draw.MolToFile(mol, file_path)
                if wandb.run and log is not None:
                    print(f"Saving {file_path} to wandb")
                    wandb.log({log: wandb.Image(file_path)}, commit=True)
                    table.append(wandb.Image(file_path))
            except rdkit.Chem.KekulizeException:
                print("Can't kekulize molecule")
        
        return  table

        # pl_trainer.log(key="epoch {}".format(epoch), images=table,commit=True)
                    

            



    def visualize_chain(self, path, nodes_list, adjacency_matrix, epoch=None):
        RDLogger.DisableLog('rdApp.*')
        # convert graphs to the rdkit molecules
        mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]

        # find the coordinates of atoms in the final molecule
        final_molecule = mols[-1]
        AllChem.Compute2DCoords(final_molecule)

        coords = []
        for i, atom in enumerate(final_molecule.GetAtoms()):
            positions = final_molecule.GetConformer().GetAtomPosition(i)
            coords.append((positions.x, positions.y, positions.z))

        # align all the molecules
        for i, mol in enumerate(mols):
            AllChem.Compute2DCoords(mol)
            conf = mol.GetConformer()
            for j, atom in enumerate(mol.GetAtoms()):
                x, y, z = coords[j]
                conf.SetAtomPosition(j, Point3D(x, y, z))

        # draw gif
        save_paths = []
        num_frams = nodes_list.shape[0]

        for frame in range(num_frams):
            file_name = os.path.join(path, 'fram_{}.png'.format(frame))
            Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
            save_paths.append(file_name)

        imgs = [imageio.imread(fn) for fn in save_paths]
        # Save all gifs in a separate folder
        gif_parent_path = os.path.join(os.path.dirname(path), 'gifs')
        if not os.path.exists(gif_parent_path):
            os.makedirs(gif_parent_path)
        # If we indent the following 3 lines, then if the gifs directory exits, our new results won't overwrite the old contents. Now we are overwriting
        gif_path = os.path.join(gif_parent_path, '{}.gif'.format(path.split('/')[-1]))
        imgs.extend([imgs[-1]] * 10)
        imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)

        gifs = wandb.Video(gif_path, fps=5, format="gif")
        columns = ["Generation Path"]
        # pl_trainer.log(
        #        key="epoch_{}".format(epoch), data=[[gifs]], columns=columns,commit=True
        #     )

        # return "epoch {}".format(epoch), 
        if wandb.run:
            print(f"Saving {gif_path} to wandb")
            wandb.log({"chain": gifs}, commit=True)
      
        # draw grid image
        try:
            img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
            img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
            # Also save the grid images of all molecules in a separate folder
            grid_parent_path = os.path.join(os.path.dirname(path), 'grid_images')
            if not os.path.exists(grid_parent_path):
                os.makedirs(grid_parent_path)
            grid_path = os.path.join(grid_parent_path, '{}_grid_image.png'.format(path.split('/')[-1]))
            img.save(grid_path)
        except Chem.rdchem.KekulizeException:
            print("Can't kekulize molecule")
        return mols, [[gifs]], columns

    def visualize_input_output_dist(self, result_path, input_sample_path, output_sample_path, epoch=None):
        # import glob
        # input_frame_paths = [framefile for framefile in glob.glob(os.path.join(input_sample_path, "*"))]
        # output_frame_paths = [framefile for framefile in glob.glob(os.path.join(output_sample_path, "*"))]
        input_frame_paths = [os.path.join(input_sample_path, f"fram_{i}.png") for i in range (50)]
        output_frame_paths = [os.path.join(output_sample_path, f"fram_{i}.png") for i in range (50)]
        comparison_img = self.FramesToComparisonGrid(input_frame_paths, output_frame_paths)
        comparison_img.save(result_path)

    def FramesToComparisonGrid(self, input_frame_paths, output_frame_paths, num_frames=11):
        from PIL import Image
        assert len(input_frame_paths) > num_frames, "num_frames must be not be greater than the number of available frames"
        subImgSize = Image.open(input_frame_paths[0]).size 
        assert subImgSize == Image.open(output_frame_paths[0]).size, "output and input sample should, by default, have the same frame size"
        assert len(input_frame_paths) == len(output_frame_paths)
        img_len = num_frames
        res = Image.new("RGBA", (img_len * subImgSize[0], 2 * subImgSize[1]), (255, 255, 255, 0))
        for i in range(0, num_frames-1):
            frame_number = i * ((len(input_frame_paths))//(num_frames-1))
            input_fn = input_frame_paths[frame_number]
            output_fn = output_frame_paths[frame_number]
            col = i
            input_frame = Image.open(input_fn)
            output_frame = Image.open(output_fn)
            res.paste(input_frame, (col * subImgSize[0], 0 * subImgSize[1]))
            res.paste(output_frame, (col * subImgSize[0], 1 * subImgSize[1]))
        input_fn = input_frame_paths[-1]
        output_fn = output_frame_paths[-1]
        col = num_frames-1
        input_frame = Image.open(input_fn)
        output_frame = Image.open(output_fn)
        res.paste(input_frame, (col * subImgSize[0], 0 * subImgSize[1]))
        res.paste(output_frame, (col * subImgSize[0], 1 * subImgSize[1]))
        return res

    def plot_entropy(self, result_path, entropy_list, legend):
        import matplotlib.pyplot as plt
        plt.plot(entropy_list)
        plt.xlabel('Timestep')
        plt.ylabel('Mean Entropy')
        plt.title('Mean Entropy vs. Timestep')
        plt.legend([legend])
        plt.savefig(os.path.join(result_path, "input_dist_entropy"))


class NonMolecularVisualization:
    def to_networkx(self, node_list, adjacency_matrix):
        """
        Convert graphs to networkx graphs
        node_list: the nodes of a batch of nodes (bs x n)
        adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
        """
        graph = nx.Graph()

        for i in range(len(node_list)):
            if node_list[i] == -1:
                continue
            graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])

        rows, cols = np.where(adjacency_matrix >= 1)
        edges = zip(rows.tolist(), cols.tolist())
        for edge in edges:
            edge_type = adjacency_matrix[edge[0]][edge[1]]
            graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)

        return graph

    def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
        if largest_component:
            CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
            CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            graph = CGs[0]

        # Plot the graph structure with colors
        if pos is None:
            pos = nx.spring_layout(graph, iterations=iterations)

        # Set node colors based on the eigenvectors
        w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
        vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
        m = max(np.abs(vmin), vmax)
        vmin, vmax = -m, m

        plt.figure()
        nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
                cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')

        plt.tight_layout()
        plt.savefig(path)
        plt.close("all")

    def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph', epoch=None):
        # define path to save figures
        if not os.path.exists(path):
            os.makedirs(path)

        # visualize the final molecules
        for i in range(num_graphs_to_visualize):
            file_path = os.path.join(path, 'graph_{}.png'.format(i))
            graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
            self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
            im = plt.imread(file_path)
            if wandb.run and log is not None:
                wandb.log({log: [wandb.Image(im, caption=file_path)]})

    def visualize_chain(self, path, nodes_list, adjacency_matrix, epoch=None):
        # convert graphs to networkx
        graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
        # find the coordinates of atoms in the final molecule
        final_graph = graphs[-1]
        final_pos = nx.spring_layout(final_graph, seed=0)

        # draw gif
        save_paths = []
        num_frams = nodes_list.shape[0]

        for frame in range(num_frams):
            file_name = os.path.join(path, 'fram_{}.png'.format(frame))
            self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
            save_paths.append(file_name)

        imgs = [imageio.imread(fn) for fn in save_paths]

        # Draw Gifs and save them in a separate folder
        gif_parent_path = os.path.join(os.path.dirname(path), 'gifs')
        if not os.path.exists(gif_parent_path):
            os.makedirs(gif_parent_path)
        gif_path = os.path.join(gif_parent_path, '{}.gif'.format(path.split('/')[-1]))
        imgs.extend([imgs[-1]] * 10)
        imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
        
        gifs = [wandb.Video(gif_path, caption=gif_path, format="gif")]
        columns = ["Generation Path"]

        if wandb.run:
            print(f"Saving {gif_path} to wandb")
            wandb.log({'chain': gifs}, commit=True)
        
        # Draw Grid Images and save them in a separate folder
        grid_img = self.FramesToGridImage(save_paths)
        grid_img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
        # Also save the grid images of all molecules in a separate folder
        grid_parent_path = os.path.join(os.path.dirname(path), 'grid_images')
        if not os.path.exists(grid_parent_path):
            os.makedirs(grid_parent_path)
        grid_path = os.path.join(grid_parent_path, '{}_grid_image.png'.format(path.split('/')[-1]))
        grid_img.save(grid_path)
            
        return graphs, [[gifs]], columns
    
    # def visualize_input_output_dist(self, path, output_nodes_list, output_adjacency_matrix, input_nodes_list, input_adjacency_matrix, epoch=None):
    #     print

    def visualize_input_output_dist(self, result_path, input_sample_path, output_sample_path, epoch=None):
        # import glob
        # input_frame_paths = [framefile for framefile in glob.glob(os.path.join(input_sample_path, "*"))]
        # output_frame_paths = [framefile for framefile in glob.glob(os.path.join(output_sample_path, "*"))]
        input_frame_paths = [os.path.join(input_sample_path, f"fram_{i}.png") for i in range (50)]
        output_frame_paths = [os.path.join(output_sample_path, f"fram_{i}.png") for i in range (50)]
        comparison_img = self.FramesToComparisonGrid(input_frame_paths, output_frame_paths)
        comparison_img.save(result_path)
        
    def FramesToComparisonGrid(self, input_frame_paths, output_frame_paths, num_frames=11):
        from PIL import Image
        assert len(input_frame_paths) > num_frames, "num_frames must be not be greater than the number of available frames"
        subImgSize = Image.open(input_frame_paths[0]).size 
        assert subImgSize == Image.open(output_frame_paths[0]).size, "output and input sample should, by default, have the same frame size"
        assert len(input_frame_paths) == len(output_frame_paths)
        img_len = num_frames
        res = Image.new("RGBA", (img_len * subImgSize[0], 2 * subImgSize[1]), (255, 255, 255, 0))
        for i in range(0, num_frames-1):
            frame_number = i * ((len(input_frame_paths))//(num_frames-1))
            input_fn = input_frame_paths[frame_number]
            output_fn = output_frame_paths[frame_number]
            col = i
            input_frame = Image.open(input_fn)
            output_frame = Image.open(output_fn)
            res.paste(input_frame, (col * subImgSize[0], 0 * subImgSize[1]))
            res.paste(output_frame, (col * subImgSize[0], 1 * subImgSize[1]))
        input_fn = input_frame_paths[-1]
        output_fn = output_frame_paths[-1]
        col = num_frames-1
        input_frame = Image.open(input_fn)
        output_frame = Image.open(output_fn)
        res.paste(input_frame, (col * subImgSize[0], 0 * subImgSize[1]))
        res.paste(output_frame, (col * subImgSize[0], 1 * subImgSize[1]))
        return res
    
    def FramesToGridImage(self, frame_paths, framesPerRow=10):
        from PIL import Image
        subImgSize = Image.open(frame_paths[0]).size
        res = Image.new("RGBA", (framesPerRow * subImgSize[0], ((len(frame_paths) + framesPerRow -1 ) //framesPerRow) * subImgSize[1]), (255, 255, 255, 0))
        for i, fn in enumerate(frame_paths):
            row = i // framesPerRow
            col = i % framesPerRow
            frame = Image.open(fn)
            res.paste(frame, (col * subImgSize[0], row * subImgSize[1]))
        return res

    def plot_entropy(self, result_path, entropy_list, legend):
        import matplotlib.pyplot as plt
        plt.plot(entropy_list)
        plt.xlabel('Timestep')
        plt.ylabel('Mean Entropy')
        plt.title('Mean Entropy vs. Timestep')
        plt.legend([legend])
        plt.savefig(os.path.join(result_path, "input_dist_entropy"))

