
import os
import time
import pyrosetta
from pyrosetta.rosetta.protocols.relax import FastRelax
from pyrosetta.rosetta.core.pack.task import TaskFactory
from pyrosetta.rosetta.core.pack.task import operation
from pyrosetta.rosetta.core.select import residue_selector as selections
from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action
pyrosetta.init(' '.join([
    '-mute', 'all',
    '-use_input_sc',
    '-ignore_unrecognized_res',
    '-ignore_zero_occupancy', 'false',
    '-load_PDB_components', 'false',
    '-relax:default_repeats', '2',
    '-no_fconfig',
]))

from diffab.tools.relax.base import RelaxTask


def current_milli_time():
    return round(time.time() * 1000)


def parse_residue_position(p):
    icode = None
    if not p[-1].isnumeric():   
        icode = p[-1]

    for i, c in enumerate(p):
        if c.isnumeric():
            break
    chain = p[:i]
    resseq = int(p[i:])

    if icode is not None:
        return chain, resseq, icode
    else:
        return chain, resseq


def get_scorefxn(scorefxn_name:str):
    
    import pyrosetta

    corrections = {
        'beta_july15': False,
        'beta_nov16': False,
        'gen_potential': False,
        'restore_talaris_behavior': False,
    }
    if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name:
        
        corrections['beta_july15'] = True
    elif 'beta_nov16' in scorefxn_name:
        corrections['beta_nov16'] = True
    elif 'genpot' in scorefxn_name:
        corrections['gen_potential'] = True
        pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True)
    elif 'talaris' in scorefxn_name:  
        corrections['restore_talaris_behavior'] = True
    else:
        pass
    for corr, value in corrections.items():
        pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value)
    return pyrosetta.create_score_function(scorefxn_name)


class RelaxRegion(object):
    
    def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True):
        super().__init__()
        self.scorefxn = get_scorefxn(scorefxn)
        self.fast_relax = FastRelax()
        self.fast_relax.set_scorefxn(self.scorefxn)
        self.fast_relax.max_iter(max_iter)
        assert subset in ('all', 'target', 'nbrs')
        self.subset = subset
        self.move_bb = move_bb

    def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last):
        pose = pyrosetta.pose_from_pdb(pdb_path)
        start_t = current_milli_time()
        original_pose = pose.clone()

        tf = TaskFactory()
        tf.push_back(operation.InitializeFromCommandline())
        tf.push_back(operation.RestrictToRepacking())   

        
        
        if flexible_residue_first[-1] == ' ': 
            flexible_residue_first = flexible_residue_first[:-1]
        if flexible_residue_last[-1] == ' ':  
            flexible_residue_last  = flexible_residue_last[:-1]
        if self.subset != 'all':
            gen_selector = selections.ResidueIndexSelector()
            gen_selector.set_index_range(
                pose.pdb_info().pdb2pose(*flexible_residue_first), 
                pose.pdb_info().pdb2pose(*flexible_residue_last), 
            )
            nbr_selector = selections.NeighborhoodResidueSelector()
            nbr_selector.set_focus_selector(gen_selector)
            nbr_selector.set_include_focus_in_subset(True)

            if self.subset == 'nbrs':
                subset_selector = nbr_selector
            elif self.subset == 'target':
                subset_selector = gen_selector

            prevent_repacking_rlt = operation.PreventRepackingRLT()
            prevent_subset_repacking = operation.OperateOnResidueSubset(
                prevent_repacking_rlt, 
                subset_selector,
                flip_subset=True,
            )
            tf.push_back(prevent_subset_repacking)

        scorefxn = self.scorefxn
        fr = self.fast_relax

        pose = original_pose.clone()
        pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long()
        for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1):
            pos_list.append(pos)
        

        mmf = MoveMapFactory()
        if self.move_bb: 
            mmf.add_bb_action(move_map_action.mm_enable, gen_selector)
        mmf.add_chi_action(move_map_action.mm_enable, subset_selector)
        mm  = mmf.create_movemap_from_pose(pose)

        fr.set_movemap(mm)
        fr.set_task_factory(tf)
        fr.apply(pose)

        e_before = scorefxn(original_pose)
        e_relax  = scorefxn(pose) 
        
        
        
        return pose, e_before, e_relax


def run_pyrosetta(task: RelaxTask):
    if not task.can_proceed() :
        return task
    if task.update_if_finished('rosetta'):
        return task

    minimizer = RelaxRegion()
    pose_min, _, _ = minimizer(
        pdb_path = task.current_path,
        flexible_residue_first = task.flexible_residue_first,
        flexible_residue_last = task.flexible_residue_last,
    )

    out_path = task.set_current_path_tag('rosetta')
    pose_min.dump_pdb(out_path)
    task.mark_success()
    return task


def run_pyrosetta_fixbb(task: RelaxTask):
    if not task.can_proceed() :
        return task
    if task.update_if_finished('fixbb'):
        return task

    minimizer = RelaxRegion(move_bb=False)
    pose_min, _, _ = minimizer(
        pdb_path = task.current_path,
        flexible_residue_first = task.flexible_residue_first,
        flexible_residue_last = task.flexible_residue_last,
    )

    out_path = task.set_current_path_tag('fixbb')
    pose_min.dump_pdb(out_path)
    task.mark_success()
    return task

    

