Statistiques
| Branche: | Tag: | Révision :

dockonsurf / modules / clustering.py @ 695dcff8

Historique | Voir | Annoter | Télécharger (8,61 ko)

1
"""Functions to cluster structures.
2

3
functions:
4
get_labels_affty: Clusters data in affinity matrix form by assigning labels to
5
data points.
6
get_labels_vector: Clusters data in vectorial form by assigning labels to
7
data points.
8
get_clusters: Groups data-points belonging to the same cluster into arrays of
9
indices.
10
get_exemplars_affty: Computes the exemplars for every cluster and returns a list
11
of indices.
12
plot_clusters: Plots the clustered data casting a color to every cluster.
13
clustering: Directs the clustering process by calling the relevant functions.
14
"""
15
import logging
16

    
17
import hdbscan
18
import numpy as np
19

    
20
logger = logging.getLogger('DockOnSurf')
21

    
22

    
23
def get_rmsd(mol_list: list, remove_Hs="c"):
24
    """Computes the rmsd matrix of the conformers in a rdkit mol object.
25

26
    @param mol_list: list of rdkit mol objects containing the conformers.
27
    @param remove_Hs: bool or str,
28
    @return rmsd_matrix: Matrix containing the rmsd values of every pair of
29
    conformers.
30

31
    The RMSD values of every pair of conformers is computed, stored in matrix
32
    form and returned back. The calculation of rmsd values can take into
33
    account all hydrogens, none, or only the ones not linked to carbon atoms.
34
    """
35
    import rdkit.Chem.AllChem as Chem
36

    
37
    if len(mol_list) < 2:
38
        err = "The provided molecule has less than 2 conformers"
39
        logger.error(err)
40
        raise ValueError(err)
41

    
42
    if not remove_Hs:
43
        pass
44
    elif remove_Hs or remove_Hs.lower() == "all":
45
        mol_list = [Chem.RemoveHs(mol) for mol in mol_list]
46
    elif remove_Hs.lower() == "c":
47
        from isolated import remove_C_linked_Hs
48
        mol_list = [remove_C_linked_Hs(mol) for mol in mol_list]
49
    else:
50
        err = "remove_Hs value does not have an acceptable value"
51
        logger.error(err)
52
        raise ValueError(err)
53

    
54
    num_confs = len(mol_list)
55
    conf_ids = list(range(num_confs))
56
    rmsd_mtx = np.zeros((num_confs, num_confs))
57
    for id1 in conf_ids:  # TODO reduce RMSD precision
58
        for id2 in conf_ids[id1 + 1:]:
59
            rmsd = Chem.GetBestRMS(mol_list[id1], mol_list[id2])
60
            rmsd_mtx[id1][id2] = rmsd
61
            rmsd_mtx[id2][id1] = rmsd
62

    
63
    return rmsd_mtx
64

    
65

    
66
def get_labels_affty(affty_mtx, kind="rmsd"):
67
    """Clusters data in affinity matrix form by assigning labels to data points.
68

69
    @param affty_mtx: Data to be clustered, it must be an affinity matrix.
70
    (Eg. Euclidean distances between points, RMSD Matrix, etc.).
71
    Shape: [n_points, n_points]
72
    @param kind: Which kind of data the affinity matrix contains.
73
    @return: list of cluster labels. Every data point is assigned a number
74
    corresponding to the cluster it belongs to.
75
    """
76
    if np.average(affty_mtx) < 1e-3 and kind == "rmsd":
77
        sing_clust = True
78
        min_size = int(len(affty_mtx) / 2)
79
    else:
80
        sing_clust = False
81
        min_size = 20
82
    hdbs = hdbscan.HDBSCAN(metric="precomputed",
83
                           min_samples=5,
84
                           min_cluster_size=min_size,
85
                           allow_single_cluster=sing_clust)
86
    return hdbs.fit_predict(affty_mtx)
87

    
88

    
89
def get_labels_vector():
90
    """Clusters data in vectorial form by assigning labels to data points.
91

92
    @return: list of cluster labels. Every data point is assigned a number
93
    corresponding to the cluster it belongs to.
94
    """
95
    return []
96

    
97

    
98
def get_clusters(labels):
99
    """Groups data-points belonging to the same cluster into arrays of indices.
100

101
    @param labels: list of cluster labels (numbers) corresponding to the cluster
102
    it belongs to.
103
    @return: tuple of arrays. Every array contains the indices (relative to the
104
    labels list) of the data points belonging to the same cluster.
105
    """
106
    n_clusters = max(labels) + 1
107
    return tuple(np.where(labels == clust_num)[0]
108
                 for clust_num in range(n_clusters))
109

    
110

    
111
def get_exemplars_affty(affty_mtx, clusters):
112
    """Computes the exemplars for every cluster and returns a list of indices.
113

114
    @param affty_mtx: Data structured in form of affinity matrix. eg. Euclidean
115
    distances between points, RMSD Matrix, etc.) shape: [n_points, n_points].
116
    @param clusters: tuple of arrays. Every array contains the indices (relative
117
    to the affinity matrix) of the data points belonging to the same cluster.
118
    @return: list of indices (relative to the affinity matrix) exemplars for
119
    every cluster.
120
    """
121
    from sklearn.cluster import AffinityPropagation
122
    clust_affty_mtcs = tuple(affty_mtx[np.ix_(clust, clust)]
123
                             for clust in clusters)
124
    exemplars = []
125
    for i, mtx in enumerate(clust_affty_mtcs):
126
        pref = -1e6 * np.max(np.abs(mtx))
127
        af = AffinityPropagation(affinity='precomputed',
128
                                 preference=pref,
129
                                 damping=0.95,
130
                                 max_iter=2000).fit(mtx)
131
        exemplars.append(clusters[i][af.cluster_centers_indices_[0]])
132
    return exemplars
133

    
134

    
135
def plot_clusters(labels, x, y, exemplars=None, save=True):
136
    """Plots the clustered data casting a color to every cluster.
137

138
    @param labels: list of cluster labels (numbers) corresponding to the cluster
139
    it belongs to.
140
    @param x: list of data of the x axis.
141
    @param y: list of data of the y axis.
142
    @param exemplars: list of data point indices (relative to the labels list)
143
    considered as cluster exemplars.
144
    @param save: bool, Whether to save the generated plot into a file or not.
145
    (in the latter case the plot is shown in a new window)
146
    """
147
    import matplotlib.pyplot as plt
148
    from matplotlib import cm, colors
149

    
150
    n_clusters = max(labels) + 1
151
    rb = cm.get_cmap('gist_rainbow', max(n_clusters, 1))
152
    rb.set_under()
153
    plt.figure(figsize=(10, 8))
154
    for i in range(len(labels)):
155
        plt.plot(x[i], y[i], c=rb(labels[i]), marker='.')
156
        if len(exemplars) > 0 and i == exemplars[labels[i]]:
157
            plt.plot(x[i], y[i], c=rb(labels[i]), marker="x",
158
                     markersize=15,
159
                     label=f"Exemplar cluster {labels[i]}")
160
    plt.title(f'Found {n_clusters} Clusters.')
161
    plt.xlabel("Energy")
162
    plt.ylabel("MOI")
163
    plt.legend()
164

    
165
    bounds = list(range(max(n_clusters, 1)))
166
    norm = colors.Normalize(vmin=min(labels), vmax=max(labels))
167
    plt.colorbar(cm.ScalarMappable(norm=norm, cmap=rb), ticks=bounds)
168
    if save:
169
        from modules.utilities import check_bak
170
        check_bak('clusters.png')
171
        plt.savefig('clusters.png')
172
        plt.close("all")
173
    else:
174
        plt.show()
175

    
176

    
177
def clustering(data, plot=False, x=None, y=None):
178
    """Directs the clustering process by calling the relevant functions.
179

180
    @param data: The data to be clustered. It must be stored in vector form
181
    [n_features, n_samples] or in affinity matrix form [n_samples, n_samples],
182
    symmetric and 0 in the main diagonal. (Eg. Euclidean distances between
183
    points, RMSD Matrix, etc.).
184
    @param plot: bool, Whether to plot the clustered data.
185
    @param x: Necessary only if plot is turned on. X values to plot the data.
186
    @param y: Necessary only if plot is turned on. Y values to plot the data.
187
    @return: list of exemplars, list of indices (relative to data)
188
    exemplars for every cluster.
189
    """
190
    from collections.abc import Iterable
191

    
192
    data_err = "Data must be stored in vector form [n_features, n_samples] or" \
193
               "in affinity matrix form [n_samples, n_samples]: symmetric " \
194
               "and 0 in the main diagonal. Eg. RMSD matrix"
195
    debug_err = "On debug mode x and y should be provided"
196

    
197
    if plot and not (isinstance(x, Iterable) and isinstance(y, Iterable)):
198
        logger.error(debug_err)
199
        raise ValueError(debug_err)
200
    if not isinstance(data, np.ndarray):
201
        data = np.array(data)
202
    if len(data.shape) != 2:
203
        logger.error(data_err)
204
        raise ValueError(data_err)
205

    
206
    if data.shape[0] == data.shape[1] \
207
            and (np.tril(data).T == np.triu(data)).all():
208
        logger.info("Clustering using affinity matrix.")
209
        labels = get_labels_affty(data)
210
        if max(labels) == -1:
211
            logger.warning('Clustering of conformers did not converge. Try '
212
                           "setting a smaller 'min_samples' parameter.")
213
            exemplars = list(range(data.shape[0]))
214
        else:
215
            clusters = get_clusters(labels)
216
            exemplars = get_exemplars_affty(data, clusters)
217
        if plot:
218
            plot_clusters(labels, x, y, exemplars, save=True)
219
        logger.info(f'Conformers are grouped in {len(exemplars)} clusters.')
220
        return exemplars
221
    else:
222
        not_impl_err = 'Clustering not yet implemented for vectorized data'
223
        logger.error(not_impl_err)
224
        raise NotImplementedError(not_impl_err)