Commit b0eec7cb authored by mlocatelli's avatar mlocatelli
Browse files

Sign0 fit method parametrization; small fixes to diagnostics tasks

parent 001b0c2b
Pipeline #2533 failed with stages
in 8 minutes and 1 second
......@@ -26,7 +26,7 @@ from chemicalchecker.util.sampler.triplets import TripletSampler
class sign0(BaseSignature, DataSignature):
"""Signature type 0 class."""
def __init__(self, signature_path, dataset, **params):
def __init__(self, signature_path, dataset,**params):
"""Initialize a Signature.
Args:
......@@ -327,8 +327,8 @@ class sign0(BaseSignature, DataSignature):
def fit(self, cc_root=None, pairs=None, X=None, keys=None, features=None,
data_file=None, key_type="inchikey", agg_method="average",
do_triplets=False, max_features=10000, chunk_size=10000,
sanitize=True, **params):
do_triplets=False, max_features=None, chunk_size=None,
sanitize=True, trim_features=True,**kwargs):
"""Process the input data.
We produce a sign0 (full) and a sign0 (reference).
......@@ -350,12 +350,12 @@ class sign0(BaseSignature, DataSignature):
should contain the required data in datasets.
do_triplets(boolean): Draw triplets from the CC (default=True).
"""
BaseSignature.fit(self, **params)
BaseSignature.fit(self, **kwargs)
self.clear()
self.update_status("Getting data")
if pairs is None and X is None and data_file is None:
self.__log.debug("Runnning preprocess")
data_file = Preprocess.preprocess(self, **params)
data_file = Preprocess.preprocess(self, **kwargs)
self.__log.debug("data_file is {}".format(data_file))
res = self.get_data(pairs=pairs, X=X, keys=keys, features=features,
......@@ -369,10 +369,7 @@ class sign0(BaseSignature, DataSignature):
if sanitize:
self.update_status("Sanitizing")
# we want to keep exactly 2048 features (Morgan fingerprint) for A1
# FIXME trimFeatures should be a fit parameter
trimFeatures = self.dataset != 'A1.001'
san = Sanitizer(trim=trimFeatures, max_features=max_features,
san = Sanitizer(trim=trim_features, max_features=max_features,
chunk_size=chunk_size)
X, keys, keys_raw, features = san.transform(
V=X, keys=keys, keys_raw=keys_raw, features=features,
......@@ -403,16 +400,16 @@ class sign0(BaseSignature, DataSignature):
self.refresh()
# save reference
overwrite = params.get('overwrite', False)
overwrite = kwargs.get('overwrite', False)
self.save_reference(overwrite=overwrite)
# Making triplets
if do_triplets:
self.update_status("Sampling triplets")
cc = self.get_cc(cc_root)
sampler = TripletSampler(cc, self, save=True)
sampler.sample(**params)
sampler.sample(**kwargs)
# finalize signature
BaseSignature.fit_end(self, **params)
BaseSignature.fit_end(self, **kwargs)
def predict(self, pairs=None, X=None, keys=None, features=None,
data_file=None, key_type=None, merge=False, merge_method="new",
......
......@@ -185,10 +185,15 @@ def main(args):
sign2_list = [cc.get_signature('sign2', 'full', ds)
for ds in cc.datasets_exemplary()]
mfp = cc.get_signature('sign0', 'full', 'A1.001').data_path
# we want to keep exactly 2048 features (Morgan fingerprint) for A1
fit_kwargs['sign0']['A1.001'] = {'trim_features': False}
for ds in datasets:
fit_kwargs['sign0'][ds] = {
'key_type': 'inchikey',
'do_triplets': False,
'max_features': 10000,
'chunk_size': 10000,
'sanitize': True,
'validations': True,
'diagnostics': False
}
......@@ -331,7 +336,7 @@ def main(args):
# by the 'reference_cc' input parameter (only for sign0 case, for other signatures
# the reference is args.cc_root itself)
cctype = 'sign0'
task = PythonCallable(name="diagnostics",
task = PythonCallable(name="diagnostics_sign0_BCDE",
python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc])
pp.add_task(task)
......@@ -383,7 +388,7 @@ def main(args):
# by the 'reference_cc' input parameter (only for sign0 case, for other signatures
# the reference is args.cc_root itself)
cctype = 'sign0'
task = PythonCallable(name="diagnostics",
task = PythonCallable(name="diagnostics_sign0_A",
python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc])
pp.add_task(task)
......@@ -409,7 +414,7 @@ def main(args):
# TASK: diagonistc plots for sign1-sign2
cctypes = ['sign1', 'sign2']
for cctype in cctypes:
task = PythonCallable(name="diagnostics",
task = PythonCallable(name="diagnostics_" + cctype,
python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.cc_root])
pp.add_task(task)
......@@ -450,7 +455,7 @@ def main(args):
#############################################
# TASK: Calculate diagnostics plots of sign3 for all spaces
cctype = 'sign3'
task = PythonCallable(name="diagnostics",
task = PythonCallable(name="diagnostics_" + cctype,
python_callable=Diagnosis.diagnostics_hpc,
op_args=[pp.tmpdir, args.cc_root, cctype, molset[cctype], dss, args.reference_cc])
pp.add_task(task)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment