import logging
import numpy as np, xarray as xr, rpxdock as rp
from rpxdock.search import hier_search
log = logging.getLogger(__name__)
[docs]def make_plugs(plug, hole, hscore, search=hier_search, sampler=None, **kw):
arg = rp.Bunch(kw)
arg.nresl = hscore.actual_nresl if arg.nresl is None else arg.nresl
arg.output_prefix = "plug" if arg.output_prefix is None else arg.output_prefix
t = rp.Timer().start()
evaluator = PlugEvaluator(plug, hole, hscore, **arg)
if sampler is None: sampler = _default_samplers[search](plug, hole, hscore)
xforms, scores, extra, stats = search(sampler, evaluator, **arg)
ibest = rp.filter_redundancy(xforms, plug, scores, **arg)
tdump = _debug_dump_plugs(xforms, plug, hole, scores, ibest, evaluator, **arg)
log.debug(f"rate: {int(stats.ntot / t.total):,}/s ttot {t.total:7.3f} tdump {tdump:7.3f}")
log.debug("stage time:", " ".join([f"{t:8.2f}s" for t, n in stats.neval]))
log.debug("stage rate: ", " ".join([f"{int(n/t):7,}/s" for t, n in stats.neval]))
xforms = xforms[ibest]
scores = scores[ibest]
wrpx = arg.wts.sub(ncontact=0)
wnct = arg.wts.sub(rpx=0)
rpx, extra = evaluator.iface_scores(xforms, arg.nresl - 1, wrpx)
ncontact, *_ = evaluator.iface_scores(xforms, arg.nresl - 1, wnct)
ifacescores = rpx + ncontact
assert np.allclose(np.min(rpx + ncontact, axis=1), scores)
data = dict(
attrs=dict(arg=arg, stats=stats, sym=hole.sym, ttotal=t.total, tdump=tdump,
output_body='all'),
scores=(["model"], scores.astype("f4")),
xforms=(["model", "hrow", "hcol"], xforms),
tot_plug=(["model"], ifacescores[:, 0].astype("f4")),
tot_hole=(["model"], ifacescores[:, 1].astype("f4")),
rpx_plug=(["model"], rpx[:, 0].astype("f4")),
rpx_hole=(["model"], rpx[:, 1].astype("f4")),
ncontact_plug=(["model"], ncontact[:, 0].astype("f4")),
ncontact_hole=(["model"], ncontact[:, 1].astype("f4")),
)
for k, v in extra.items():
if not isinstance(v, (list, tuple)) or len(v) > 3: v = ['model'], v
data[k] = v
return rp.Result(
body_=[] if arg.dont_store_body_in_results else [plug, hole],
body_label_=[] if arg.dont_store_body_in_results else ['plug', 'hole'],
**data,
)
[docs]class PlugEvaluator:
def __init__(self, plug, hole, hscore, **kw):
self.arg = rp.Bunch(kw)
self.plug = plug
self.hole = hole
self.hscore = hscore
self.symrot = rp.homog.hrot([0, 0, 1], 360 / int(hole.sym[1:]), degrees=True)
def __call__(self, xforms, iresl=-1, wts={}, **_):
wts = self.arg.wts.sub(wts)
wts_ph = wts.plug, wts.hole
iface_scores, extra = self.iface_scores(xforms, iresl, wts)
scores = self.arg.iface_summary(iface_scores * wts_ph, axis=1)
return scores, extra
[docs] def iface_scores(self, xforms, iresl=-1, wts={}, **_):
wts = self.arg.wts.sub(wts)
xeye = np.eye(4, dtype="f4")
xforms = xforms.reshape(-1, 4, 4)
plug, hole, sfxn = self.plug, self.hole, self.hscore.scorepos
dclsh, max_trim = self.arg.clashdis, self.arg.max_trim
xsym = self.symrot @ xforms
# check for "flatness"
ok = np.abs((xforms @ plug.pcavecs[0])[:, 2]) <= self.arg.max_longaxis_dot_z
if not self.arg.plug_fixed_olig: # check chash in formed oligomer
ok[ok] &= plug.clash_ok(plug, xforms[ok], xsym[ok], mindis=dclsh)
if max_trim > 0: # get non-clash range
trim = plug.intersect_range(hole, xforms[ok], max_trim=max_trim, mindis=dclsh)
trim, trimok = rp.search.trim_ok(trim, plug.nres, max_trim)
ok[ok] &= trimok
else: # check clash olig vs hole
ok[ok] &= plug.clash_ok(hole, xforms[ok], xeye, mindis=dclsh)
trim = [0], [plug.nres - 1]
# score everything that didn't clash
xok = xforms[ok]
scores = np.zeros((len(xforms), 2))
scores[ok, 0] = 9999
if not self.arg.plug_fixed_olig:
bounds = (*trim, -1, *trim, -1)
scores[ok, 0] = sfxn(plug, plug, xok, xsym[ok], iresl, bounds=bounds, wts=wts)
scores[ok, 1] = sfxn(plug, hole, xok, xeye[:, ], iresl, bounds=trim, wts=wts)
# record ranges used
plb = np.zeros(len(scores), dtype="i4")
pub = np.ones(len(scores), dtype="i4") * (plug.nres - 1)
if trim:
plb[ok], pub[ok] = trim[0], trim[1]
return scores, rp.Bunch(reslb=plb, resub=pub)
def _debug_dump_plugs(xforms, plug, hole, scores, ibest, evaluator, **kw):
arg = rp.Bunch(kw)
t = rp.Timer().start()
fname_prefix = "plug" if arg.output_prefix is None else arg.output_prefix
nout_debug = min(10 if arg.nout_debug is None else arg.nout_debug, len(ibest))
for i in range(nout_debug):
plug.move_to(xforms[ibest[i]])
wrpx, wnct = (arg.wts.sub(rpx=1, ncontact=0), arg.wts.sub(rpx=0, ncontact=1))
scoreme = evaluator.iface_scores
((pscr, hscr), ), extra = scoreme(xforms[ibest[i]], arg.nresl - 1, wrpx)
((pcnt, hcnt), ), extra = scoreme(xforms[ibest[i]], arg.nresl - 1, wnct)
fn = fname_prefix + "_%02i.pdb" % i
log.info(f"{fn} score {scores[ibest[i]]:7.3f} olig: {pscr:7.3f} hole: {hscr:7.3f}" +
f"resi {extra.reslb[0]}-{extra.resub[0]} {pcnt:7.0f} {hcnt:7.0f}")
# print('_debug_dump_plugs', i, scores[ibest[i]], extra.reslb.data[0], extra.resub.data[0])
rp.io.dump_pdb_from_bodies(fn, [plug], rp.geom.symframes(hole.sym), resbounds=extra)
return t.total
[docs]def plug_get_sample_hierarchy(plug, hole, hscore):
"set up XformHier with appropriate bounds and resolution"
cart_samp_resl, ori_samp_resl = hscore.base.attr.xhresl
r0 = max(hole.rg_xy(), 2 * plug.radius_max())
nr1 = np.ceil(r0 / cart_samp_resl)
r1 = nr1 * cart_samp_resl
nr2 = np.ceil(r0 / cart_samp_resl * 2)
r2 = nr2 * cart_samp_resl / 2
nh = np.ceil(3 * hole.rg_z() / cart_samp_resl)
h = nh * cart_samp_resl / 2
cartub = np.array([+r2, +r2, +h])
cartlb = np.array([-r2, -r2, -h])
cartbs = np.array([nr2, nr2, nh], dtype="i")
xh = rp.sampling.XformHier_f4(cartlb, cartub, cartbs, ori_samp_resl)
assert xh.sanity_check(), "bad xform hierarchy"
log.info(f"XformHier {xh.size(0):,} {xh.cart_bs} {xh.ori_resl} {xh.cart_lb} {xh.cart_ub}")
return xh
_default_samplers = {hier_search: plug_get_sample_hierarchy}
[docs]def plug_test_hier_sampler(plug, hole, hscore, n=6):
r, rori = hscore.base.attr.xhresl
cartub = np.array([n * r, r, r])
cartlb = np.array([-n * r, 0, 0])
cartbs = np.array([2 * n, 1, 1], dtype="i")
xh = rp.sampling.XformHier_f4(cartlb, cartub, cartbs, rori)
assert xh.sanity_check(), "bad xform hierarchy"
# print(f"XformHier {xh.size(0):,}", xh.cart_bs, xh.ori_resl, xh.cart_lb, xh.cart_ub)
return xh
### below is junk?
def __make_plugs_hier_sample_test__(plug, hole, hscore, **kw):
arg = rp.Bunch(kw)
sampler = plug_get_sample_hierarchy(plug, hole, hscore)
sampler = plug_test_hier_sampler(plug, hole, hscore)
nresl = kw["nresl"]
for rpx in [0, 1]:
arg.wts = rp.Bunch(plug=1.0, hole=1.0, ncontact=1.0, rpx=rpx)
evaluator = PlugEvaluator(plug, hole, hscore, **arg)
iresl = 0
indices, xforms = expand_samples(**arg.sub(vars()))
scores, *resbound, t = hier_evaluate(**arg.sub(vars()))
iroot = np.argsort(-scores)[:10]
xroot = xforms[iroot]
sroot = scores[iroot]
for ibeam in range(6, 27):
beam_size = 2**ibeam
indices, xforms, scores = iroot, xroot, sroot
for iresl in range(1, nresl):
indices, xforms = expand_samples(**arg.sub(vars()))
scores, *resbound, t = hier_evaluate(**arg.sub(vars()))
print(
f"rpx {rpx} beam {beam_size:9,}",
f"iresl {iresl} ntot {len(scores):11,} nonzero {np.sum(scores > 0):5,}",
f"best {np.max(scores)}",
)
import _pickle
fn = "make_plugs_hier_sample_test_rpx_%i_ibeam_%i.pickle" % (rpx, ibeam)
with open(fn, "wb") as out:
_pickle.dump((ibeam, iresl, indices, scores), out)
print()
assert 0