dockonsurf / modules / clustering.py @ 082685ad
Historique | Voir | Annoter | Télécharger (8,6 ko)
1 | 75325cc3 | Carles | """Functions to cluster structures.
|
---|---|---|---|
2 | 75325cc3 | Carles |
|
3 | 75325cc3 | Carles | functions:
|
4 | 75325cc3 | Carles | get_labels_affty: Clusters data in affinity matrix form by assigning labels to
|
5 | 75325cc3 | Carles | data points.
|
6 | 75325cc3 | Carles | get_labels_vector: Clusters data in vectorial form by assigning labels to
|
7 | 75325cc3 | Carles | data points.
|
8 | 75325cc3 | Carles | get_clusters: Groups data-points belonging to the same cluster into arrays of
|
9 | 75325cc3 | Carles | indices.
|
10 | 75325cc3 | Carles | get_exemplars_affty: Computes the exemplars for every cluster and returns a list
|
11 | 75325cc3 | Carles | of indices.
|
12 | 75325cc3 | Carles | plot_clusters: Plots the clustered data casting a color to every cluster.
|
13 | 75325cc3 | Carles | clustering: Directs the clustering process by calling the relevant functions.
|
14 | 75325cc3 | Carles | """
|
15 | 75325cc3 | Carles | import logging |
16 | 75325cc3 | Carles | |
17 | 75325cc3 | Carles | import hdbscan |
18 | 75325cc3 | Carles | import numpy as np |
19 | 75325cc3 | Carles | |
20 | 75325cc3 | Carles | logger = logging.getLogger('DockOnSurf')
|
21 | 75325cc3 | Carles | |
22 | 75325cc3 | Carles | |
23 | fff3aab1 | Carles | def get_rmsd(mol_list: list, remove_Hs="c"): |
24 | fff3aab1 | Carles | """Computes the rmsd matrix of the conformers in a rdkit mol object.
|
25 | fff3aab1 | Carles |
|
26 | fff3aab1 | Carles | @param mol_list: list of rdkit mol objects containing the conformers.
|
27 | fff3aab1 | Carles | @param remove_Hs: bool or str,
|
28 | fff3aab1 | Carles | @return rmsd_matrix: Matrix containing the rmsd values of every pair of
|
29 | fff3aab1 | Carles | conformers.
|
30 | fff3aab1 | Carles |
|
31 | fff3aab1 | Carles | The RMSD values of every pair of conformers is computed, stored in matrix
|
32 | fff3aab1 | Carles | form and returned back. The calculation of rmsd values can take into
|
33 | fff3aab1 | Carles | account all hydrogens, none, or only the ones not linked to carbon atoms.
|
34 | fff3aab1 | Carles | """
|
35 | fff3aab1 | Carles | import rdkit.Chem.AllChem as Chem |
36 | fff3aab1 | Carles | |
37 | fff3aab1 | Carles | if len(mol_list) < 2: |
38 | fff3aab1 | Carles | err = "The provided molecule has less than 2 conformers"
|
39 | fff3aab1 | Carles | logger.error(err) |
40 | fff3aab1 | Carles | raise ValueError(err) |
41 | fff3aab1 | Carles | |
42 | fff3aab1 | Carles | if not remove_Hs: |
43 | fff3aab1 | Carles | pass
|
44 | fff3aab1 | Carles | elif remove_Hs or remove_Hs.lower() == "all": |
45 | fff3aab1 | Carles | mol_list = [Chem.RemoveHs(mol) for mol in mol_list] |
46 | fff3aab1 | Carles | elif remove_Hs.lower() == "c": |
47 | c492296f | Carles Martí | from modules.isolated import remove_C_linked_Hs |
48 | fff3aab1 | Carles | mol_list = [remove_C_linked_Hs(mol) for mol in mol_list] |
49 | fff3aab1 | Carles | else:
|
50 | fff3aab1 | Carles | err = "remove_Hs value does not have an acceptable value"
|
51 | fff3aab1 | Carles | logger.error(err) |
52 | fff3aab1 | Carles | raise ValueError(err) |
53 | fff3aab1 | Carles | |
54 | fff3aab1 | Carles | num_confs = len(mol_list)
|
55 | fff3aab1 | Carles | conf_ids = list(range(num_confs)) |
56 | fff3aab1 | Carles | rmsd_mtx = np.zeros((num_confs, num_confs)) |
57 | f8c4eafe | Carles | for id1 in conf_ids: # TODO reduce RMSD precision |
58 | fff3aab1 | Carles | for id2 in conf_ids[id1 + 1:]: |
59 | fff3aab1 | Carles | rmsd = Chem.GetBestRMS(mol_list[id1], mol_list[id2]) |
60 | fff3aab1 | Carles | rmsd_mtx[id1][id2] = rmsd |
61 | fff3aab1 | Carles | rmsd_mtx[id2][id1] = rmsd |
62 | fff3aab1 | Carles | |
63 | fff3aab1 | Carles | return rmsd_mtx
|
64 | fff3aab1 | Carles | |
65 | fff3aab1 | Carles | |
66 | 75325cc3 | Carles | def get_labels_affty(affty_mtx, kind="rmsd"): |
67 | 75325cc3 | Carles | """Clusters data in affinity matrix form by assigning labels to data points.
|
68 | 75325cc3 | Carles |
|
69 | 75325cc3 | Carles | @param affty_mtx: Data to be clustered, it must be an affinity matrix.
|
70 | 75325cc3 | Carles | (Eg. Euclidean distances between points, RMSD Matrix, etc.).
|
71 | 75325cc3 | Carles | Shape: [n_points, n_points]
|
72 | 75325cc3 | Carles | @param kind: Which kind of data the affinity matrix contains.
|
73 | 75325cc3 | Carles | @return: list of cluster labels. Every data point is assigned a number
|
74 | 75325cc3 | Carles | corresponding to the cluster it belongs to.
|
75 | 75325cc3 | Carles | """
|
76 | 75325cc3 | Carles | if np.average(affty_mtx) < 1e-3 and kind == "rmsd": |
77 | 75325cc3 | Carles | sing_clust = True
|
78 | 75325cc3 | Carles | min_size = int(len(affty_mtx) / 2) |
79 | 75325cc3 | Carles | else:
|
80 | 75325cc3 | Carles | sing_clust = False
|
81 | 75325cc3 | Carles | min_size = 20
|
82 | 75325cc3 | Carles | hdbs = hdbscan.HDBSCAN(metric="precomputed",
|
83 | 75325cc3 | Carles | min_samples=5,
|
84 | 75325cc3 | Carles | min_cluster_size=min_size, |
85 | 75325cc3 | Carles | allow_single_cluster=sing_clust) |
86 | 75325cc3 | Carles | return hdbs.fit_predict(affty_mtx)
|
87 | 75325cc3 | Carles | |
88 | 75325cc3 | Carles | |
89 | 75325cc3 | Carles | def get_labels_vector(): |
90 | 75325cc3 | Carles | """Clusters data in vectorial form by assigning labels to data points.
|
91 | 75325cc3 | Carles |
|
92 | 75325cc3 | Carles | @return: list of cluster labels. Every data point is assigned a number
|
93 | 75325cc3 | Carles | corresponding to the cluster it belongs to.
|
94 | 75325cc3 | Carles | """
|
95 | 75325cc3 | Carles | return []
|
96 | 75325cc3 | Carles | |
97 | 75325cc3 | Carles | |
98 | 75325cc3 | Carles | def get_clusters(labels): |
99 | 75325cc3 | Carles | """Groups data-points belonging to the same cluster into arrays of indices.
|
100 | 75325cc3 | Carles |
|
101 | 75325cc3 | Carles | @param labels: list of cluster labels (numbers) corresponding to the cluster
|
102 | 75325cc3 | Carles | it belongs to.
|
103 | 75325cc3 | Carles | @return: tuple of arrays. Every array contains the indices (relative to the
|
104 | 75325cc3 | Carles | labels list) of the data points belonging to the same cluster.
|
105 | 75325cc3 | Carles | """
|
106 | 75325cc3 | Carles | n_clusters = max(labels) + 1 |
107 | 75325cc3 | Carles | return tuple(np.where(labels == clust_num)[0] |
108 | 75325cc3 | Carles | for clust_num in range(n_clusters)) |
109 | 75325cc3 | Carles | |
110 | 75325cc3 | Carles | |
111 | 75325cc3 | Carles | def get_exemplars_affty(affty_mtx, clusters): |
112 | 75325cc3 | Carles | """Computes the exemplars for every cluster and returns a list of indices.
|
113 | 75325cc3 | Carles |
|
114 | 75325cc3 | Carles | @param affty_mtx: Data structured in form of affinity matrix. eg. Euclidean
|
115 | 75325cc3 | Carles | distances between points, RMSD Matrix, etc.) shape: [n_points, n_points].
|
116 | 75325cc3 | Carles | @param clusters: tuple of arrays. Every array contains the indices (relative
|
117 | 75325cc3 | Carles | to the affinity matrix) of the data points belonging to the same cluster.
|
118 | 75325cc3 | Carles | @return: list of indices (relative to the affinity matrix) exemplars for
|
119 | 75325cc3 | Carles | every cluster.
|
120 | 75325cc3 | Carles | """
|
121 | 75325cc3 | Carles | from sklearn.cluster import AffinityPropagation |
122 | 75325cc3 | Carles | clust_affty_mtcs = tuple(affty_mtx[np.ix_(clust, clust)]
|
123 | 75325cc3 | Carles | for clust in clusters) |
124 | 75325cc3 | Carles | exemplars = [] |
125 | 75325cc3 | Carles | for i, mtx in enumerate(clust_affty_mtcs): |
126 | 75325cc3 | Carles | pref = -1e6 * np.max(np.abs(mtx))
|
127 | 78fcb188 | Carles Martí | af = AffinityPropagation(affinity='precomputed', preference=pref,
|
128 | 78fcb188 | Carles Martí | damping=0.95, max_iter=2000, |
129 | 78fcb188 | Carles Martí | random_state=None).fit(mtx)
|
130 | 75325cc3 | Carles | exemplars.append(clusters[i][af.cluster_centers_indices_[0]])
|
131 | 75325cc3 | Carles | return exemplars
|
132 | 75325cc3 | Carles | |
133 | 75325cc3 | Carles | |
134 | 75325cc3 | Carles | def plot_clusters(labels, x, y, exemplars=None, save=True): |
135 | 75325cc3 | Carles | """Plots the clustered data casting a color to every cluster.
|
136 | 75325cc3 | Carles |
|
137 | 75325cc3 | Carles | @param labels: list of cluster labels (numbers) corresponding to the cluster
|
138 | 75325cc3 | Carles | it belongs to.
|
139 | 75325cc3 | Carles | @param x: list of data of the x axis.
|
140 | 75325cc3 | Carles | @param y: list of data of the y axis.
|
141 | 75325cc3 | Carles | @param exemplars: list of data point indices (relative to the labels list)
|
142 | 75325cc3 | Carles | considered as cluster exemplars.
|
143 | 75325cc3 | Carles | @param save: bool, Whether to save the generated plot into a file or not.
|
144 | 75325cc3 | Carles | (in the latter case the plot is shown in a new window)
|
145 | 75325cc3 | Carles | """
|
146 | 75325cc3 | Carles | import matplotlib.pyplot as plt |
147 | 75325cc3 | Carles | from matplotlib import cm, colors |
148 | 75325cc3 | Carles | |
149 | 75325cc3 | Carles | n_clusters = max(labels) + 1 |
150 | 75325cc3 | Carles | rb = cm.get_cmap('gist_rainbow', max(n_clusters, 1)) |
151 | 75325cc3 | Carles | rb.set_under() |
152 | 75325cc3 | Carles | plt.figure(figsize=(10, 8)) |
153 | 75325cc3 | Carles | for i in range(len(labels)): |
154 | 75325cc3 | Carles | plt.plot(x[i], y[i], c=rb(labels[i]), marker='.')
|
155 | f8c4eafe | Carles | if len(exemplars) > 0 and i == exemplars[labels[i]]: |
156 | 75325cc3 | Carles | plt.plot(x[i], y[i], c=rb(labels[i]), marker="x",
|
157 | 75325cc3 | Carles | markersize=15,
|
158 | 75325cc3 | Carles | label=f"Exemplar cluster {labels[i]}")
|
159 | 75325cc3 | Carles | plt.title(f'Found {n_clusters} Clusters.')
|
160 | 75325cc3 | Carles | plt.xlabel("Energy")
|
161 | 75325cc3 | Carles | plt.ylabel("MOI")
|
162 | 75325cc3 | Carles | plt.legend() |
163 | 75325cc3 | Carles | |
164 | 75325cc3 | Carles | bounds = list(range(max(n_clusters, 1))) |
165 | 75325cc3 | Carles | norm = colors.Normalize(vmin=min(labels), vmax=max(labels)) |
166 | 75325cc3 | Carles | plt.colorbar(cm.ScalarMappable(norm=norm, cmap=rb), ticks=bounds) |
167 | 75325cc3 | Carles | if save:
|
168 | 05464650 | Carles Marti | from modules.utilities import check_bak |
169 | 05464650 | Carles Marti | check_bak('clusters.png')
|
170 | 05464650 | Carles Marti | plt.savefig('clusters.png')
|
171 | 75325cc3 | Carles | plt.close("all")
|
172 | 75325cc3 | Carles | else:
|
173 | 75325cc3 | Carles | plt.show() |
174 | 75325cc3 | Carles | |
175 | 75325cc3 | Carles | |
176 | 05464650 | Carles Marti | def clustering(data, plot=False, x=None, y=None): |
177 | 75325cc3 | Carles | """Directs the clustering process by calling the relevant functions.
|
178 | 75325cc3 | Carles |
|
179 | 75325cc3 | Carles | @param data: The data to be clustered. It must be stored in vector form
|
180 | 75325cc3 | Carles | [n_features, n_samples] or in affinity matrix form [n_samples, n_samples],
|
181 | 75325cc3 | Carles | symmetric and 0 in the main diagonal. (Eg. Euclidean distances between
|
182 | 75325cc3 | Carles | points, RMSD Matrix, etc.).
|
183 | 05464650 | Carles Marti | @param plot: bool, Whether to plot the clustered data.
|
184 | 05464650 | Carles Marti | @param x: Necessary only if plot is turned on. X values to plot the data.
|
185 | 05464650 | Carles Marti | @param y: Necessary only if plot is turned on. Y values to plot the data.
|
186 | 75325cc3 | Carles | @return: list of exemplars, list of indices (relative to data)
|
187 | 75325cc3 | Carles | exemplars for every cluster.
|
188 | 75325cc3 | Carles | """
|
189 | 75325cc3 | Carles | from collections.abc import Iterable |
190 | 75325cc3 | Carles | |
191 | 75325cc3 | Carles | data_err = "Data must be stored in vector form [n_features, n_samples] or" \
|
192 | 75325cc3 | Carles | "in affinity matrix form [n_samples, n_samples]: symmetric " \
|
193 | 75325cc3 | Carles | "and 0 in the main diagonal. Eg. RMSD matrix"
|
194 | 75325cc3 | Carles | debug_err = "On debug mode x and y should be provided"
|
195 | 75325cc3 | Carles | |
196 | 05464650 | Carles Marti | if plot and not (isinstance(x, Iterable) and isinstance(y, Iterable)): |
197 | 75325cc3 | Carles | logger.error(debug_err) |
198 | 75325cc3 | Carles | raise ValueError(debug_err) |
199 | 75325cc3 | Carles | if not isinstance(data, np.ndarray): |
200 | 75325cc3 | Carles | data = np.array(data) |
201 | 75325cc3 | Carles | if len(data.shape) != 2: |
202 | 75325cc3 | Carles | logger.error(data_err) |
203 | 75325cc3 | Carles | raise ValueError(data_err) |
204 | 75325cc3 | Carles | |
205 | 75325cc3 | Carles | if data.shape[0] == data.shape[1] \ |
206 | 75325cc3 | Carles | and (np.tril(data).T == np.triu(data)).all():
|
207 | 695dcff8 | Carles Marti | logger.info("Clustering using affinity matrix.")
|
208 | 75325cc3 | Carles | labels = get_labels_affty(data) |
209 | db7349af | Carles | if max(labels) == -1: |
210 | d0939bb2 | Carles Marti | logger.warning('Clustering of conformers did not converge. Try '
|
211 | 695dcff8 | Carles Marti | "setting a smaller 'min_samples' parameter.")
|
212 | d0939bb2 | Carles Marti | exemplars = list(range(data.shape[0])) |
213 | d0939bb2 | Carles Marti | else:
|
214 | d0939bb2 | Carles Marti | clusters = get_clusters(labels) |
215 | d0939bb2 | Carles Marti | exemplars = get_exemplars_affty(data, clusters) |
216 | 05464650 | Carles Marti | if plot:
|
217 | 75325cc3 | Carles | plot_clusters(labels, x, y, exemplars, save=True)
|
218 | 680944e4 | Carles | logger.info(f'Conformers are grouped in {len(exemplars)} clusters.')
|
219 | 75325cc3 | Carles | return exemplars
|
220 | 75325cc3 | Carles | else:
|
221 | f8c4eafe | Carles | not_impl_err = 'Clustering not yet implemented for vectorized data'
|
222 | f8c4eafe | Carles | logger.error(not_impl_err) |
223 | f8c4eafe | Carles | raise NotImplementedError(not_impl_err) |