from os.path import join

from sklearn import tree
from tqdm import tqdm
import rdkit.Chem as Chem
import torch
import numpy as np

from bond_type_prediction.average_model import process_xyz_no_smiles


class TreeEdgeEModel:
    def __init__(self, data_dir) -> None:
        self.data_dir = data_dir

        # We start be some data loading
        self.all_smiles = {}
        for split in ['train', 'valid', 'test']:
            # Read SMILES file
            smiles_path = join(data_dir, f'smiles/{split}.txt')
            with open(smiles_path, "rb") as file:
                self.all_smiles[split] = file.readlines()
            self.all_smiles[split] = [smiles.decode("utf-8").strip("\n") for smiles in self.all_smiles[split]]

    def prepare_data(self):
        self.data = {}
        for split in ['train', 'valid', 'test']:
            print(f'Preparing data for split {split}')
            X, y = [], []
            
            for idx in tqdm(range(len(self.all_smiles[split]))):
                # construct current 2D molecule to get true bonds
                smiles = self.all_smiles[split][idx]
                try:
                    mol = Chem.MolFromSmiles(smiles)
                    Chem.Kekulize(mol)
                except:
                    continue
                mol = Chem.AddHs(mol)
                adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)

                # read current 3D molecule to compute stats
                xyz_path = join(self.data_dir, f'synthetic_coords_rdkit/{split}/mol_{idx}.xyz')
                with open(xyz_path, 'r') as file:
                    molecule = process_xyz_no_smiles(file)
                positions, atom_type = molecule['positions'], molecule['atomic_numbers']
                assert Chem.MolToSmiles(Chem.MolFromSmiles(molecule['smiles'])) == Chem.MolToSmiles(Chem.MolFromSmiles(smiles)), "Smiles do not match"

                # compute pairwise distances
                distance_matrix = torch.cdist(positions, positions, p=2)

                # this will only consider the upper diagonal part of the matrix
                for i in range(len(atom_type)):
                    for j in range(i+1, len(atom_type)):
                        x_ij = [distance_matrix[i, j].item(), atom_type[i].item(), atom_type[j].item()]
                        y_ij = int(adj_matrix[i, j])
                        X.append(x_ij)
                        y.append(y_ij)

            self.data[f'X_{split}'] = X
            self.data[f'y_{split}'] = y

    def fit_tree(self, max_depth=None, min_samples_leaf=1):
        self.model = tree.DecisionTreeClassifier(max_depth=max_depth, min_samples_leaf=min_samples_leaf, class_weight="balanced")
        self.model = self.model.fit(self.data['X_train'], self.data['y_train'])

    def eval(self):
        y_pred = self.model.predict(self.data['X_valid'])
        edge_accuracy = np.mean(y_pred == self.data['y_valid']) * 100
        return edge_accuracy

    def get_molecule_accuracy(self):
        n_correct_molecules = 0
        n_total_molecules = 0
        for idx in tqdm(range(len(self.all_smiles['valid']))):
            # construct current 2D molecule to get true bonds
            smiles = self.all_smiles['valid'][idx]
            try:
                mol = Chem.MolFromSmiles(smiles)
                Chem.Kekulize(mol)
            except:
                continue
            mol = Chem.AddHs(mol)
            adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)

            # read current 3D molecule to compute stats
            xyz_path = join(self.data_dir, f'synthetic_coords_rdkit/valid/mol_{idx}.xyz')
            with open(xyz_path, 'r') as file:
                molecule = process_xyz_no_smiles(file)
            positions, atom_type = molecule['positions'], molecule['atomic_numbers']

            # compute pairwise distances
            distance_matrix = torch.cdist(positions, positions, p=2)

            X = []
            y = []
            # this will only consider the upper diagonal part of the matrix
            for i in range(len(atom_type)):
                for j in range(i+1, len(atom_type)):
                    x_ij = [distance_matrix[i, j].item(), atom_type[i].item(), atom_type[j].item()]
                    y_ij = int(adj_matrix[i, j])
                    X.append(x_ij)
                    y.append(y_ij)

            y_pred = self.model.predict(X)
            n_correct_molecules += int(np.all(y_pred == y))
            n_total_molecules += 1
        
        return (n_correct_molecules / n_total_molecules) * 100
