Source code for rpxdock.motif.pairdat

import os, time, logging
import xarray as xr, numpy as np

log = logging.getLogger(__name__)

[docs]class ResPairData: def __init__(self, data, sanity_check=None): if isinstance(data, xr.Dataset): self.data = data elif isinstance(data, ResPairData): self.data = data.data if hasattr(self, "data"): if sanity_check is True: self.sanity_check() return else: _rp_from_raw_dicts(self, data, sanity_check) def __getattr__(self, k): if k == "data": raise AttributeError return getattr(self.data, k) def __getitem__(self, k): return self.data[k] def __str__(self): return "ResPairData with data = " + str(self.data).replace("\n", "\n ") def _get_keepers_by_pdb(self, keep, random=True, seed=None, **kw): if isinstance(keep, (int, np.int32, np.int64)): if random: if seed is not None: np.random.seed(seed) keep = np.random.choice(len(self.pdb), keep, replace=False) else: keep = np.arange(keep) if isinstance(keep, np.ndarray) and keep.dtype == np.bool: keep = np.where(keep)[0] return np.array(sorted(keep))
[docs] def subset_by_pdb(self, keep, sanity_check=False, update_p_res=True, **kw): """keep subset of data in same order as original""" keepers = self._get_keepers_by_pdb(keep, **kw) if np.sum(keepers) is 0: raise ValueError('no pdbs remain') residx = np.isin(self.data.r_pdbid, keepers) pairidx = np.isin(self.data.p_pdbid, keepers) rpsub = self.data.sel(pdbid=keepers, resid=residx, pairid=pairidx) _update_relational_data(rpsub, self.data.pdb_res_offsets, keepers, update_p_res) new = ResPairData(rpsub) if sanity_check: # try: new.sanity_check() # except AssertionError: # import _pickle # with open('subset_by_pdb_error') return new
[docs] def subset_by_aa(self, aas, sanity_check=False, return_keepers=False): if isinstance(aas, str): aas = tuple(aas) keepers = np.isin(self.id2aa[self.aaid].data, aas) to_return = self.subset_by_res(keepers, sanity_check) if return_keepers: to_return = to_return, keepers return to_return
[docs] def subset_by_ss(self, ss, sanity_check=False, return_keepers=False): if len(set(ss) - set("EHL")) > 0: raise ValueError('ss must be EHL') if isinstance(ss, str): ss = tuple(ss) keepers = np.isin(self.id2ss[self.ssid].data, ss) to_return = self.subset_by_res(keepers, sanity_check) if return_keepers: to_return = to_return, keepers return to_return
[docs] def subset_by_res(self, keepers, sanity_check=False): assert len(keepers) == len(self.data.resid) if np.sum(keepers) is 0: raise ValueError('no residues remain') # t1 = time.perf_counter() # remove residues w/o removing pdbs rp_res = ResPairData(self.data.sel(resid=keepers)) g_nres = rp_res.r_pdbid.groupby(rp_res.r_pdbid).count() nres = np.zeros(len(self.nres), dtype=np.int32) nres[g_nres.r_pdbid.values] = g_nres.values # sketchy? new_pdb_res_offsets = np.concatenate([[0], np.cumsum(nres)]) rp_res.data.attrs["pdb_res_offsets"] = new_pdb_res_offsets rp_res.data.nres.values = nres # print('t1', time.perf_counter() - t1) # t2 = time.perf_counter() # remove pairs keep_idx = np.where(keepers)[0] p_resi_ok = np.isin(rp_res.p_resi, keep_idx) p_resj_ok = np.isin(rp_res.p_resj, keep_idx) pair_keep = np.logical_and(p_resi_ok, p_resj_ok) if np.sum(pair_keep) == 0: raise ValueError('no pairs left') rp_pair = ResPairData(rp_res.data.sel(pairid=pair_keep)) # print('t2', time.perf_counter() - t2) # t3 = time.perf_counter() # adjust pair res numbering noldi = np.max(rp_pair.data.p_resi).values + 1 noldj = np.max(rp_pair.data.p_resj).values + 1 nikeep = np.max(keep_idx) + 1 old2new = np.zeros(max(nikeep, noldi, noldj), "i4") - 1 old2new[keep_idx] = np.arange(len(keep_idx)) rp_pair.data.p_resi.values = old2new[rp_pair.data.p_resi.values] rp_pair.data.p_resj.values = old2new[rp_pair.data.p_resj.values] assert np.all(rp_pair.p_resi >= 0) assert np.all(rp_pair.p_resj >= 0) # print('t3', time.perf_counter() - t3) # t4 = time.perf_counter() # remove empty pdbs and adjust r/p_pdbid keep_pdb = np.where(rp_pair.nres > 0)[0] old2new = np.zeros(len(rp_pair.nres), 'i4') - 1 # print(keep_pdb) old2new[keep_pdb] = np.arange(len(keep_pdb)) rp_pair.data = rp_pair.data.sel(pdbid=keep_pdb) rp_pair.r_pdbid.values = old2new[rp_pair.r_pdbid.values] rp_pair.p_pdbid.values = old2new[rp_pair.p_pdbid.values] # print('t4', time.perf_counter() - t4) # t5 = time.perf_counter() # recompute offsets new_pdb_res_offsets = np.concatenate([[0], np.cumsum(rp_pair.nres)]) rp_pair.data.attrs["pdb_res_offsets"] = new_pdb_res_offsets g_npair = rp_pair.p_pdbid.groupby(rp_pair.p_pdbid).count() npair = np.zeros(len(rp_pair.nres), dtype=np.int32) npair[g_npair.p_pdbid.values] = g_npair.values # sketchy? new_pdb_pair_offsets = np.concatenate([[0], np.cumsum(npair)]) rp_pair.data.attrs["pdb_pair_offsets"] = new_pdb_pair_offsets # print('t5', time.perf_counter() - t5) # t6 = time.perf_counter() if sanity_check: rp_pair.sanity_check() # print('t6', time.perf_counter() - t6) return rp_pair
[docs] def subset_by_pair(self, keepers, sanity_check=False, update_p_res=True): rpsub = ResPairData(self.data.sel(pairid=keepers)) pdb_keep = np.unique(rpsub.p_pdbid) if len(pdb_keep) == 0: raise ValueError('no pairs remain') if len(pdb_keep) != len(rpsub.pdbid): # some whole pdbs removed, full update rpsub = rpsub.subset_by_pdb(pdb_keep, update_p_res=update_p_res) else: # no whole pdbs removed, simple update tmp = np.cumsum(rpsub.p_pdbid.groupby(rpsub.p_pdbid).count()) new_pdb_pair_offsets = np.concatenate([[0], tmp]) rpsub.data.attrs["pdb_pair_offsets"] = new_pdb_pair_offsets if sanity_check: rpsub.sanity_check() return rpsub
[docs] def split_by_pdb(self, frac, random=True, **kw): n1 = int(len(self.pdb) * frac) n2 = len(self.pdb) - n1 if random: part1 = np.random.choice(len(self.pdb), n1, replace=False) part2 = np.array(list(set(range(len(self.pdb))) - set(part1))) np.random.shuffle(part2) else: part1 = np.arange(n1) part2 = np.arange(n1, len(self.pdb)) parts = [sorted(part1), sorted(part2)] return [self.subset_by_pdb(part, **kw) for part in parts]
[docs] def sanity_check(self): rp = self.data Npdb = len(self.pdb) # from rpxdock import Timer # with Timer() as timer: if True: log.debug(f'ResPaisDat sanity_check iters {min(Npdb, 100)}') for ipdb in np.random.choice(Npdb, min(Npdb, 100), replace=False): rlb, rub = rp.pdb_res_offsets[ipdb:ipdb + 2] # if rlb > 0: # assert rp.r_pdbid[rlb - 1] == ipdb - 1 assert rp.r_pdbid[rlb] == ipdb assert rp.r_pdbid[rub - 1] == ipdb if ipdb + 1 < len(rp.pdb): assert rp.r_pdbid[rub] == ipdb + 1 p_resno = rp.resno[rlb:rub] assert np.all(p_resno >= rp.resid[rlb:rub] - rlb) # timer.checkpoint('check pdb_res_offsets') plb = rp.pdb_pair_offsets[ipdb] pub = rp.pdb_pair_offsets[ipdb + 1] if plb == pub: continue p_resi = rp.p_resi[plb:pub] - rp.pdb_res_offsets[ipdb] p_resj = rp.p_resj[plb:pub] - rp.pdb_res_offsets[ipdb] if np.min(p_resi) < 0 or np.max(p_resi) >= rp.nres[ipdb]: r = str(np.random.randint(1e9)) # import _pickle # # with open(r + ".pickle", "wb") as out: # _pickle.dump(rp, out) print(r, "sanity_check fail") print(r, "pdb", ipdb, rp.pdb[ipdb].values, rp.nres[ipdb].values) print(r, "offset res", rp.pdb_res_offsets[ipdb]) print(r, "pair_range", plb, pub) print(r, p_resi) print(r, p_resj) # import sys # # sys.exit() assert np.min(p_resi) >= 0 assert np.max(p_resi) < rp.nres[ipdb] assert 0 < np.min(p_resj) < rp.nres[ipdb] # timer.checkpoint('check pdb_pair_offsets vs nres') resi_this_ipdb = rp.p_resi[rp.p_pdbid == ipdb] assert np.all(rp.r_pdbid[resi_this_ipdb] == ipdb) # timer.checkpoint('loop iter end') # timer.checkpoint('big loop') # sanity check pair distances vs res-res distances if 'cb' in rp: cbi = rp.cb[rp.p_resi] cbj = rp.cb[rp.p_resj] dhat = np.linalg.norm(cbi - cbj, axis=1) assert np.allclose(dhat, rp.p_dist, atol=1e-3) # timer.checkpoint('coords/dist check') assert 0 <= np.max(rp.p_resi) < len(rp.resid) assert 0 <= np.max(rp.p_resj) < len(rp.resid)
# timer.checkpoint('p_res bounds check') # print(timer)
[docs] def only_whats_needed(self, task): not_needed = dict( seqproftest= "phi psi omega chi1 chi2 chi3 chi4 chain r_fa_sol r_fa_intra_atr_xover4 r_fa_intra_rep_xover4 r_fa_intra_sol_xover4 r_lk_ball r_lk_ball_iso r_lk_ball_bridge r_lk_ball_bridge_uncpl r_fa_elec r_fa_intra_elec r_pro_close r_hbond_sr_bb r_hbond_lr_bb r_hbond_bb_sc r_hb_sc r_dslf_fa13 r_rama_prepro r_omega r_p_aa_pp r_fa_dun_rot r_fa_dun_dev r_fa_dun_semi r_hxl_tors r_ref sasa2 sasa4 nnb6 nnb8 nnb12 nnb14 p_hb_bb_bb p_hb_bb_sc p_hb_sc_bb p_hb_sc_sc p_fa_atr p_fa_rep p_fa_sol p_lk_ball p_fa_elec p_hbond_sr_bb p_hbond_lr_bb" .split(), respairscore= "r_fa_intra_atr_xover4 r_fa_intra_rep_xover4 r_fa_intra_sol_xover4 r_lk_ball_iso r_lk_ball_bridge r_lk_ball_bridge_uncpl r_fa_intra_elec r_pro_close r_rama_prepro r_omega r_p_aa_pp r_hxl_tors r_ref sasa2 sasa4 nnb6 nnb8 nnb12 nnb14 p_hb_bb_bb p_hb_bb_sc p_hb_sc_bb p_fa_atr p_fa_rep p_fa_sol p_lk_ball p_fa_elec" .split(), ) assert task in not_needed for v in not_needed[task]: if not v in self.data: print("ResPairData.only_whats_needed: missing:", v) return ResPairData(self.data.drop(not_needed[task]))
def _get_pdb_names(files): base = [os.path.basename(f) for f in files] assert all(b[4:] == "_0001.pdb" for b in base) return [b[:4] for b in base] def _update_relational_data(data, prev_pdb_res_offsets, pdb_subset=None, update_p_res=True): # update relational stuff if pdb_subset is None: pdb_subset = np.arange(len(data.pdbid)) # update pdbid references nold = np.max(data.r_pdbid).values + 1 old2new = np.zeros(nold, "i4") - 1 old2new[pdb_subset] = np.arange(len(pdb_subset)) data.r_pdbid.values = old2new[data.r_pdbid.values] data.p_pdbid.values = old2new[data.p_pdbid.values] new_pdb_res_offsets = np.concatenate([[0], np.cumsum(data.nres)]) data.attrs["pdb_res_offsets"] = new_pdb_res_offsets tmp = np.cumsum(data.p_pdbid.groupby(data.p_pdbid).count()) new_pdb_pair_offsets = np.concatenate([[0], tmp]) data.attrs["pdb_pair_offsets"] = new_pdb_pair_offsets old_res_ofst = prev_pdb_res_offsets[pdb_subset[data.p_pdbid.values]] new_res_ofst = new_pdb_res_offsets[data.p_pdbid.values] if update_p_res: data["p_resi"] += new_res_ofst - old_res_ofst data["p_resj"] += new_res_ofst - old_res_ofst assert np.all(data.p_resi >= 0) assert np.all(data.p_resj >= 0) def _change_seq_ss_to_ids(rp): dat = rp.data id2aa = np.array(list("ACDEFGHIKLMNPQRSTVWY")) aa2id = xr.DataArray(np.arange(20, dtype="i4"), [("aa", id2aa)]) aaid = aa2id.sel(aa=dat.seq).values.astype("i4") dat["aaid"] = xr.DataArray(aaid, dims=["resid"]) dat = dat.drop("seq") dat["id2aa"] = xr.DataArray(id2aa, [aa2id], ["aai"]) dat["aa2id"] = xr.DataArray(aa2id, [id2aa], ["aa"]) id2ss = np.array(list("EHL")) ss2id = xr.DataArray(np.arange(3, dtype="i4"), [("ss", id2ss)]) ssid = ss2id.sel(ss=dat.ss).values.astype("i4") dat["ssid"] = xr.DataArray(ssid, dims=["resid"]) dat = dat.drop("ss") dat["id2ss"] = xr.DataArray(id2ss, [ss2id], ["ssi"]) dat["ss2id"] = xr.DataArray(ss2id, [id2ss], ["ss"]) rp.data = dat def _rp_from_raw_dicts(self, data, sanity_check): # loading from raw dicts assert isinstance(data, dict) raw = data pdbdata = raw["pdbdata"] coords = raw["coords"] resdata = raw["resdata"] pairdata = raw["pairdata"] pdb_res_offsets = raw["pdb_res_offsets"] pdb_pair_offsets = raw["pdb_pair_offsets"] bin_params = raw["bin_params"] # put this stuff in dataset self.res_ofst = pdb_res_offsets self.pair_ofst = pdb_pair_offsets self.bin_params = bin_params pdbdata["file"] = pdbdata["pdb"] pdbdata["pdb"] = _get_pdb_names(pdbdata["file"]) for k, v in pdbdata.items(): pdbdata[k] = (["pdbid"], v) pdbdata["com"] = (["pdbid", "xyzw"], pdbdata["com"][1]) for k, v in resdata.items(): resdata[k] = (["resid"], v) # resdata["n"] = (["resid", "xyzw"], coords["ncac"][:, 0]) # resdata["ca"] = (["resid", "xyzw"], coords["ncac"][:, 1]) # resdata["c"] = (["resid", "xyzw"], coords["ncac"][:, 2]) resdata["n"] = (["resid", "xyzw"], coords["n"]) resdata["ca"] = (["resid", "xyzw"], coords["ca"]) resdata["c"] = (["resid", "xyzw"], coords["c"]) resdata["o"] = (["resid", "xyzw"], coords["o"]) resdata["cb"] = (["resid", "xyzw"], coords["cb"]) resdata["stub"] = (["resid", "hrow", "hcol"], coords["stubs"]) resdata["r_pdbid"] = resdata["pdbno"] del resdata["pdbno"] for k, v in pairdata.items(): pairdata[k] = (["pairid"], v) pairdata["p_pdbid"] = pairdata["pdbno"] del pairdata["pdbno"] if len(pdbdata.keys() & resdata.keys()): print(pdbdata.keys() & resdata.keys()) assert 0 == len(pdbdata.keys() & resdata.keys()) if len(pdbdata.keys() & pairdata.keys()): print(pdbdata.keys() & pairdata.keys()) assert 0 == len(pdbdata.keys() & pairdata.keys()) if len(pairdata.keys() & resdata.keys()): print(pairdata.keys() & resdata.keys()) assert 0 == len(pairdata.keys() & resdata.keys()) data = {**pdbdata, **resdata, **pairdata} assert len(data) == len(pdbdata) + len(resdata) + len(pairdata) self.data = xr.Dataset( data, coords=dict( xyzw=["x", "y", "z", "w"], hrow=["x", "y", "z", "w"], hcol=["x", "y", "z", "t"], ), attrs=dict( pdb_res_offsets=pdb_res_offsets, pdb_pair_offsets=pdb_pair_offsets, xbin_params=bin_params, xbin_types=raw["xbin_types"], xbin_swap_type=raw["xbin_swap_type"], eweights=raw["eweights"], ), ) _change_seq_ss_to_ids(self) res_ofst = self.pdb_res_offsets[self.p_pdbid] self.data["p_resi"] += res_ofst self.data["p_resj"] += res_ofst assert self.data.stub.sel(hcol="t").shape[1] == 4 assert np.all(self.data.ca.sel(xyzw="w") == 1) assert np.all(self.data.cb.sel(xyzw="w") == 1) if sanity_check is not False: self.sanity_check()