"""Script to generate samples from a trained model.
For line 78, its hard coded, plz change to your genrated pkl
# change to the generated pkl
starter_list = open_pickle('experimental_result/rdkit.pkl')
"""

import argparse
import datetime
import os
import os.path as osp
import time
import re

import numpy as np
import torch


# from lightning import seed_everything
from loguru import logger as log
from torch_geometric.data import Batch, Data
from tqdm import tqdm
from utils.utils import instantiate_model, read_yaml
import pickle
from utils.commons.io import save_pkl
from dataset.dataset import InferenceDataset
from refiner.models.model_GlobalRefiner import GlobalRefiner

from rdkit import Chem
from rdkit.Chem import AllChem

torch.set_float32_matmul_precision("high")

def open_pickle(mol_path):
    with open(mol_path, "rb") as f:
        dic = pickle.load(f)
    return dic

def get_datatime():
    return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def main(config: dict, checkpoint_path: str, output_dir: str, debug: bool):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log.info(f"Using {device} for sampling.")

    # instantiate datamodule and model
    dataset = InferenceDataset(
        partition=config["datamodule_args"]["partition"],
        split="test",
    )
    model = instantiate_model(config["model"], config["model_args"])
    print(dataset)

    

    # load model weights
    log.info(f"Loading model weights from {checkpoint_path}")
    state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
    model.load_state_dict(state_dict)

    # move to device
    model: GlobalRefiner = model.to(device)
    model.eval()

    # max batch size
    eval_args = config.get("eval_args", {})
    max_batch_size = eval_args.get("batch_size", 32)

    # load indices
    num_test_samples = len(dataset)
    data_list = {}
    v_total_dict = {}
    times = []


    # change to the generated pkl
    starter_list = open_pickle('experimental_result/rdkit.pkl')
    reorganized_dict = {}
    for data in starter_list:
        pos_list = []
        for pos in data.pos_gen:
            pos = torch.tensor(pos, dtype=torch.float32)
            pos_list.append(pos)

        reorganized_dict[data.smiles] = pos_list
    

    for idx in tqdm(range(num_test_samples)):
        data = dataset[idx]        

        # get data for batch_size
        smiles = data.smiles
        log.info(f"Generating conformers for molecule: {smiles}")
        try: 
            mols_pos = reorganized_dict[smiles]
            print(f'success {smiles}')

        except:
            print(f'fail {smiles}')
            continue


        # calculate number of samples to generated
        pos_ref: torch.Tensor = data.pos.cpu().numpy()
        pos_prior = torch.stack(mols_pos)
        
        count = pos_ref.shape[0]  # number of conformers
        num_samples = 2 * count
        pos_gen = []
        data.pos = data.pos[0]
        for batch_start in range(0, num_samples, max_batch_size):
             
            # get batch_size
            batch_size = min(max_batch_size, num_samples - batch_start)
            batch_end = batch_start + batch_size

            pos_start = pos_prior[batch_start:batch_end,:,:]
            pos_start = pos_start.view(-1,3)
            
            # batch the data
            batched_data = Batch.from_data_list([data] * batch_size)
            batched_data.start_pos = pos_start 

            (
                z,
                edge_index,
                batch,
                node_attr,
                chiral_index,
                chiral_nbr_index,
                chiral_tag,
                start_pos,
            ) = (
                batched_data["atomic_numbers"].to(device),
                batched_data["edge_index"].to(device),
                batched_data["batch"].to(device),
                batched_data["node_attr"].to(device),
                batched_data["chiral_index"].to(device),
                batched_data["chiral_nbr_index"].to(device),
                batched_data["chiral_tag"].to(device),
                batched_data["start_pos"].to(device),
            )
            # get time-estimate
            start = time.time()
            with torch.no_grad():
                # generate samples
                pos, v_dict = model.sample(
                    z,
                    edge_index,
                    batch,
                    node_attr=node_attr,
                    chiral_index=chiral_index,
                    chiral_nbr_index=chiral_nbr_index,
                    chiral_tag=chiral_tag,
                    start_pos=start_pos,
                    **eval_args.get("sampler_args", {}),
                )
            end = time.time()
            times.append((end - start) / batch_size)  # store time per conformer

            for key in v_dict:
                if key not in v_total_dict:
                    v_total_dict[key] = []
                v_total_dict[key] += v_dict[key]
            # reshape to (num_samples, num_atoms, 3) using batch
            pos = pos.view(batch_size, -1, 3).cpu().detach().numpy()

            # append to generated_positions
            pos_gen.append(pos)

        # concatenate generated_positions: (num_samples, num_atoms, 3)
        pos_gen = np.concatenate(pos_gen, axis=0)

        data_list[smiles] = Data(
            smiles=smiles, pos_ref=pos_ref, rdmol=data.mol, pos_gen=pos_gen
        )
        if debug:
            break
    save_pkl(
        os.path.join(output_dir, "velocity_pred.pkl"), v_total_dict
    )

    if not debug:
        save_pkl(
            os.path.join(output_dir, "generated_files.pkl"), list(data_list.values())
        )



if __name__ == "__main__":
    # argparse checkpoint path
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True)
    parser.add_argument("--checkpoint", "-k", type=str, required=True)
    parser.add_argument("--output_dir", "-o", type=str, required=False, default="logs/")
    parser.add_argument("--debug", "-d", action="store_true")

    args = parser.parse_args()

    # debug mode
    debug = args.debug
    log.info(f"Debug mode: {debug}")

    # read config
    assert osp.exists(args.config), "Config path does not exist."
    log.info(f"Loading config from: {args.config}")
    config = read_yaml(args.config)
    task_name = config.get("task_name", "default")

    # get checkpoint path
    checkpoint_path = args.checkpoint
    assert osp.exists(checkpoint_path), "Checkpoint path does not exist."

    # setup output directory for storing samples
    output_dir = osp.join(
        args.output_dir,
        f"samples/{task_name}/{get_datatime()}",
    )
    if not debug:
        os.makedirs(output_dir, exist_ok=True)

    main(config, checkpoint_path, output_dir, debug=debug)
