Commit 8dfbcf5f authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

moved parameters to fit kwargs

parent 8fa3391c
......@@ -47,16 +47,10 @@ class sign2(BaseSignature, DataSignature):
DataSignature.__init__(self, self.data_path)
# assign dataset
self.dataset = dataset
# get parameters or default values
self.params = dict()
self.params['graph'] = params.get('graph', None)
self.params['node2vec'] = params.get('node2vec', None)
self.params['adanet'] = params.get('adanet', None)
if self.params['adanet'] is not None:
self.cpu = self.params['adanet'].get('cpu', 1)
def fit(self, sign1=None, neig1=None, reuse=True, compare_nn=False,
oos_predictor=True, **params):
oos_predictor=True, graph_kwargs={}, node2vec_kwargs={},
adanet_kwargs={'cpu': 1}, **kwargs):
"""Fit signature 2 given signature 1 and its nearest neighbors.
Node2vec embeddings are computed using the graph derived from sign1.
......@@ -75,7 +69,7 @@ class sign2(BaseSignature, DataSignature):
from chemicalchecker.tool.node2vec import Node2Vec
except ImportError as err:
raise err
BaseSignature.fit(self, **params)
BaseSignature.fit(self, **kwargs)
#########
# step 1: Node2Vec (learn graph embedding) input is neig1
#########
......@@ -99,16 +93,12 @@ class sign2(BaseSignature, DataSignature):
self.__log.debug('Node2Vec on %s' % sign1)
n2v = Node2Vec(executable=Config().TOOLS.node2vec_exec)
# use neig1 to generate the Node2Vec input graph (as edgelist)
graph_params = self.params['graph']
node2vec_path = os.path.join(self.model_path, 'node2vec')
if not os.path.isdir(node2vec_path):
os.makedirs(node2vec_path)
graph_file = os.path.join(node2vec_path, 'graph.edgelist')
if not reuse or not os.path.isfile(graph_file):
if graph_params:
n2v.to_edgelist(sign1, neig1, graph_file, **graph_params)
else:
n2v.to_edgelist(sign1, neig1, graph_file)
n2v.to_edgelist(sign1, neig1, graph_file, **graph_kwargs)
# check that all molecules are considered in the graph
with open(graph_file, 'r') as fh:
lines = fh.readlines()
......@@ -123,13 +113,9 @@ class sign2(BaseSignature, DataSignature):
graph = SNAPNetwork.from_file(graph_file)
graph.stats_toJSON(graph_stat_file)
# run Node2Vec to generate embeddings
node2vec_params = self.params['node2vec']
emb_file = os.path.join(node2vec_path, 'n2v.emb')
if not reuse or not os.path.isfile(emb_file):
if node2vec_params:
n2v.run(graph_file, emb_file, **node2vec_params)
else:
n2v.run(graph_file, emb_file)
n2v.run(graph_file, emb_file, **node2vec_kwargs)
# convert to signature h5 format
if not reuse or not os.path.isfile(self.data_path):
n2v.emb_to_h5(sign1.keys, emb_file, self.data_path)
......@@ -157,28 +143,19 @@ class sign2(BaseSignature, DataSignature):
self.update_status("Training out-of-sample predictor")
self.__log.debug('AdaNet fit %s with Node2Vec output' % sign1)
# get params and set folder
adanet_params = self.params['adanet']
adanet_path = os.path.join(self.model_path, 'adanet')
if adanet_params:
if 'model_dir' in adanet_params:
adanet_path = adanet_params.pop('model_dir')
adanet_path = adanet_kwargs.get('model_dir', adanet_path)
if not reuse or not os.path.isdir(adanet_path):
os.makedirs(adanet_path, exist_ok=True)
# prepare train-test file
traintest_file = os.path.join(adanet_path, 'traintest.h5')
if adanet_params:
traintest_file = adanet_params.pop(
'traintest_file', traintest_file)
traintest_file = adanet_kwargs.get(
'traintest_file', traintest_file)
if not reuse or not os.path.isfile(traintest_file):
Traintest.create_signature_file(
sign1.data_path, self.data_path, traintest_file)
if adanet_params:
ada = AdaNet(model_dir=adanet_path,
traintest_file=traintest_file, **adanet_params)
else:
ada = AdaNet(model_dir=adanet_path,
traintest_file=traintest_file)
ada = AdaNet(model_dir=adanet_path,
traintest_file=traintest_file, **adanet_kwargs)
# learn NN with AdaNet
self.__log.debug('AdaNet training on %s' % traintest_file)
ada.train_and_evaluate()
......@@ -207,11 +184,11 @@ class sign2(BaseSignature, DataSignature):
else:
self.map(sign2_full.data_path)
# finalize signature
BaseSignature.fit_end(self, **params)
BaseSignature.fit_end(self, **kwargs)
def predict(self, sign1, destination=None):
"""Use the learned model to predict the signature.
Args:
sign1(signature): A valid Signature type 1
destination(None|path|signature): If None the prediction results are
......@@ -388,7 +365,8 @@ class sign2(BaseSignature, DataSignature):
lr_pred['name'] = "LinearRegression"
return lr_pred
def eval_node2vec(self, sign1, neig1, reuse=True):
def eval_node2vec(self, sign1, neig1, reuse=True, graph_kwargs={},
node2vec_kwargs={}):
"""Evaluate node2vec performances.
Node2vec embeddings are computed using the graph derived from sign1.
......@@ -414,22 +392,15 @@ class sign2(BaseSignature, DataSignature):
self.__log.debug('Node2Vec on %s' % sign1)
n2v = Node2Vec(executable=Config().TOOLS.node2vec_exec)
# define the n2v model path
node2vec_params = self.params['node2vec']
node2vec_path = os.path.join(self.model_path, 'node2vec_eval')
if node2vec_params:
if 'model_dir' in node2vec_params:
node2vec_path = node2vec_params.pop('model_dir')
node2vec_path = node2vec_kwargs.get('model_dir', node2vec_path)
if not reuse or not os.path.isdir(node2vec_path):
os.makedirs(node2vec_path)
# use neig1 to generate the Node2Vec input graph (as edgelist)
graph_params = self.params['graph']
graph_file = os.path.join(
self.model_path, 'node2vec', 'graph.edgelist')
if not reuse or not os.path.isfile(graph_file):
if graph_params:
n2v.to_edgelist(sign1, neig1, graph_file, **graph_params)
else:
n2v.to_edgelist(sign1, neig1, graph_file)
n2v.to_edgelist(sign1, neig1, graph_file, **graph_kwargs)
# split graph in train and test
graph_train = graph_file + ".train"
graph_test = graph_file + ".test"
......@@ -440,10 +411,7 @@ class sign2(BaseSignature, DataSignature):
# run Node2Vec to generate embeddings based on train
emb_file = os.path.join(node2vec_path, 'n2v.emb')
if not reuse or not os.path.isfile(emb_file):
if node2vec_params:
n2v.run(graph_train, emb_file, **node2vec_params)
else:
n2v.run(graph_train, emb_file)
n2v.run(graph_train, emb_file, **node2vec_kwargs)
# create evaluation sign2
eval_s2 = sign2(node2vec_path, self.dataset)
# convert to signature h5 format
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment