Source code for rpxdock.search.cyclic

import numpy as np, xarray as xr, rpxdock as rp, rpxdock.homog as hm
from rpxdock.search import hier_search

[docs]def make_cyclic_hier_sampler(monomer, hscore): cart_resl, ori_resl = hscore.base.attr.xhresl ncart = int(np.ceil(2 * monomer.radius_max() / cart_resl)) return rp.sampling.OriCart1Hier_f4([0.0], [ncart * cart_resl], [ncart], ori_resl)
_default_samplers = {hier_search: make_cyclic_hier_sampler}
[docs]def make_cyclic(monomer, sym, hscore, search=hier_search, sampler=None, **kw): arg = rp.Bunch(kw) t = rp.Timer().start() sym = "C%i" % i if isinstance(sym, int) else sym arg.nresl = hscore.actual_nresl if arg.nresl is None else arg.nresl arg.output_prefix = arg.output_prefix if arg.output_prefix else sym if sampler is None: sampler = _default_samplers[search](monomer, hscore) evaluator = CyclicEvaluator(monomer, sym, hscore, **arg) xforms, scores, extra, stats = search(sampler, evaluator, **arg) ibest = rp.filter_redundancy(xforms, monomer, scores, **arg) tdump = _debug_dump_cyclic(xforms, monomer, sym, scores, ibest, evaluator, **arg) if arg.verbose: print(f"rate: {int(stats.ntot / t.total):,}/s ttot {t.total:7.3f} tdump {tdump:7.3f}") print("stage time:", " ".join([f"{t:8.2f}s" for t, n in stats.neval])) print("stage rate: ", " ".join([f"{int(n/t):7,}/s" for t, n in stats.neval])) xforms = xforms[ibest] wrpx = arg.wts.sub(rpx=1, ncontact=0) wnct = arg.wts.sub(rpx=0, ncontact=1) rpx, extra = evaluator(xforms, arg.nresl - 1, wrpx) ncontact, _ = evaluator(xforms, arg.nresl - 1, wnct) return rp.Result( body_=None if arg.dont_store_body_in_results else [monomer], attrs=dict(arg=arg, stats=stats, ttotal=t.total, tdump=tdump, sym=sym), scores=(["model"], scores[ibest].astype("f4")), xforms=(["model", "hrow", "hcol"], xforms), rpx=(["model"], rpx.astype("f4")), ncontact=(["model"], ncontact.astype("f4")), reslb=(["model"], extra.reslb), resub=(["model"], extra.resub), )
[docs]class CyclicEvaluator: def __init__(self, body, sym, hscore, **kw): self.arg = rp.Bunch(kw) self.body = body self.hscore = hscore self.symrot = hm.hrot([0, 0, 1], 360 / int(sym[1:]), degrees=True) def __call__(self, xforms, iresl=-1, wts={}, **kw): arg = self.arg.sub(wts=wts) xeye = np.eye(4, dtype="f4") body, sfxn = self.body, self.hscore.scorepos xforms = xforms.reshape(-1, 4, 4) #@ body.pos xsym = self.symrot @ xforms # check for "flatness" ok = np.abs((xforms @ body.pcavecs[0])[:, 2]) <= self.arg.max_longaxis_dot_z # check clash, or get non-clash range if arg.max_trim > 0: trim = body.intersect_range(body, xforms[ok], xsym[ok], **arg) trim, trimok = rp.search.trim_ok(trim, body.nres, **arg) ok[ok] &= trimok else: ok[ok] &= body.clash_ok(body, xforms[ok], xsym[ok], **arg) trim = [0], [body.nres - 1] # score everything that didn't clash scores = np.zeros(len(xforms)) bounds = (*trim, -1, *trim, -1) scores[ok] = sfxn(body, body, xforms[ok], xsym[ok], iresl, bounds, **arg) # record ranges used lb = np.zeros(len(scores), dtype="i4") ub = np.ones(len(scores), dtype="i4") * (body.nres - 1) if trim: lb[ok], ub[ok] = trim[0], trim[1] return scores, rp.Bunch(reslb=lb, resub=ub)
def _debug_dump_cyclic(xforms, body, sym, scores, ibest, evaluator, **kw): arg = rp.Bunch(kw) t = rp.Timer().start() nout_debug = min(10 if arg.nout_debug is None else arg.nout_debug, len(ibest)) for iout in range(nout_debug): i = ibest[iout] body.move_to(xforms[i]) wrpx, wnct = (arg.wts.sub(rpx=1, ncontact=0), arg.wts.sub(rpx=0, ncontact=1)) scr, extra = evaluator(xforms[i], arg.nresl - 1, wrpx) cnt, extra = evaluator(xforms[i], arg.nresl - 1, wnct) fn = arg.output_prefix + "_%02i.pdb" % iout print( f"{fn} score {scores[i]:7.3f} rpx {scr[0]:7.3f} cnt {cnt[0]:4}", f"resi {extra.reslb[0]}-{extra.resub[0]}", ) rp.dump_pdb_from_bodies(fn, [body], rp.symframes(sym), resbounds=extra) return t.total