import argparse
import numpy as np
import importlib
from datetime import datetime
import os
import shutil
import copy
import torch
import GPUtil
import pickle
import scipy as sp
from scipy import ndimage
import time
import jax
import matplotlib.pyplot as plt
import matplotlib as mpl
import plotly.graph_objects as go
import tree
from plotly.subplots import make_subplots

from data import residue_constants
from data import diffuser
from data import utils as du
from data import rocklin_dataset
from data import top7_dataset
from data import pdb_dataset
from data import protein
from data import denovo_dataset
from model import reverse_diffusion

from inpainting import inpaint, particle_filter, inpaint_experiment, motif_problems


from experiments import torch_train_diffusion
from analysis import plotting
from analysis import utils as au

def gather_args():
    parser = argparse.ArgumentParser(
        description='Inpainting script.')
    parser.add_argument(
        '--output_dir',
        help='path to directory to save files into',
        type=str)
    parser.add_argument(
        '--run_name',
        help='additional string to add to saved filenames',
        default='',
        type=str)
    parser.add_argument(
        '--target',
        help='Target for inpainting, must be ["rsv", "EFHand", "2onq"]',
        type=str)
    parser.add_argument(
        '--motif_range',
        help=('motif range -- either (motif_start,motif_end) or' +
            '(motif_start,motif_end,motif2_start,motif2_end)'),
        type=str)
    parser.add_argument(
        '--num_samples',
        help='Number of samples per residue length.',
        type=int,
        default=1)
    return parser.parse_args()

def run():
    args = gather_args()
    N_forward_diffusions = args.num_samples
    N_samples_per_diffusion = 2
    N_particles = 64


    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    chosen_gpu = ''.join(
        [str(x) for x in GPUtil.getAvailable(order='memory')])
    os.environ["CUDA_VISIBLE_DEVICES"] = chosen_gpu
    move_to_np = lambda x: x.cpu().detach().numpy()
    print("chosen GPU: ", chosen_gpu)


    # Define output directory for saving pdb files
    output_dir = args.output_dir

    base_dirs = [
        './protein_diffusion'
    ]
    for base_dir in base_dirs:
        if os.path.isdir(base_dir):
            protein_diffusion_base_dir = base_dir
            print(base_dir)
            break


    # Read ckpt
    #ckpt_basename = '18D_04M_2022Y_11h_31m_36s' # used for rsv and EF-Hand
    #ckpt_basename = '30D_04M_2022Y_12h_22m_41s' # from April 30 monomer model
    ckpt_basename = '26D_05M_2022Y_20h_10m_17s' # from May 26 monomer model


    ckpt_dir = os.path.join(protein_diffusion_base_dir, f'torch_train_diffusion/{ckpt_basename}/')
    ckpt_path = os.path.join(ckpt_dir, os.listdir(ckpt_dir)[0]).replace('.pth', '.pkl')

    print("ckpt path: ", ckpt_path)
    ckpt_pkl = du.read_pkl(ckpt_path)
    ckpt_cfg = ckpt_pkl['cfg']
    ckpt_state = ckpt_pkl['exp_state']


    # Set-up experiment
    data_setting = 'pdb'
    cfg = torch_train_diffusion.get_config(data_setting=data_setting)
    cfg.experiment.update(ckpt_cfg.experiment)
    cfg.experiment.data_setting = data_setting
    cfg.model.update(ckpt_cfg.model)
    cfg.data.max_len = ckpt_cfg.data.max_len
    cfg.data.inpainting_training = False

    exp_cfg = cfg.experiment
    exp = torch_train_diffusion.Experiment(cfg)

    # Load weights
    exp.model.load_state_dict(ckpt_state)

    # Set to eval mode
    exp.model.eval()


    pdb_name = args.target
    motif_range = args.motif_range
    print('loading pdb and motif: ', pdb_name, motif_range)
    if pdb_name == "rsv":
        motif_start, motif_end = [int(motif_pos) for motif_pos in
                motif_range.split(",")]

        (pdb_name, target_len, motif_ca_xyz, full_ca_xyz_true,
            motif_idcs, inpainting_task_name) = motif_problems.load_rsv_motif_problem(
                    motif_start, motif_end)
    elif pdb_name == "EFHand":
        motif_start, motif_end, motif2_start, motif2_end = [
            int(motif_pos) for motif_pos in motif_range.split(",")]
        (pdb_name, target_len, motif_ca_xyz, full_ca_xyz_true,
                motif_idcs, inpainting_task_name) = \
        motif_problems.load_EFHand_motif_problem(
                motif_start, motif_end, motif2_start, motif2_end)
    else:
        assert pdb_name in ["6exz", "6e6r", "5trv", "5ci9"]
        motif_start, motif_end = [int(motif_pos) for motif_pos in
                motif_range.split(",")]

        (pdb_name, target_len, motif_ca_xyz, full_ca_xyz_true,
            motif_idcs, inpainting_task_name) = motif_problems.load_pdb_motif_problem(
                    motif_start, motif_end, pdb_name)

    if args.run_name is not None:
        inpainting_task_name += "_" + args.run_name

    # Separately save target and motif
    du.save_bb_as_pdb(full_ca_xyz_true[:target_len], output_dir + pdb_name +
            '.pdb')
    du.save_bb_as_pdb(motif_ca_xyz, output_dir + inpainting_task_name + '.pdb')

    prot_diffuser, T = exp.diffuser, exp_cfg.T

    # Run inpainting without particle filter (with replacement and fixed methods)
    for inpaint_method in ["replacement", "fixed"]:
        print("\n\ninpaint method: ", inpaint_method)
        for j in range(N_forward_diffusions):
            print("\n\nrep: %d"%j)
            batch_name = inpainting_task_name + "_rep_%02d"%j
            motif_forward_diffusion = inpaint.diffuse_motif(motif_ca_xyz, prot_diffuser, T)
            inpaint_experiment.run_inpainting(
                exp, target_len, motif_ca_xyz, motif_idcs, prot_diffuser, T, N_samples_per_diffusion,
                batch_name, output_dir, inpaint_method=inpaint_method, num_save=N_samples_per_diffusion)

    # Run inpainting without particle filter (with replacement and fixed methods)
    inpaint_method = "particle"
    print("\n\ninpaint method: ", inpaint_method)
    for j in range(N_forward_diffusions):
        print("rep: %d"%j)
        batch_name = inpainting_task_name + "_rep_%02d"%j
        motif_forward_diffusion = inpaint.diffuse_motif(motif_ca_xyz, prot_diffuser, T)
        inpaint_experiment.run_inpainting(
            exp, target_len, motif_ca_xyz, motif_idcs, prot_diffuser, T, N_particles,
            batch_name, output_dir, inpaint_method=inpaint_method, num_save=N_samples_per_diffusion)



if __name__ == '__main__':
    run()
