Statistiques
| Branche: | Tag: | Révision :

dockonsurf / modules / clustering.py @ 5412d6ef

Historique | Voir | Annoter | Télécharger (8,74 ko)

1
"""Functions to cluster structures.
2

3
functions:
4
get_labels_affty: Clusters data in affinity matrix form by assigning labels to
5
data points.
6
get_labels_vector: Clusters data in vectorial form by assigning labels to
7
data points.
8
get_clusters: Groups data-points belonging to the same cluster into arrays of
9
indices.
10
get_exemplars_affty: Computes the exemplars for every cluster and returns a list
11
of indices.
12
plot_clusters: Plots the clustered data casting a color to every cluster.
13
clustering: Directs the clustering process by calling the relevant functions.
14
"""
15
import logging
16
import warnings
17

    
18
import hdbscan
19
import numpy as np
20

    
21
logger = logging.getLogger('DockOnSurf')
22

    
23

    
24
def get_rmsd(mol_list: list, remove_Hs="c"):
25
    """Computes the rmsd matrix of the conformers in a rdkit mol object.
26

27
    @param mol_list: list of rdkit mol objects containing the conformers.
28
    @param remove_Hs: bool or str,
29
    @return rmsd_matrix: Matrix containing the rmsd values of every pair of
30
    conformers.
31

32
    The RMSD values of every pair of conformers is computed, stored in matrix
33
    form and returned back. The calculation of rmsd values can take into
34
    account all hydrogens, none, or only the ones not linked to carbon atoms.
35
    """
36
    import rdkit.Chem.AllChem as Chem
37

    
38
    if len(mol_list) < 2:
39
        err = "The provided molecule has less than 2 conformers"
40
        logger.error(err)
41
        raise ValueError(err)
42

    
43
    if not remove_Hs:
44
        pass
45
    elif remove_Hs or remove_Hs.lower() == "all":
46
        mol_list = [Chem.RemoveHs(mol) for mol in mol_list]
47
    elif remove_Hs.lower() == "c":
48
        from isolated import remove_C_linked_Hs
49
        mol_list = [remove_C_linked_Hs(mol) for mol in mol_list]
50
    else:
51
        err = "remove_Hs value does not have an acceptable value"
52
        logger.error(err)
53
        raise ValueError(err)
54

    
55
    num_confs = len(mol_list)
56
    conf_ids = list(range(num_confs))
57
    rmsd_mtx = np.zeros((num_confs, num_confs))
58
    for id1 in conf_ids:  # TODO reduce RMSD precision
59
        for id2 in conf_ids[id1 + 1:]:
60
            rmsd = Chem.GetBestRMS(mol_list[id1], mol_list[id2])
61
            rmsd_mtx[id1][id2] = rmsd
62
            rmsd_mtx[id2][id1] = rmsd
63

    
64
    return rmsd_mtx
65

    
66

    
67
def get_labels_affty(affty_mtx, kind="rmsd"):
68
    """Clusters data in affinity matrix form by assigning labels to data points.
69

70
    @param affty_mtx: Data to be clustered, it must be an affinity matrix.
71
    (Eg. Euclidean distances between points, RMSD Matrix, etc.).
72
    Shape: [n_points, n_points]
73
    @param kind: Which kind of data the affinity matrix contains.
74
    @return: list of cluster labels. Every data point is assigned a number
75
    corresponding to the cluster it belongs to.
76
    """
77
    if np.average(affty_mtx) < 1e-3 and kind == "rmsd":
78
        sing_clust = True
79
        min_size = int(len(affty_mtx) / 2)
80
    else:
81
        sing_clust = False
82
        min_size = 20
83
    hdbs = hdbscan.HDBSCAN(metric="precomputed",
84
                           min_samples=5,
85
                           min_cluster_size=min_size,
86
                           allow_single_cluster=sing_clust)
87
    return hdbs.fit_predict(affty_mtx)
88

    
89

    
90
def get_labels_vector():
91
    """Clusters data in vectorial form by assigning labels to data points.
92

93
    @return: list of cluster labels. Every data point is assigned a number
94
    corresponding to the cluster it belongs to.
95
    """
96
    return []
97

    
98

    
99
def get_clusters(labels):
100
    """Groups data-points belonging to the same cluster into arrays of indices.
101

102
    @param labels: list of cluster labels (numbers) corresponding to the cluster
103
    it belongs to.
104
    @return: tuple of arrays. Every array contains the indices (relative to the
105
    labels list) of the data points belonging to the same cluster.
106
    """
107
    n_clusters = max(labels) + 1
108
    return tuple(np.where(labels == clust_num)[0]
109
                 for clust_num in range(n_clusters))
110

    
111

    
112
def get_exemplars_affty(affty_mtx, clusters):
113
    """Computes the exemplars for every cluster and returns a list of indices.
114

115
    @param affty_mtx: Data structured in form of affinity matrix. eg. Euclidean
116
    distances between points, RMSD Matrix, etc.) shape: [n_points, n_points].
117
    @param clusters: tuple of arrays. Every array contains the indices (relative
118
    to the affinity matrix) of the data points belonging to the same cluster.
119
    @return: list of indices (relative to the affinity matrix) exemplars for
120
    every cluster.
121
    """
122
    from sklearn.cluster import AffinityPropagation
123
    clust_affty_mtcs = tuple(affty_mtx[np.ix_(clust, clust)]
124
                             for clust in clusters)
125
    exemplars = []
126
    for i, mtx in enumerate(clust_affty_mtcs):
127
        pref = -1e6 * np.max(np.abs(mtx))
128
        warnings.filterwarnings("error")
129
        try:
130
            af = AffinityPropagation(affinity='precomputed', preference=pref,
131
                                     damping=0.95, max_iter=2000,
132
                                     random_state=None).fit(mtx)
133
        except UserWarning as w:
134
            logger.warning(str(w))
135
        exemplars.append(clusters[i][af.cluster_centers_indices_[0]])
136
    return exemplars
137

    
138

    
139
def plot_clusters(labels, x, y, exemplars=None, save=True):
140
    """Plots the clustered data casting a color to every cluster.
141

142
    @param labels: list of cluster labels (numbers) corresponding to the cluster
143
    it belongs to.
144
    @param x: list of data of the x axis.
145
    @param y: list of data of the y axis.
146
    @param exemplars: list of data point indices (relative to the labels list)
147
    considered as cluster exemplars.
148
    @param save: bool, Whether to save the generated plot into a file or not.
149
    (in the latter case the plot is shown in a new window)
150
    """
151
    import matplotlib.pyplot as plt
152
    from matplotlib import cm, colors
153

    
154
    n_clusters = max(labels) + 1
155
    rb = cm.get_cmap('gist_rainbow', max(n_clusters, 1))
156
    rb.set_under()
157
    plt.figure(figsize=(10, 8))
158
    for i in range(len(labels)):
159
        plt.plot(x[i], y[i], c=rb(labels[i]), marker='.')
160
        if len(exemplars) > 0 and i == exemplars[labels[i]]:
161
            plt.plot(x[i], y[i], c=rb(labels[i]), marker="x",
162
                     markersize=15,
163
                     label=f"Exemplar cluster {labels[i]}")
164
    plt.title(f'Found {n_clusters} Clusters.')
165
    plt.xlabel("Energy")
166
    plt.ylabel("MOI")
167
    plt.legend()
168

    
169
    bounds = list(range(max(n_clusters, 1)))
170
    norm = colors.Normalize(vmin=min(labels), vmax=max(labels))
171
    plt.colorbar(cm.ScalarMappable(norm=norm, cmap=rb), ticks=bounds)
172
    if save:
173
        from modules.utilities import check_bak
174
        check_bak('clusters.png')
175
        plt.savefig('clusters.png')
176
        plt.close("all")
177
    else:
178
        plt.show()
179

    
180

    
181
def clustering(data, plot=False, x=None, y=None):
182
    """Directs the clustering process by calling the relevant functions.
183

184
    @param data: The data to be clustered. It must be stored in vector form
185
    [n_features, n_samples] or in affinity matrix form [n_samples, n_samples],
186
    symmetric and 0 in the main diagonal. (Eg. Euclidean distances between
187
    points, RMSD Matrix, etc.).
188
    @param plot: bool, Whether to plot the clustered data.
189
    @param x: Necessary only if plot is turned on. X values to plot the data.
190
    @param y: Necessary only if plot is turned on. Y values to plot the data.
191
    @return: list of exemplars, list of indices (relative to data)
192
    exemplars for every cluster.
193
    """
194
    from collections.abc import Iterable
195

    
196
    data_err = "Data must be stored in vector form [n_features, n_samples] or" \
197
               "in affinity matrix form [n_samples, n_samples]: symmetric " \
198
               "and 0 in the main diagonal. Eg. RMSD matrix"
199
    debug_err = "On debug mode x and y should be provided"
200

    
201
    if plot and not (isinstance(x, Iterable) and isinstance(y, Iterable)):
202
        logger.error(debug_err)
203
        raise ValueError(debug_err)
204
    if not isinstance(data, np.ndarray):
205
        data = np.array(data)
206
    if len(data.shape) != 2:
207
        logger.error(data_err)
208
        raise ValueError(data_err)
209

    
210
    if data.shape[0] == data.shape[1] \
211
            and (np.tril(data).T == np.triu(data)).all():
212
        logger.info("Clustering using affinity matrix.")
213
        labels = get_labels_affty(data)
214
        if max(labels) == -1:
215
            logger.warning('Clustering of conformers did not converge. Try '
216
                           "setting a smaller 'min_samples' parameter.")
217
            exemplars = list(range(data.shape[0]))
218
        else:
219
            clusters = get_clusters(labels)
220
            exemplars = get_exemplars_affty(data, clusters)
221
        if plot:
222
            plot_clusters(labels, x, y, exemplars, save=True)
223
        logger.info(f'Conformers are grouped in {len(exemplars)} clusters.')
224
        return exemplars
225
    else:
226
        not_impl_err = 'Clustering not yet implemented for vectorized data'
227
        logger.error(not_impl_err)
228
        raise NotImplementedError(not_impl_err)