import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--sim_ckpt', type=str, default=None, required=True)
parser.add_argument('--config', type=str, default=None, required=True)
parser.add_argument('--data_dir', type=str, default=None, required=True)
parser.add_argument('--data_dir_coarse', type=str, default='/low_resolution_numpy_path')
parser.add_argument('--suffix', type=str, default='i40')
parser.add_argument('--num_frames', type=int, default=400)
parser.add_argument('--threshold', type=int, default = 2)
parser.add_argument('--num_rollouts', type=int, default=1)
parser.add_argument('--xtc', action='store_true')
parser.add_argument('--get_score', action='store_true')
parser.add_argument('--out_dir', type=str, default=".")
parser.add_argument('--split', type=str, default='/mdcath/test.csv')
parser.add_argument('--mode', choices=['gradient', 'data'], default='data')  # Set default to 'data'
parser.add_argument("--overfit", action='store_true')
parser.add_argument('--frame_interval', type=int, default=None)
parser.add_argument('--crop', type=int, default=240)
parser.add_argument('--coarse_start', type=int, default=0)
parser.add_argument('--atlas', action='store_true')
parser.add_argument('--copy_frames', action='store_true')
parser.add_argument('--no_frames', action='store_true')
args = parser.parse_args()

import os, torch, mdtraj, time
import numpy as np
from SDE_model.transport.protein_sde import ProteinSDE
from SDE_model.utils import atom4_to_pdb, atom14_to_pdb
from SDE_model.model.config import ModelConfig
from SDE_model.dataset_infer_full import MDGenDataset


# folder = int(args.coarse_start/20)
# args.out_dir = args.out_dir + f'/{folder}'
os.makedirs(args.out_dir, exist_ok=True)


def rollout(model, num_frames, batch, threshold):
    atom4, loss = model.inference(batch, num_frames, threshold, args.mode)
    return atom4, loss


def do(model, data_loader):
    start = time.time()
    total_loss = []
    
    for batch in data_loader:
        if isinstance(batch, list):
            batch = {i: v for i, v in enumerate(batch)}  
        elif not isinstance(batch, dict):
            raise ValueError("Batch must be a list or a dictionary.")
        batch = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 
        name = batch['name'][0]

        gt_arr = np.lib.format.open_memmap(f'{args.data_dir}/{name}_{args.suffix}.npy', 'r')
        gt_path = os.path.join(args.out_dir, f'{name}_gt.pdb')
   
        start = time.time()
        all_atom4, loss = rollout(model, args.num_frames, batch, args.threshold)
        print('time is :', time.time() - start)        
        total_loss.append(loss)
        all_atom4 = all_atom4[0]

        non_zero_mask = batch['mask'] != 0
        protein_len = torch.sum(non_zero_mask)
        

        all_atom4 = all_atom4[:, :protein_len, ...]
        batch['seqres'] = batch['seqres'][:, :protein_len, ...]
        path = os.path.join(args.out_dir, f'{name}.pdb')
        atom4_to_pdb(all_atom4[..., :4,:].cpu().numpy(), batch['seqres'][0].cpu().numpy(), path) 
  
        
        if args.frame_interval:
            if args.threshold != 1:
                indices = np.concatenate([np.arange(i, i+2) for i in range(0, 400, args.frame_interval)])
                gt_arr = gt_arr[indices]
            else:
                frame_start = 0
                gt_arr = gt_arr[frame_start: 400+frame_start]
                gt_arr = gt_arr[::args.frame_interval]
            
        else:
            frame_start = args.coarse_start
            print('gt frame start is :', frame_start)
            gt_arr = gt_arr[frame_start: (args.num_frames + frame_start)]
        
        
        if gt_arr.shape[1]>240:   
            continue
        else:
            atom14_to_pdb(gt_arr, batch['seqres'][0].cpu().numpy(), gt_path)
            if args.xtc:
                traj = mdtraj.load(path)
                traj.superpose(traj)
                traj.save(os.path.join(args.out_dir, f'{name}.xtc'))
                pdb_path_pred = os.path.join(args.out_dir, f'{name}.pdb')
                traj[0].save(pdb_path_pred)
                
                
                gt_traj = mdtraj.load(gt_path)
                gt_traj.superpose(gt_traj)
                gt_traj.save(os.path.join(args.out_dir, f'{name}_gt.xtc'))
                pdb_path_gt = gt_path
                gt_traj[0].save(gt_path) 
                    
    print('the average loss is : ', sum(total_loss) / len(total_loss) )
    return
    
@torch.no_grad()
def main():
    
    loaded_config = ModelConfig.load_config(args.config)
    model = ProteinSDE(loaded_config) 
    model.load_state_dict(torch.load(args.sim_ckpt))
    model.eval().to('cuda')
    
    
    testset = MDGenDataset(args, split=args.split)
    test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size= 1,
            num_workers=0,
        )
    
    do(model, test_loader)
    
main()