import copy, logging
from collections import OrderedDict, abc, defaultdict
import numpy as np, xarray as xr, rpxdock as rp
from rpxdock.util import sanitize_for_pickle, num_digits
log = logging.getLogger(__name__)
[docs]class Result:
def __init__(self, data_or_file=None, body_=[], body_label_=None, **kw):
if isinstance(body_, rp.Body): body_ = [body_]
self.bodies = [body_]
self.body_label_ = body_label_ if body_label_ else ['body%i' % i for i in range(len(body_))]
self.pdb_extra = None
if len(self.body_label_) != len(body_):
raise ValueError('body_label_ must match number of bodies')
if data_or_file:
assert len(kw) is 0
if isinstance(data_or_file, xr.Dataset):
self.data = data_or_file
else:
self.load(file_)
else:
attrs = OrderedDict(kw['attrs']) if 'attrs' in kw else None
if attrs: del kw['attrs']
attrs = sanitize_for_pickle(attrs)
self.data = xr.Dataset(dict(**kw), attrs=attrs)
# b/c I always mistype these
self.dump_pdb_top_score = self.dump_pdbs_top_score
self.dump_pdb_top_score_each = self.dump_pdbs_top_score_each
[docs] def sortby(self, *args, **kw):
r = copy.copy(self)
r.data = self.data.sortby(*args, **kw)
return r
def __getattr__(self, name):
if name == "data":
raise AttributeError
return getattr(self.data, name)
def __getitem__(self, name):
return self.data[name]
def __setitem__(self, name, val):
self.data[name] = val
def __str__(self):
return "Result with data = " + str(self.data).replace("\n", "\n ")
[docs] def copy(self):
return Result(self.data.copy())
[docs] def getstate(self):
return self.data.to_dict()
[docs] def setstate(self, state):
self.data = xr.Dataset.from_dict(state)
[docs] def sel(self, *args, **kw):
r = copy.copy(self)
r.data = r.data.sel(*args, **kw)
return r
[docs] def dump_pdbs_top_score(self, nout_top=10, **kw):
best = np.argsort(-self.scores)
return self.dump_pdbs(best[:nout_top], lbl='top', **kw)
[docs] def top_each(self, neach=1):
which = dict()
for ijob, imodel in self.scores.groupby(self.ijob).groups.items():
imodel = np.array(imodel)
ibest = np.argsort(-self.scores.data[imodel])
which[ijob] = imodel[ibest[:neach]]
return which
[docs] def dump_pdbs_top_score_each(self, nout_each=1, **kw):
if nout_each is 0: return
which = self.top_each(nout_each)
ndigijob = num_digits(len(which) - 1)
dumped = set()
ndigmdl = max(np.max(num_digits(v)) for v in which.values())
for ijob, imodel in which.items():
dumped |= self.dump_pdbs(imodel, lbl=f'job{ijob:0{ndigijob}}_top', ndigmdl=ndigmdl, **kw)
return dumped
[docs] def dump_pdbs(self, which, ndigwhich=None, ndigmdl=None, lbl='', skip=[], output_prefix='rpx',
**kw):
if len(which) is 0: return set()
if isinstance(which, abc.Mapping):
raise ValueError('dump_pdbs takes sequence not mapping')
if 'fname' in kw and kw['fname'] is not None:
raise ArgumentError('fname is not a valid argument for dump_pdbs, because multiple files')
if not output_prefix and 'output_prefix' in self.attrs:
output_prefix = self.output_prefix
if ndigwhich is None: ndigwhich = num_digits(len(which) - 1)
if ndigmdl is None: ndigmdl = num_digits(max(which))
dumped = set()
for i, imodel in enumerate(which):
assert not isinstance(imodel, np.ndarray) or len(imodel) == 1
if not imodel in skip:
dumped.add(int(imodel))
prefix_tmp = f'{output_prefix}_{lbl}{i:0{ndigwhich}}_{int(imodel):0{ndigmdl}}_'
self.dump_pdb(imodel, output_prefix=prefix_tmp, **kw)
return dumped
[docs] def dump_pdb(self, imodel, output_prefix='', output_suffix='', fname=None, output_body='ALL',
sym='', sep='_', skip=[], hscore=None, output_asym_only=False, **kw):
if not sym and 'sym' in self.attrs: sym = self.attrs['sym']
if not sym and 'sym' in self.data: sym = self.data.sym.data[imodel]
sym = sym if sym else "C1"
if not output_prefix and 'output_prefix' in self.attrs:
output_prefix = self.output_prefix
bod = self.bodies[0]
if 'ijob' in self.data: bod = self.bodies[self.ijob[imodel].values]
multipos = self.xforms.ndim == 4
if multipos and self.xforms.shape[1] != len(bod):
raise ValueError("number of positions doesn't match number of bodies")
if str(output_body).upper() == 'ALL': output_body = list(range(len(bod)))
if isinstance(output_body, int): output_body = [output_body]
if not all(w < len(bod) for w in output_body):
raise ValueError(f'output_body ouf of bounds {output_body}')
bod = [bod[i] for i in output_body]
bodlab = None
if self.xforms.ndim == 4:
for x, b in zip(self.xforms[imodel], bod):
b.move_to(x.data)
else:
bod[0].move_to(self.xforms[imodel].data)
if not fname:
output_prefix = output_prefix + sep if output_prefix else ''
body_names = [b.label for b in bod]
if len(output_body) > 1 and self.body_label_:
bodlab = [self.body_label_[i] for i in output_body]
body_names = [bl + '_' + lbl for bl, lbl in zip(bodlab, body_names)]
middle = '__'.join(body_names)
output_suffix = sep + output_suffix if output_suffix else ''
fname = output_prefix + middle + output_suffix + '.pdb'
log.info(f'dumping pdb {fname} score {self.scores.data[imodel]}')
bfactor = None
if hscore and len(bod) == 2:
sm = hscore.score_matrix_inter(
bod[0],
bod[1],
symframes=rp.geom.symframes(sym, pos=self.xforms.data[imodel], **kw),
wts=kw['wts'],
)
bfactor = [sm.sum(axis=1), sm.sum(axis=0)]
bounds = np.tile([[-9e9], [9e9]], len(bod)).T
if 'reslb' in self.data and 'resub' in self.data:
bounds = np.stack([self.reslb[imodel], self.resub[imodel]], axis=-1)
symframes = rp.geom.symframes(sym, pos=self.xforms.data[imodel], **kw)
if output_asym_only: symframes = [np.eye(4)]
rp.io.dump_pdb_from_bodies(fname, bod, symframes=symframes, resbounds=bounds,
bfactor=bfactor, **kw)
if self.pdb_extra is not None:
with open(fname, 'a') as out:
out.write(self.pdb_extra[int(imodel)])
if hasattr(self.data, 'helix_n_to_primary'):
symframes = symframes[np.array(
[0, self.data.helix_n_to_primary[imodel], self.data.helix_n_to_secondry[imodel]])]
rp.io.dump_pdb_from_bodies(fname + '_hbase.pdb', bod, symframes=symframes,
resbounds=bounds, bfactor=bfactor, **kw)
# assert 0, 'testing helix dump'
def __len__(self):
return len(self.model)
@property
def ndocks(self):
return len(self.dockinfo)
def __eq__(self, other):
return self.data.equals(other.data)
[docs]def dict_coherent_entries(alldicts):
sets = defaultdict(set)
badkeys = set()
for d in alldicts:
for k, v in d.items():
try:
sets[k].add(v)
except:
badkeys.add(k)
return {k: v.pop() for k, v in sets.items() if len(v) is 1 and not k in badkeys}
[docs]def concat_results(results, **kw):
if isinstance(results, Result): results = [results]
assert len(results) > 0
ijob = np.repeat(np.arange(len(results)), [len(r) for r in results])
assert max(len(r.bodies) for r in results) == 1
assert all(r.body_label_ == results[0].body_label_ for r in results)
allattrs = [r.attrs for r in results]
common = dict_coherent_entries(allattrs)
r = Result(xr.concat([r.data for r in results], dim='model', **kw))
r.bodies = [r.bodies[0] for r in results]
r.data['ijob'] = (['model'], ijob)
r.data.attrs = OrderedDict(dockinfo=allattrs, **common)
r.body_label_ = results[0].body_label_
if results[0].pdb_extra is not None:
r.pdb_extra = list()
for x in results:
assert len(x.pdb_extra) == len(x.data.scores)
r.pdb_extra.extend(x.pdb_extra)
return r
[docs]def dummy_result(size=1000):
from rpxdock.homog import rand_xform
return Result(
ijob=(['model'], np.repeat([3, 1, 2, 4, 0], size / 5).astype('i8')),
scores=(["model"], np.random.rand(size).astype('f4')),
xforms=(["model", "hrow", "hcol"], rand_xform(size).astype('f4')),
rpx_plug=(["model"], np.random.rand(size).astype('f4')),
rpx_hole=(["model"], np.random.rand(size).astype('f4')),
ncontact_plug=(["model"], np.random.rand(size).astype('f4')),
ncontact_hole=(["model"], np.random.rand(size).astype('f4')),
reslb=(["model"], np.random.randint(0, 100, size)),
resub=(["model"], np.random.randint(100, 200, size)),
)
[docs]def assert_results_close(r, s, n=-1):
if set(r.keys()) != set(s.keys()):
print(list(r.keys()))
print(list(s.keys()))
assert set(r.keys()) == set(s.keys()), 'results must have same fields'
assert np.allclose(r.scores[:n], s.scores[:n])
assert np.allclose(r.xforms[:n], s.xforms[:n], atol=1e-3)
for k in r.data:
assert np.allclose(r[k][:n], s[k][:n])