Statistiques
| Branche: | Tag: | Révision :

dockonsurf / modules / clustering.py @ 4e82c425

Historique | Voir | Annoter | Télécharger (9,37 ko)

1
"""Functions to cluster structures.
2

3
functions:
4
get_rmsd: Computes the rmsd matrix of the conformers in a list of rdkit mol
5
    objects.
6
get_labels_affty: Clusters data in affinity matrix form by assigning labels to
7
    data points.
8
get_labels_vector: Clusters data in vectorial form by assigning labels to
9
    data points.
10
get_clusters: Groups data-points belonging to the same cluster into arrays of
11
    indices.
12
get_exemplars_affty: Computes the exemplars for every cluster and returns a list
13
    of indices.
14
plot_clusters: Plots the clustered data casting a color to every cluster.
15
clustering: Directs the clustering process by calling the relevant functions.
16
"""
17
import logging
18

    
19
import hdbscan
20
import numpy as np
21

    
22
logger = logging.getLogger('DockOnSurf')
23

    
24

    
25
def get_rmsd(mol_list: list, remove_Hs="c"):
26
    """Computes the rmsd matrix of the conformers in a list of rdkit mol objects
27

28
    @param mol_list: list of rdkit mol objects containing the conformers.
29
    @param remove_Hs: bool or str,
30
    @return rmsd_matrix: Matrix containing the rmsd values of every pair of
31
    conformers.
32

33
    The RMSD values of every pair of conformers is computed, stored in matrix
34
    form and returned back. The calculation of rmsd values can take into
35
    account all hydrogens, none, or only the ones not linked to carbon atoms.
36
    """
37
    import rdkit.Chem.AllChem as Chem
38

    
39
    if len(mol_list) < 2:
40
        err = "The provided molecule has less than 2 conformers"
41
        logger.error(err)
42
        raise ValueError(err)
43

    
44
    if not remove_Hs:
45
        pass
46
    elif remove_Hs or remove_Hs.lower() == "all":
47
        mol_list = [Chem.RemoveHs(mol) for mol in mol_list]
48
    elif remove_Hs.lower() == "c":
49
        from modules.isolated import remove_C_linked_Hs
50
        mol_list = [remove_C_linked_Hs(mol) for mol in mol_list]
51
    else:
52
        err = "remove_Hs value does not have an acceptable value"
53
        logger.error(err)
54
        raise ValueError(err)
55

    
56
    num_confs = len(mol_list)
57
    conf_ids = list(range(num_confs))
58
    rmsd_mtx = np.zeros((num_confs, num_confs))
59
    for id1 in conf_ids:  # TODO reduce RMSD precision
60
        for id2 in conf_ids[id1 + 1:]:
61
            rmsd = Chem.GetBestRMS(mol_list[id1], mol_list[id2])
62
            rmsd_mtx[id1][id2] = rmsd
63
            rmsd_mtx[id2][id1] = rmsd
64

    
65
    return rmsd_mtx
66

    
67

    
68
def get_labels_affty(affty_mtx, kind="rmsd"):
69
    """Clusters data in affinity matrix form by assigning labels to data points.
70

71
    @param affty_mtx: Data to be clustered, it must be an affinity matrix.
72
    (Eg. Euclidean distances between points, RMSD Matrix, etc.).
73
    Shape: [n_points, n_points]
74
    @param kind: Which kind of data the affinity matrix contains.
75
    @return: list of cluster labels. Every data point is assigned a number
76
    corresponding to the cluster it belongs to.
77
    """
78
    if np.average(affty_mtx) < 1e-3 and kind == "rmsd":
79
        sing_clust = True
80
        min_size = int(len(affty_mtx) / 2)
81
    else:
82
        sing_clust = False
83
        min_size = 20
84
    hdbs = hdbscan.HDBSCAN(metric="precomputed",
85
                           min_samples=5,
86
                           min_cluster_size=min_size,
87
                           allow_single_cluster=sing_clust)
88
    return hdbs.fit_predict(affty_mtx)
89

    
90

    
91
def get_labels_vector():
92
    """Clusters data in vectorial form by assigning labels to data points.
93

94
    @return: list of cluster labels. Every data point is assigned a number
95
    corresponding to the cluster it belongs to.
96
    """
97
    # TODO Implement it.
98
    return []
99

    
100

    
101
def get_clusters(labels):
102
    """Groups data-points belonging to the same cluster into arrays of indices.
103

104
    @param labels: list of cluster labels (numbers) corresponding to the cluster
105
    it belongs to.
106
    @return: tuple of arrays. Every array contains the indices (relative to the
107
    labels list) of the data points belonging to the same cluster.
108
    """
109
    n_clusters = max(labels) + 1
110
    return tuple(np.where(labels == clust_num)[0]
111
                 for clust_num in range(n_clusters))
112

    
113

    
114
def get_exemplars_affty(affty_mtx, clusters):
115
    """Computes the exemplars for every cluster and returns a list of indices.
116

117
    @param affty_mtx: Data structured in form of affinity matrix. eg. Euclidean
118
    distances between points, RMSD Matrix, etc.) shape: [n_points, n_points].
119
    @param clusters: tuple of arrays. Every array contains the indices (relative
120
    to the affinity matrix) of the data points belonging to the same cluster.
121
    @return: list of indices (relative to the affinity matrix) of the exemplars
122
    for every cluster.
123

124
    This function finds the exemplars of already clusterized data. It does
125
    that by (i) building a rmsd matrix for each existing cluster with the values
126
    of the total RMSD matrix (ii) carrying out an actual clustering for each
127
    cluster-specific matrix using a set of parameters (large negative value of
128
    preference) such that it always finds only one cluster and (iii) it then
129
    calculates the exemplar for the matrix.
130
    """
131
    from sklearn.cluster import AffinityPropagation
132
    # Splits Total RMSD matrix into cluster-specific RMSD matrices.
133
    clust_affty_mtcs = tuple(affty_mtx[np.ix_(clust, clust)]
134
                             for clust in clusters)
135
    exemplars = []
136
    # Carries out the forced-to-converge-to-1 clustering for each already
137
    # existing cluster rmsd matrix and calculates the exemplar.
138
    for i, mtx in enumerate(clust_affty_mtcs):
139
        pref = -1e6 * np.max(np.abs(mtx))
140
        af = AffinityPropagation(affinity='precomputed', preference=pref,
141
                                 damping=0.95, max_iter=2000,
142
                                 random_state=None).fit(mtx)
143
        exemplars.append(clusters[i][af.cluster_centers_indices_[0]])
144
    return exemplars
145

    
146

    
147
def plot_clusters(labels, x, y, exemplars=None, save=True):
148
    """Plots the clustered data casting a color to every cluster.
149

150
    @param labels: list of cluster labels (numbers) corresponding to the cluster
151
    it belongs to.
152
    @param x: list of data of the x axis.
153
    @param y: list of data of the y axis.
154
    @param exemplars: list of data point indices (relative to the labels list)
155
    considered as cluster exemplars.
156
    @param save: bool, Whether to save the generated plot into a file or not.
157
    (in the latter case the plot is shown in a new window)
158
    """
159
    import matplotlib.pyplot as plt
160
    from matplotlib import cm, colors
161

    
162
    n_clusters = max(labels) + 1
163
    rb = cm.get_cmap('gist_rainbow', max(n_clusters, 1))
164
    rb.set_under()
165
    plt.figure(figsize=(10, 8))
166
    for i in range(len(labels)):
167
        plt.plot(x[i], y[i], c=rb(labels[i]), marker='.')
168
        if len(exemplars) > 0 and i == exemplars[labels[i]]:
169
            plt.plot(x[i], y[i], c=rb(labels[i]), marker="x",
170
                     markersize=15,
171
                     label=f"Exemplar cluster {labels[i]}")
172
    plt.title(f'Found {n_clusters} Clusters.')
173
    plt.xlabel("Energy")
174
    plt.ylabel("MOI")
175
    plt.legend()
176

    
177
    bounds = list(range(max(n_clusters, 1)))
178
    norm = colors.Normalize(vmin=min(labels), vmax=max(labels))
179
    plt.colorbar(cm.ScalarMappable(norm=norm, cmap=rb), ticks=bounds)
180
    if save:
181
        from modules.utilities import check_bak
182
        check_bak('clusters.png')
183
        plt.savefig('clusters.png')
184
        plt.close("all")
185
    else:
186
        plt.show()
187

    
188

    
189
def clustering(data, plot=False, x=None, y=None):
190
    """Directs the clustering process by calling the relevant functions.
191

192
    @param data: The data to be clustered. It must be stored in vector form
193
    [n_features, n_samples] or in affinity matrix form [n_samples, n_samples],
194
    symmetric and 0 in the main diagonal. (Eg. Euclidean distances between
195
    points, RMSD Matrix, etc.).
196
    @param plot: bool, Whether to plot the clustered data.
197
    @param x: Necessary only if plot is turned on. X values to plot the data.
198
    @param y: Necessary only if plot is turned on. Y values to plot the data.
199
    @return: list of exemplars, list of indices (relative to data)
200
    exemplars for every cluster.
201
    """
202
    from collections.abc import Iterable
203

    
204
    data_err = "Data must be stored in vector form [n_features, n_samples] or" \
205
               "in affinity matrix form [n_samples, n_samples]: symmetric " \
206
               "and 0 in the main diagonal. Eg. RMSD matrix"
207
    debug_err = "On debug mode x and y should be provided"
208

    
209
    if plot and not (isinstance(x, Iterable) and isinstance(y, Iterable)):
210
        logger.error(debug_err)
211
        raise ValueError(debug_err)
212
    if not isinstance(data, np.ndarray):
213
        data = np.array(data)
214
    if len(data.shape) != 2:
215
        logger.error(data_err)
216
        raise ValueError(data_err)
217

    
218
    if data.shape[0] == data.shape[1] \
219
            and (np.tril(data).T == np.triu(data)).all():
220
        logger.info("Clustering using affinity matrix.")
221
        labels = get_labels_affty(data)
222
        if max(labels) == -1:
223
            logger.warning('Clustering of conformers did not converge. Try '
224
                           "setting a smaller 'min_samples' parameter.")
225
            exemplars = list(range(data.shape[0]))
226
        else:
227
            clusters = get_clusters(labels)
228
            exemplars = get_exemplars_affty(data, clusters)
229
        if plot:
230
            plot_clusters(labels, x, y, exemplars, save=True)
231
        logger.info(f'Conformers are grouped in {len(exemplars)} clusters.')
232
        return exemplars
233
    else:
234
        not_impl_err = 'Clustering not yet implemented for vectorized data'
235
        logger.error(not_impl_err)
236
        raise NotImplementedError(not_impl_err)