dockonsurf / modules / clustering.py @ 5412d6ef
Historique | Voir | Annoter | Télécharger (8,74 ko)
1 |
"""Functions to cluster structures.
|
---|---|
2 |
|
3 |
functions:
|
4 |
get_labels_affty: Clusters data in affinity matrix form by assigning labels to
|
5 |
data points.
|
6 |
get_labels_vector: Clusters data in vectorial form by assigning labels to
|
7 |
data points.
|
8 |
get_clusters: Groups data-points belonging to the same cluster into arrays of
|
9 |
indices.
|
10 |
get_exemplars_affty: Computes the exemplars for every cluster and returns a list
|
11 |
of indices.
|
12 |
plot_clusters: Plots the clustered data casting a color to every cluster.
|
13 |
clustering: Directs the clustering process by calling the relevant functions.
|
14 |
"""
|
15 |
import logging |
16 |
import warnings |
17 |
|
18 |
import hdbscan |
19 |
import numpy as np |
20 |
|
21 |
logger = logging.getLogger('DockOnSurf')
|
22 |
|
23 |
|
24 |
def get_rmsd(mol_list: list, remove_Hs="c"): |
25 |
"""Computes the rmsd matrix of the conformers in a rdkit mol object.
|
26 |
|
27 |
@param mol_list: list of rdkit mol objects containing the conformers.
|
28 |
@param remove_Hs: bool or str,
|
29 |
@return rmsd_matrix: Matrix containing the rmsd values of every pair of
|
30 |
conformers.
|
31 |
|
32 |
The RMSD values of every pair of conformers is computed, stored in matrix
|
33 |
form and returned back. The calculation of rmsd values can take into
|
34 |
account all hydrogens, none, or only the ones not linked to carbon atoms.
|
35 |
"""
|
36 |
import rdkit.Chem.AllChem as Chem |
37 |
|
38 |
if len(mol_list) < 2: |
39 |
err = "The provided molecule has less than 2 conformers"
|
40 |
logger.error(err) |
41 |
raise ValueError(err) |
42 |
|
43 |
if not remove_Hs: |
44 |
pass
|
45 |
elif remove_Hs or remove_Hs.lower() == "all": |
46 |
mol_list = [Chem.RemoveHs(mol) for mol in mol_list] |
47 |
elif remove_Hs.lower() == "c": |
48 |
from isolated import remove_C_linked_Hs |
49 |
mol_list = [remove_C_linked_Hs(mol) for mol in mol_list] |
50 |
else:
|
51 |
err = "remove_Hs value does not have an acceptable value"
|
52 |
logger.error(err) |
53 |
raise ValueError(err) |
54 |
|
55 |
num_confs = len(mol_list)
|
56 |
conf_ids = list(range(num_confs)) |
57 |
rmsd_mtx = np.zeros((num_confs, num_confs)) |
58 |
for id1 in conf_ids: # TODO reduce RMSD precision |
59 |
for id2 in conf_ids[id1 + 1:]: |
60 |
rmsd = Chem.GetBestRMS(mol_list[id1], mol_list[id2]) |
61 |
rmsd_mtx[id1][id2] = rmsd |
62 |
rmsd_mtx[id2][id1] = rmsd |
63 |
|
64 |
return rmsd_mtx
|
65 |
|
66 |
|
67 |
def get_labels_affty(affty_mtx, kind="rmsd"): |
68 |
"""Clusters data in affinity matrix form by assigning labels to data points.
|
69 |
|
70 |
@param affty_mtx: Data to be clustered, it must be an affinity matrix.
|
71 |
(Eg. Euclidean distances between points, RMSD Matrix, etc.).
|
72 |
Shape: [n_points, n_points]
|
73 |
@param kind: Which kind of data the affinity matrix contains.
|
74 |
@return: list of cluster labels. Every data point is assigned a number
|
75 |
corresponding to the cluster it belongs to.
|
76 |
"""
|
77 |
if np.average(affty_mtx) < 1e-3 and kind == "rmsd": |
78 |
sing_clust = True
|
79 |
min_size = int(len(affty_mtx) / 2) |
80 |
else:
|
81 |
sing_clust = False
|
82 |
min_size = 20
|
83 |
hdbs = hdbscan.HDBSCAN(metric="precomputed",
|
84 |
min_samples=5,
|
85 |
min_cluster_size=min_size, |
86 |
allow_single_cluster=sing_clust) |
87 |
return hdbs.fit_predict(affty_mtx)
|
88 |
|
89 |
|
90 |
def get_labels_vector(): |
91 |
"""Clusters data in vectorial form by assigning labels to data points.
|
92 |
|
93 |
@return: list of cluster labels. Every data point is assigned a number
|
94 |
corresponding to the cluster it belongs to.
|
95 |
"""
|
96 |
return []
|
97 |
|
98 |
|
99 |
def get_clusters(labels): |
100 |
"""Groups data-points belonging to the same cluster into arrays of indices.
|
101 |
|
102 |
@param labels: list of cluster labels (numbers) corresponding to the cluster
|
103 |
it belongs to.
|
104 |
@return: tuple of arrays. Every array contains the indices (relative to the
|
105 |
labels list) of the data points belonging to the same cluster.
|
106 |
"""
|
107 |
n_clusters = max(labels) + 1 |
108 |
return tuple(np.where(labels == clust_num)[0] |
109 |
for clust_num in range(n_clusters)) |
110 |
|
111 |
|
112 |
def get_exemplars_affty(affty_mtx, clusters): |
113 |
"""Computes the exemplars for every cluster and returns a list of indices.
|
114 |
|
115 |
@param affty_mtx: Data structured in form of affinity matrix. eg. Euclidean
|
116 |
distances between points, RMSD Matrix, etc.) shape: [n_points, n_points].
|
117 |
@param clusters: tuple of arrays. Every array contains the indices (relative
|
118 |
to the affinity matrix) of the data points belonging to the same cluster.
|
119 |
@return: list of indices (relative to the affinity matrix) exemplars for
|
120 |
every cluster.
|
121 |
"""
|
122 |
from sklearn.cluster import AffinityPropagation |
123 |
clust_affty_mtcs = tuple(affty_mtx[np.ix_(clust, clust)]
|
124 |
for clust in clusters) |
125 |
exemplars = [] |
126 |
for i, mtx in enumerate(clust_affty_mtcs): |
127 |
pref = -1e6 * np.max(np.abs(mtx))
|
128 |
warnings.filterwarnings("error")
|
129 |
try:
|
130 |
af = AffinityPropagation(affinity='precomputed', preference=pref,
|
131 |
damping=0.95, max_iter=2000, |
132 |
random_state=None).fit(mtx)
|
133 |
except UserWarning as w: |
134 |
logger.warning(str(w))
|
135 |
exemplars.append(clusters[i][af.cluster_centers_indices_[0]])
|
136 |
return exemplars
|
137 |
|
138 |
|
139 |
def plot_clusters(labels, x, y, exemplars=None, save=True): |
140 |
"""Plots the clustered data casting a color to every cluster.
|
141 |
|
142 |
@param labels: list of cluster labels (numbers) corresponding to the cluster
|
143 |
it belongs to.
|
144 |
@param x: list of data of the x axis.
|
145 |
@param y: list of data of the y axis.
|
146 |
@param exemplars: list of data point indices (relative to the labels list)
|
147 |
considered as cluster exemplars.
|
148 |
@param save: bool, Whether to save the generated plot into a file or not.
|
149 |
(in the latter case the plot is shown in a new window)
|
150 |
"""
|
151 |
import matplotlib.pyplot as plt |
152 |
from matplotlib import cm, colors |
153 |
|
154 |
n_clusters = max(labels) + 1 |
155 |
rb = cm.get_cmap('gist_rainbow', max(n_clusters, 1)) |
156 |
rb.set_under() |
157 |
plt.figure(figsize=(10, 8)) |
158 |
for i in range(len(labels)): |
159 |
plt.plot(x[i], y[i], c=rb(labels[i]), marker='.')
|
160 |
if len(exemplars) > 0 and i == exemplars[labels[i]]: |
161 |
plt.plot(x[i], y[i], c=rb(labels[i]), marker="x",
|
162 |
markersize=15,
|
163 |
label=f"Exemplar cluster {labels[i]}")
|
164 |
plt.title(f'Found {n_clusters} Clusters.')
|
165 |
plt.xlabel("Energy")
|
166 |
plt.ylabel("MOI")
|
167 |
plt.legend() |
168 |
|
169 |
bounds = list(range(max(n_clusters, 1))) |
170 |
norm = colors.Normalize(vmin=min(labels), vmax=max(labels)) |
171 |
plt.colorbar(cm.ScalarMappable(norm=norm, cmap=rb), ticks=bounds) |
172 |
if save:
|
173 |
from modules.utilities import check_bak |
174 |
check_bak('clusters.png')
|
175 |
plt.savefig('clusters.png')
|
176 |
plt.close("all")
|
177 |
else:
|
178 |
plt.show() |
179 |
|
180 |
|
181 |
def clustering(data, plot=False, x=None, y=None): |
182 |
"""Directs the clustering process by calling the relevant functions.
|
183 |
|
184 |
@param data: The data to be clustered. It must be stored in vector form
|
185 |
[n_features, n_samples] or in affinity matrix form [n_samples, n_samples],
|
186 |
symmetric and 0 in the main diagonal. (Eg. Euclidean distances between
|
187 |
points, RMSD Matrix, etc.).
|
188 |
@param plot: bool, Whether to plot the clustered data.
|
189 |
@param x: Necessary only if plot is turned on. X values to plot the data.
|
190 |
@param y: Necessary only if plot is turned on. Y values to plot the data.
|
191 |
@return: list of exemplars, list of indices (relative to data)
|
192 |
exemplars for every cluster.
|
193 |
"""
|
194 |
from collections.abc import Iterable |
195 |
|
196 |
data_err = "Data must be stored in vector form [n_features, n_samples] or" \
|
197 |
"in affinity matrix form [n_samples, n_samples]: symmetric " \
|
198 |
"and 0 in the main diagonal. Eg. RMSD matrix"
|
199 |
debug_err = "On debug mode x and y should be provided"
|
200 |
|
201 |
if plot and not (isinstance(x, Iterable) and isinstance(y, Iterable)): |
202 |
logger.error(debug_err) |
203 |
raise ValueError(debug_err) |
204 |
if not isinstance(data, np.ndarray): |
205 |
data = np.array(data) |
206 |
if len(data.shape) != 2: |
207 |
logger.error(data_err) |
208 |
raise ValueError(data_err) |
209 |
|
210 |
if data.shape[0] == data.shape[1] \ |
211 |
and (np.tril(data).T == np.triu(data)).all():
|
212 |
logger.info("Clustering using affinity matrix.")
|
213 |
labels = get_labels_affty(data) |
214 |
if max(labels) == -1: |
215 |
logger.warning('Clustering of conformers did not converge. Try '
|
216 |
"setting a smaller 'min_samples' parameter.")
|
217 |
exemplars = list(range(data.shape[0])) |
218 |
else:
|
219 |
clusters = get_clusters(labels) |
220 |
exemplars = get_exemplars_affty(data, clusters) |
221 |
if plot:
|
222 |
plot_clusters(labels, x, y, exemplars, save=True)
|
223 |
logger.info(f'Conformers are grouped in {len(exemplars)} clusters.')
|
224 |
return exemplars
|
225 |
else:
|
226 |
not_impl_err = 'Clustering not yet implemented for vectorized data'
|
227 |
logger.error(not_impl_err) |
228 |
raise NotImplementedError(not_impl_err) |