signaturizer.py 8.26 KB
Newer Older
Martino Bertoni's avatar
Martino Bertoni committed
1
2
3
# using the module
import os
import h5py
4
import shutil
Martino Bertoni's avatar
Martino Bertoni committed
5
6
import numpy as np
from tqdm import tqdm
7
import tensorflow.compat.v1 as tf
Martino Bertoni's avatar
Martino Bertoni committed
8
9
10
11
12
13
14
15
16
import tensorflow_hub as hub
try:
    from rdkit import Chem
    from rdkit.Chem import AllChem
except ImportError:
    raise ImportError("requires RDKit " +
                      "https://www.rdkit.org/docs/Install.html")


17
class Signaturizer(object):
Martino Bertoni's avatar
Martino Bertoni committed
18
19
    """Class loading TF-hub module and performing predictions."""

20
    def __init__(self, model_name,
21
22
                 base_url="http://chemicalchecker.com/api/db/getSignaturizer/",
                 version='v1', local=False, tf_version='1', verbose=True):
Martino Bertoni's avatar
Martino Bertoni committed
23
24
25
        """Initialize the Signaturizer.

        Args:
26
27
28
29
            model(str): The model to load. Possible values:
                - the model name (the bioactivity space (e.g. "B1") )
                - the model path (the directory containing 'saved_model.pb')
                - a list of models names or paths (e.g. ["B1", "B2", "E5"])
Martino Bertoni's avatar
Martino Bertoni committed
30
31
                - 'GLOBAL' to get the global (i.e. horizontally stacked)
                    bioactivity signature.
32
            base_url(str): The ChemicalChecker getModel API URL.
33
34
35
            version(int): Signaturizer version.
            local(bool): Wethere the specified model_name shoudl be
                interpreted as a path to a local model.
Martino Bertoni's avatar
Martino Bertoni committed
36
37
            tf_version(int): The Tesorflow version.
            verbose(bool): If True some more information will be printed.
Martino Bertoni's avatar
Martino Bertoni committed
38
39
        """
        self.verbose = verbose
40
        if not isinstance(model_name, list):
41
            if model_name.upper() == 'GLOBAL':
42
                models = [y + x for y in 'ABCDE' for x in '12345']
43
44
            else:
                models = [model_name]
45
46
        else:
            models = model_name
47
        # load modules
48
        self.model_names = list()
49
50
51
52
        self.modules = list()
        self.graph = tf.Graph()
        with self.graph.as_default():
            for model in models:
53
54
55
56
57
58
59
60
                if local:
                    if os.path.isdir(model):
                        if self.verbose:
                            print('LOADING local:', model)
                        spec = hub.create_module_spec_from_saved_model(model)
                        module = hub.Module(spec, tags=['serve'])
                    else:
                        raise Exception('Module path not found!')
61
                else:
62
                    url = base_url + '%s/%s' % (version, model)
63
                    if self.verbose:
64
65
                        print('LOADING remote:', url)
                    module = hub.Module(url, tags=['serve'])
66
                self.modules.append(module)
67
                self.model_names.append(model)
Martino Bertoni's avatar
Martino Bertoni committed
68
69
70
71
72

    def predict(self, smiles, destination=None, chunk_size=1000):
        """Predict signatures for given SMILES.

        Args:
73
74
75
            smiles(list): List of SMILES strings.
            chunk_size(int): Perform prediction on chunks of this size.
            destination(str): File path where to save predictions.
Martino Bertoni's avatar
Martino Bertoni committed
76
77
78
        Returns:
            results: `SignaturizerResult` class.
        """
79
80
81
82
83
84
85
86
        with self.graph.as_default():
            with tf.Session() as sess:
                sess.run(tf.tables_initializer())
                sess.run(tf.global_variables_initializer())
                # Prepare result object
                features = len(self.modules) * 128
                results = SignaturizerResult(len(smiles), destination,
                                             features)
87
                results.dataset[:] = self.model_names
88
89
90
91
                if results.readonly:
                    raise Exception(
                        'Destination file already exists, ' +
                        'delete or rename to proceed.')
92
93
                # predict by chunk
                all_chunks = range(0, len(smiles), chunk_size)
94
                for i in tqdm(all_chunks, disable=not self.verbose):
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                    chunk = slice(i, i + chunk_size)
                    sign0s = list()
                    failed = list()
                    for idx, mol_smiles in enumerate(smiles[chunk]):
                        try:
                            # read SMILES as molecules
                            mol = Chem.MolFromSmiles(mol_smiles)
                            if mol is None:
                                raise Exception(
                                    "Cannot get molecule from smiles.")
                            info = {}
                            fp = AllChem.GetMorganFingerprintAsBitVect(
                                mol, 2, nBits=2048, bitInfo=info)
                            bin_s0 = [fp.GetBit(i) for i in range(
                                fp.GetNumBits())]
                            calc_s0 = np.array(bin_s0).astype(np.float32)
                        except Exception as err:
                            # in case of failure save idx to fill NaNs
                            print("SKIPPING %s: %s", mol_smiles, str(err))
                            failed.append(idx)
                            calc_s0 = np.full((2048, ),  np.nan)
                        finally:
                            sign0s.append(calc_s0)
                    # stack input fingerprints and run predictor
                    sign0s = np.vstack(sign0s)
                    for idx, module in enumerate(self.modules):
                        pred = module(sign0s, signature='serving_default')
                        preds = sess.run(pred)
                        # add NaN where SMILES conversion failed
                        if failed:
                            preds[np.array(failed)] = np.full((128, ),  np.nan)
                        # save chunk to results dictionary
                        mdl_cols = slice(idx * 128, (idx + 1) * 128)
                        results.signature[chunk, mdl_cols] = preds
Martino Bertoni's avatar
Martino Bertoni committed
129
        results.close()
130
131
        if self.verbose:
            print('PREDICTION complete!')
Martino Bertoni's avatar
Martino Bertoni committed
132
133
        return results

134
135
136
137
138
139
140
141
142
    @staticmethod
    def _clear_tfhub_cache():
        cache_dir = os.getenv('TFHUB_CACHE_DIR')
        if cache_dir is None:
            cache_dir = '/tmp/tfhub_modules/'
        if not os.path.isdir(cache_dir):
            raise Exception('Cannot find tfhub cache directory, ' +
                            'please set TFHUB_CACHE_DIR variable')
        shutil.rmtree(cache_dir)
Martino Bertoni's avatar
Martino Bertoni committed
143
        os.mkdir(cache_dir)
144

Martino Bertoni's avatar
Martino Bertoni committed
145
146
147
148

class SignaturizerResult():
    """Class storing result of the prediction.

149
150
    Results are stored in the following numpy vector:
        signatures: 128 float32 defining the molecule signature.
Martino Bertoni's avatar
Martino Bertoni committed
151

152
153
    If a destination is specified the result are saved in an HDF5 file with
    the same vector available as HDF5 datasets.
Martino Bertoni's avatar
Martino Bertoni committed
154
155
    """

156
    def __init__(self, size, destination, features=128):
Martino Bertoni's avatar
Martino Bertoni committed
157
158
159
160
        """Initialize the result containers.

        Args:
            size(int): The number of molecules being signaturized.
161
            destination(str): Path to HDF5 file where prediction results will
Martino Bertoni's avatar
Martino Bertoni committed
162
163
164
                be saved.
        """
        self.dst = destination
165
        self.readonly = False
Martino Bertoni's avatar
Martino Bertoni committed
166
        if self.dst is None:
167
            # simple numpy arrays
Martino Bertoni's avatar
Martino Bertoni committed
168
            self.h5 = None
169
            self.signature = np.zeros((size, features), dtype=np.float32)
170
171
            self.dataset = np.zeros((int(features / 128),),
                                    dtype=h5py.special_dtype(vlen=str))
Martino Bertoni's avatar
Martino Bertoni committed
172
173
174
        else:
            # check if the file exists already
            if os.path.isfile(self.dst):
175
                print('HDF5 file %s exists, opening in read-only.' % self.dst)
Martino Bertoni's avatar
Martino Bertoni committed
176
177
                # this avoid overwriting by mistake
                self.h5 = h5py.File(self.dst, 'r')
178
                self.readonly = True
Martino Bertoni's avatar
Martino Bertoni committed
179
180
181
182
            else:
                # create the datasets
                self.h5 = h5py.File(self.dst, 'w')
                self.h5.create_dataset(
183
                    'signature', (size, features), dtype=np.float32)
184
185
186
                self.h5.create_dataset(
                    'dataset', (int(features / 128),),
                    dtype=h5py.special_dtype(vlen=str))
Martino Bertoni's avatar
Martino Bertoni committed
187
188
            # expose the datasets
            self.signature = self.h5['signature']
189
            self.dataset = self.h5['dataset']
Martino Bertoni's avatar
Martino Bertoni committed
190
191
192
193
194
195
196
197
198

    def close(self):
        if self.h5 is None:
            return
        self.h5.close()
        # leave it open for reading
        self.h5 = h5py.File(self.dst, 'r')
        # expose the datasets
        self.signature = self.h5['signature']
199
        self.dataset = self.h5['dataset']