import logging, numpy as np, xarray as xr, rpxdock as rp, rpxdock.homog as hm
from rpxdock.search import hier_search
log = logging.getLogger(__name__)
[docs]def helix_get_sample_hierarchy(body, hscore, extent=100):
"set up XformHier with appropriate bounds and resolution"
cart_xhresl, ori_xhresl = hscore.base.attr.xhresl
rg = body.rg()
cart_samp_resl = 0.707 * cart_xhresl
ori_samp_resl = cart_samp_resl / rg * 180 / np.pi
# print(cart_samp_resl, rg, ori_samp_resl)
ori_samp_resl = min(ori_samp_resl, ori_xhresl)
# print(f"helix_get_sample_hierarchy cart: {cart_samp_resl} ori: {ori_samp_resl}")
ncart = np.ceil(extent * 2 / cart_samp_resl)
cartlb = np.array([-extent] * 3)
cartub = np.array([extent] * 3)
cartbs = np.array([ncart] * 3, dtype="i")
xh = rp.sampling.XformHier_f4(cartlb, cartub, cartbs, ori_samp_resl)
assert xh.sanity_check(), "bad xform hierarchy"
log.info(
f"XformHier {xh.size(0):,} {xh.cart_bs} {xh.ori_resl} {xh.cart_lb} {xh.cart_ub}, body.copy()"
)
return xh
[docs]def make_helix(body, hscore, sampler, search=hier_search, **kw):
arg = rp.Bunch(kw)
arg.nresl = hscore.actual_nresl if arg.nresl is None else arg.nresl
arg.output_prefix = arg.output_prefix if arg.output_prefix else sym
t = rp.Timer().start()
assert sampler is not None, 'sampler is required'
evaluator = HelixEvaluator(body, hscore, **arg)
xforms, scores, extra, stats = search(sampler, evaluator, **arg)
ncat_pri = np.max(extra.helix_n_to_primary) + 1
ncat_sec = np.max(extra.helix_n_to_secondry) + 1
cat = extra.helix_n_to_primary
cat += extra.helix_n_to_secondry * ncat_pri
cat += extra.helix_cyclic * ncat_pri * ncat_sec
ibest = rp.filter_redundancy(xforms, body, scores, categories=cat, **arg)
if arg.verbose:
print(f"rate: {int(stats.ntot / t.total):,}/s ttot {t.total:7.3f} tdump {tdump:7.3f}")
print("stage time:", " ".join([f"{t:8.2f}s" for t, n in stats.neval]))
print("stage rate: ", " ".join([f"{int(n/t):7,}/s" for t, n in stats.neval]))
xforms = xforms[ibest]
wrpx = arg.wts.sub(rpx=1, ncontact=0)
wnct = arg.wts.sub(rpx=0, ncontact=1)
rpx, extra = evaluator(xforms, arg.nresl - 1, wrpx)
ncontact, _ = evaluator(xforms, arg.nresl - 1, wnct)
return rp.Result(
body_=None if arg.dont_store_body_in_results else [body, body.copy()],
attrs=dict(arg=arg, stats=stats, ttotal=t.total),
scores=(["model"], scores[ibest].astype("f4")),
xforms=(["model", "hrow", "hcol"], xforms),
rpx=(["model"], rpx.astype("f4")),
ncontact=(["model"], ncontact.astype("f4")),
reslb=(["model"], extra.reslb),
resub=(["model"], extra.resub),
helix_n_to_primary=(["model"], extra.helix_n_to_primary),
helix_n_to_secondry=(["model"], extra.helix_n_to_secondry),
helix_cyclic=(["model"], extra.helix_cyclic),
sym=(["model"], np.array(['H' + str(i) for i in extra.helix_cyclic])),
)
[docs]class HelixEvaluator:
def __init__(self, body, hscore, **kw):
self.arg = rp.Bunch(kw)
self.body = body.copy()
self.hscore = hscore
def __call__(self, xforms, iresl=-1, wts={}, **kw):
arg = self.arg.sub(wts=wts)
xeye = np.eye(4, dtype="f4")
body = self.body.copy()
body2 = self.body.copy()
xforms = xforms.reshape(-1, 4, 4)
cart_extent = self.hscore.cart_extent[iresl]
ori_extent = self.hscore.ori_extent[iresl]
if not arg.helix_max_delta_z:
helix_max_delta_z = body.radius_max() * 2 / arg.helix_min_isecond
else:
helix_max_delta_z = arg.helix_max_delta_z
helix_max_iclash = int(arg.helix_max_isecond * 1.2 + 3)
assert arg.symframe_num_helix_repeats >= arg.helix_max_isecond
axis, ang = hm.axis_angle_of(xforms)
ang = ang * 180 / np.pi
aok = np.logical_and(ang >= arg.helix_min_primary_angle - ori_extent,
ang <= arg.helix_max_primary_angle + ori_extent)
dhelix = np.abs(hm.hdot(axis, xforms[:, :, 3]))
dok = np.logical_and(dhelix >= arg.helix_min_delta_z - cart_extent,
dhelix <= helix_max_delta_z + cart_extent)
ok = np.logical_and(dok, aok)
# ok = np.tile(True, len(xforms))
scores = np.zeros((len(xforms), 2))
ok[ok] = body.clash_ok(body2, xforms[ok], xeye, **arg)
scores[ok, 0] = self.hscore.scorepos(body, body, xforms[ok], xeye, iresl, **arg)
ok[ok] &= scores[ok, 0] >= arg.helix_min_primary_score
ok[ok] &= scores[ok, 0] <= arg.helix_max_primary_score
which2 = None
# only check for 2ndary interaction when resolution is at least arg.helix_iresl_second_shift
if iresl < arg.helix_iresl_second_shift:
scores = scores[:, 0]
else:
xforms2 = xforms
for i in range(2, helix_max_iclash):
xforms2 = xforms @ xforms2
ok[ok] = body.clash_ok(body2, xforms2[ok], xeye, **arg)
xforms2 = xforms
scores2 = np.zeros((arg.helix_max_isecond, len(xforms)))
for i2 in range(2, arg.helix_max_isecond):
xforms2 = xforms @ xforms2
if i2 < arg.helix_min_isecond: continue
scores2[i2, ok] = self.hscore.scorepos(
body,
body2,
xforms2[ok],
xeye,
iresl - arg.helix_iresl_second_shift,
**arg,
)
# xforms2 = xforms
# scores2 = np.zeros((arg.helix_max_isecond, len(xforms)))
# for i2 in range(2, arg.helix_max_isecond):
# xforms2 = xforms @ xforms2
# ok[ok] = body.clash_ok(body, xforms2[ok], xeye, **arg) # it was you!
# scores2[i2, ok] = self.hscore.scorepos(body, body, xforms2[ok], xeye, iresl, **arg)
# for i in range(arg.helix_max_isecond, helix_max_iclash):
# xforms2 = xforms @ xforms2
# ok[ok] = body.clash_ok(body, xforms2[ok], xeye, **arg)
scores2[:, ~ok] = 0
which2 = np.argmax(scores2, axis=0).astype('i8')
scores[:, 1] = scores2[which2, range(len(xforms))]
# summary depends on iresl stage, final is min score like before
minsc = np.min(scores, axis=1)
maxsc = np.max(scores, axis=1)
mix0 = (iresl - arg.helix_iresl_second_shift + 1)
mix = 0.5**mix0
# print(iresl, mix0, mix)
assert 0 <= mix <= 1
scores = minsc * (1 - mix) + maxsc * mix
# scores = np.min(scores, axis=1)
# scores = scores[:, 0]
helix_cyclic = np.repeat(1, len(scores))
return scores, rp.Bunch(
helix_n_to_primary=np.repeat(1, len(scores)),
helix_n_to_secondry=which2,
reslb=np.tile(0, len(scores)),
resub=np.tile(len(body), len(scores)),
helix_cyclic=helix_cyclic,
)