Commit b0eec7cb authored by mlocatelli's avatar mlocatelli
Browse files

Sign0 fit method parametrization; small fixes to diagnostics tasks

parent 001b0c2b
Pipeline #2533 failed with stages
in 8 minutes and 1 second
...@@ -26,7 +26,7 @@ from chemicalchecker.util.sampler.triplets import TripletSampler ...@@ -26,7 +26,7 @@ from chemicalchecker.util.sampler.triplets import TripletSampler
class sign0(BaseSignature, DataSignature): class sign0(BaseSignature, DataSignature):
"""Signature type 0 class.""" """Signature type 0 class."""
def __init__(self, signature_path, dataset, **params): def __init__(self, signature_path, dataset,**params):
"""Initialize a Signature. """Initialize a Signature.
Args: Args:
...@@ -327,8 +327,8 @@ class sign0(BaseSignature, DataSignature): ...@@ -327,8 +327,8 @@ class sign0(BaseSignature, DataSignature):
def fit(self, cc_root=None, pairs=None, X=None, keys=None, features=None, def fit(self, cc_root=None, pairs=None, X=None, keys=None, features=None,
data_file=None, key_type="inchikey", agg_method="average", data_file=None, key_type="inchikey", agg_method="average",
do_triplets=False, max_features=10000, chunk_size=10000, do_triplets=False, max_features=None, chunk_size=None,
sanitize=True, **params): sanitize=True, trim_features=True,**kwargs):
"""Process the input data. """Process the input data.
We produce a sign0 (full) and a sign0 (reference). We produce a sign0 (full) and a sign0 (reference).
...@@ -350,12 +350,12 @@ class sign0(BaseSignature, DataSignature): ...@@ -350,12 +350,12 @@ class sign0(BaseSignature, DataSignature):
should contain the required data in datasets. should contain the required data in datasets.
do_triplets(boolean): Draw triplets from the CC (default=True). do_triplets(boolean): Draw triplets from the CC (default=True).
""" """
BaseSignature.fit(self, **params) BaseSignature.fit(self, **kwargs)
self.clear() self.clear()
self.update_status("Getting data") self.update_status("Getting data")
if pairs is None and X is None and data_file is None: if pairs is None and X is None and data_file is None:
self.__log.debug("Runnning preprocess") self.__log.debug("Runnning preprocess")
data_file = Preprocess.preprocess(self, **params) data_file = Preprocess.preprocess(self, **kwargs)
self.__log.debug("data_file is {}".format(data_file)) self.__log.debug("data_file is {}".format(data_file))
res = self.get_data(pairs=pairs, X=X, keys=keys, features=features, res = self.get_data(pairs=pairs, X=X, keys=keys, features=features,
...@@ -369,10 +369,7 @@ class sign0(BaseSignature, DataSignature): ...@@ -369,10 +369,7 @@ class sign0(BaseSignature, DataSignature):
if sanitize: if sanitize:
self.update_status("Sanitizing") self.update_status("Sanitizing")
# we want to keep exactly 2048 features (Morgan fingerprint) for A1 san = Sanitizer(trim=trim_features, max_features=max_features,
# FIXME trimFeatures should be a fit parameter
trimFeatures = self.dataset != 'A1.001'
san = Sanitizer(trim=trimFeatures, max_features=max_features,
chunk_size=chunk_size) chunk_size=chunk_size)
X, keys, keys_raw, features = san.transform( X, keys, keys_raw, features = san.transform(
V=X, keys=keys, keys_raw=keys_raw, features=features, V=X, keys=keys, keys_raw=keys_raw, features=features,
...@@ -403,16 +400,16 @@ class sign0(BaseSignature, DataSignature): ...@@ -403,16 +400,16 @@ class sign0(BaseSignature, DataSignature):
self.refresh() self.refresh()
# save reference # save reference
overwrite = params.get('overwrite', False) overwrite = kwargs.get('overwrite', False)
self.save_reference(overwrite=overwrite) self.save_reference(overwrite=overwrite)
# Making triplets # Making triplets
if do_triplets: if do_triplets:
self.update_status("Sampling triplets") self.update_status("Sampling triplets")
cc = self.get_cc(cc_root) cc = self.get_cc(cc_root)
sampler = TripletSampler(cc, self, save=True) sampler = TripletSampler(cc, self, save=True)
sampler.sample(**params) sampler.sample(**kwargs)
# finalize signature # finalize signature
BaseSignature.fit_end(self, **params) BaseSignature.fit_end(self, **kwargs)
def predict(self, pairs=None, X=None, keys=None, features=None, def predict(self, pairs=None, X=None, keys=None, features=None,
data_file=None, key_type=None, merge=False, merge_method="new", data_file=None, key_type=None, merge=False, merge_method="new",
......
...@@ -185,10 +185,15 @@ def main(args): ...@@ -185,10 +185,15 @@ def main(args):
sign2_list = [cc.get_signature('sign2', 'full', ds) sign2_list = [cc.get_signature('sign2', 'full', ds)
for ds in cc.datasets_exemplary()] for ds in cc.datasets_exemplary()]
mfp = cc.get_signature('sign0', 'full', 'A1.001').data_path mfp = cc.get_signature('sign0', 'full', 'A1.001').data_path
# we want to keep exactly 2048 features (Morgan fingerprint) for A1
fit_kwargs['sign0']['A1.001'] = {'trim_features': False}
for ds in datasets: for ds in datasets:
fit_kwargs['sign0'][ds] = { fit_kwargs['sign0'][ds] = {
'key_type': 'inchikey', 'key_type': 'inchikey',
'do_triplets': False, 'do_triplets': False,
'max_features': 10000,
'chunk_size': 10000,
'sanitize': True,
'validations': True, 'validations': True,
'diagnostics': False 'diagnostics': False
} }
...@@ -331,7 +336,7 @@ def main(args): ...@@ -331,7 +336,7 @@ def main(args):
# by the 'reference_cc' input parameter (only for sign0 case, for other signatures # by the 'reference_cc' input parameter (only for sign0 case, for other signatures
# the reference is args.cc_root itself) # the reference is args.cc_root itself)
cctype = 'sign0' cctype = 'sign0'
task = PythonCallable(name="diagnostics", task = PythonCallable(name="diagnostics_sign0_BCDE",
python_callable=Diagnosis.diagnostics_hpc, python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc]) op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc])
pp.add_task(task) pp.add_task(task)
...@@ -383,7 +388,7 @@ def main(args): ...@@ -383,7 +388,7 @@ def main(args):
# by the 'reference_cc' input parameter (only for sign0 case, for other signatures # by the 'reference_cc' input parameter (only for sign0 case, for other signatures
# the reference is args.cc_root itself) # the reference is args.cc_root itself)
cctype = 'sign0' cctype = 'sign0'
task = PythonCallable(name="diagnostics", task = PythonCallable(name="diagnostics_sign0_A",
python_callable=Diagnosis.diagnostics_hpc, python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc]) op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc])
pp.add_task(task) pp.add_task(task)
...@@ -409,7 +414,7 @@ def main(args): ...@@ -409,7 +414,7 @@ def main(args):
# TASK: diagonistc plots for sign1-sign2 # TASK: diagonistc plots for sign1-sign2
cctypes = ['sign1', 'sign2'] cctypes = ['sign1', 'sign2']
for cctype in cctypes: for cctype in cctypes:
task = PythonCallable(name="diagnostics", task = PythonCallable(name="diagnostics_" + cctype,
python_callable=Diagnosis.diagnostics_hpc, python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.cc_root]) op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.cc_root])
pp.add_task(task) pp.add_task(task)
...@@ -450,7 +455,7 @@ def main(args): ...@@ -450,7 +455,7 @@ def main(args):
############################################# #############################################
# TASK: Calculate diagnostics plots of sign3 for all spaces # TASK: Calculate diagnostics plots of sign3 for all spaces
cctype = 'sign3' cctype = 'sign3'
task = PythonCallable(name="diagnostics", task = PythonCallable(name="diagnostics_" + cctype,
python_callable=Diagnosis.diagnostics_hpc, python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc]) op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc])
pp.add_task(task) pp.add_task(task)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment