Statistiques
| Branche: | Tag: | Révision :

dockonsurf / modules / calculate_rmsd @ 86112fec

Historique | Voir | Annoter | Télécharger (27,71 ko)

1
#!/usr/bin/env python
2
__doc__ = \
3
"""
4
Calculate Root-mean-square deviation (RMSD) between structure A and B, in XYZ
5
or PDB format, using transformation and rotation.
6

    
7
For more information, usage, example and citation read more at
8
https://github.com/charnley/rmsd
9
"""
10

    
11
__version__ = '1.3.2'
12

    
13
import copy
14
import re
15

    
16
import numpy as np
17
from scipy.optimize import linear_sum_assignment
18
from scipy.spatial.distance import cdist
19

    
20

    
21
AXIS_SWAPS = np.array([
22
    [0, 1, 2],
23
    [0, 2, 1],
24
    [1, 0, 2],
25
    [1, 2, 0],
26
    [2, 1, 0],
27
    [2, 0, 1]])
28

    
29
AXIS_REFLECTIONS = np.array([
30
    [1, 1, 1],
31
    [-1, 1, 1],
32
    [1, -1, 1],
33
    [1, 1, -1],
34
    [-1, -1, 1],
35
    [-1, 1, -1],
36
    [1, -1, -1],
37
    [-1, -1, -1]])
38

    
39

    
40
def rmsd(V, W):
41
    """
42
    Calculate Root-mean-square deviation from two sets of vectors V and W.
43

    
44
    Parameters
45
    ----------
46
    V : array
47
        (N,D) matrix, where N is points and D is dimension.
48
    W : array
49
        (N,D) matrix, where N is points and D is dimension.
50

    
51
    Returns
52
    -------
53
    rmsd : float
54
        Root-mean-square deviation between the two vectors
55
    """
56
    D = len(V[0])
57
    N = len(V)
58
    result = 0.0
59
    for v, w in zip(V, W):
60
        result += sum([(v[i] - w[i])**2.0 for i in range(D)])
61
    return np.sqrt(result/N)
62

    
63

    
64
def kabsch_rmsd(P, Q, translate=False):
65
    """
66
    Rotate matrix P unto Q using Kabsch algorithm and calculate the RMSD.
67

    
68
    Parameters
69
    ----------
70
    P : array
71
        (N,D) matrix, where N is points and D is dimension.
72
    Q : array
73
        (N,D) matrix, where N is points and D is dimension.
74
    translate : bool
75
        Use centroids to translate vector P and Q unto each other.
76

    
77
    Returns
78
    -------
79
    rmsd : float
80
        root-mean squared deviation
81
    """
82
    if translate:
83
        Q = Q - centroid(Q)
84
        P = P - centroid(P)
85

    
86
    P = kabsch_rotate(P, Q)
87
    return rmsd(P, Q)
88

    
89

    
90
def kabsch_rotate(P, Q):
91
    """
92
    Rotate matrix P unto matrix Q using Kabsch algorithm.
93

    
94
    Parameters
95
    ----------
96
    P : array
97
        (N,D) matrix, where N is points and D is dimension.
98
    Q : array
99
        (N,D) matrix, where N is points and D is dimension.
100

    
101
    Returns
102
    -------
103
    P : array
104
        (N,D) matrix, where N is points and D is dimension,
105
        rotated
106

    
107
    """
108
    U = kabsch(P, Q)
109

    
110
    # Rotate P
111
    P = np.dot(P, U)
112
    return P
113

    
114

    
115
def kabsch(P, Q):
116
    """
117
    Using the Kabsch algorithm with two sets of paired point P and Q, centered
118
    around the centroid. Each vector set is represented as an NxD
119
    matrix, where D is the the dimension of the space.
120

    
121
    The algorithm works in three steps:
122
    - a centroid translation of P and Q (assumed done before this function
123
      call)
124
    - the computation of a covariance matrix C
125
    - computation of the optimal rotation matrix U
126

    
127
    For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
128

    
129
    Parameters
130
    ----------
131
    P : array
132
        (N,D) matrix, where N is points and D is dimension.
133
    Q : array
134
        (N,D) matrix, where N is points and D is dimension.
135

    
136
    Returns
137
    -------
138
    U : matrix
139
        Rotation matrix (D,D)
140
    """
141

    
142
    # Computation of the covariance matrix
143
    C = np.dot(np.transpose(P), Q)
144

    
145
    # Computation of the optimal rotation matrix
146
    # This can be done using singular value decomposition (SVD)
147
    # Getting the sign of the det(V)*(W) to decide
148
    # whether we need to correct our rotation matrix to ensure a
149
    # right-handed coordinate system.
150
    # And finally calculating the optimal rotation matrix U
151
    # see http://en.wikipedia.org/wiki/Kabsch_algorithm
152
    V, S, W = np.linalg.svd(C)
153
    d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
154

    
155
    if d:
156
        S[-1] = -S[-1]
157
        V[:, -1] = -V[:, -1]
158

    
159
    # Create Rotation matrix U
160
    U = np.dot(V, W)
161

    
162
    return U
163

    
164

    
165
def quaternion_rmsd(P, Q):
166
    """
167
    Rotate matrix P unto Q and calculate the RMSD
168
    based on doi:10.1016/1049-9660(91)90036-O
169

    
170
    Parameters
171
    ----------
172
    P : array
173
        (N,D) matrix, where N is points and D is dimension.
174
    Q : array
175
        (N,D) matrix, where N is points and D is dimension.
176

    
177
    Returns
178
    -------
179
    rmsd : float
180
    """
181
    rot = quaternion_rotate(P, Q)
182
    P = np.dot(P, rot)
183
    return rmsd(P, Q)
184

    
185

    
186
def quaternion_transform(r):
187
    """
188
    Get optimal rotation
189
    note: translation will be zero when the centroids of each molecule are the
190
    same
191
    """
192
    Wt_r = makeW(*r).T
193
    Q_r = makeQ(*r)
194
    rot = Wt_r.dot(Q_r)[:3, :3]
195
    return rot
196

    
197

    
198
def makeW(r1, r2, r3, r4=0):
199
    """
200
    matrix involved in quaternion rotation
201
    """
202
    W = np.asarray([
203
        [r4, r3, -r2, r1],
204
        [-r3, r4, r1, r2],
205
        [r2, -r1, r4, r3],
206
        [-r1, -r2, -r3, r4]])
207
    return W
208

    
209

    
210
def makeQ(r1, r2, r3, r4=0):
211
    """
212
    matrix involved in quaternion rotation
213
    """
214
    Q = np.asarray([
215
        [r4, -r3, r2, r1],
216
        [r3, r4, -r1, r2],
217
        [-r2, r1, r4, r3],
218
        [-r1, -r2, -r3, r4]])
219
    return Q
220

    
221

    
222
def quaternion_rotate(X, Y):
223
    """
224
    Calculate the rotation
225

    
226
    Parameters
227
    ----------
228
    X : array
229
        (N,D) matrix, where N is points and D is dimension.
230
    Y: array
231
        (N,D) matrix, where N is points and D is dimension.
232

    
233
    Returns
234
    -------
235
    rot : matrix
236
        Rotation matrix (D,D)
237
    """
238
    N = X.shape[0]
239
    W = np.asarray([makeW(*Y[k]) for k in range(N)])
240
    Q = np.asarray([makeQ(*X[k]) for k in range(N)])
241
    Qt_dot_W = np.asarray([np.dot(Q[k].T, W[k]) for k in range(N)])
242
    W_minus_Q = np.asarray([W[k] - Q[k] for k in range(N)])
243
    A = np.sum(Qt_dot_W, axis=0)
244
    eigen = np.linalg.eigh(A)
245
    r = eigen[1][:, eigen[0].argmax()]
246
    rot = quaternion_transform(r)
247
    return rot
248

    
249

    
250
def centroid(X):
251
    """
252
    Centroid is the mean position of all the points in all of the coordinate
253
    directions, from a vectorset X.
254

    
255
    https://en.wikipedia.org/wiki/Centroid
256

    
257
    C = sum(X)/len(X)
258

    
259
    Parameters
260
    ----------
261
    X : array
262
        (N,D) matrix, where N is points and D is dimension.
263

    
264
    Returns
265
    -------
266
    C : float
267
        centroid
268
    """
269
    C = X.mean(axis=0)
270
    return C
271

    
272

    
273
def reorder_distance(p_atoms, q_atoms, p_coord, q_coord):
274
    """
275
    Re-orders the input atom list and xyz coordinates by atom type and then by
276
    distance of each atom from the centroid.
277

    
278
    Parameters
279
    ----------
280
    atoms : array
281
        (N,1) matrix, where N is points holding the atoms' names
282
    coord : array
283
        (N,D) matrix, where N is points and D is dimension
284

    
285
    Returns
286
    -------
287
    atoms_reordered : array
288
        (N,1) matrix, where N is points holding the ordered atoms' names
289
    coords_reordered : array
290
        (N,D) matrix, where N is points and D is dimension (rows re-ordered)
291
    """
292

    
293
    # Find unique atoms
294
    unique_atoms = np.unique(p_atoms)
295

    
296
    # generate full view from q shape to fill in atom view on the fly
297
    view_reorder = np.zeros(q_atoms.shape, dtype=int)
298

    
299
    for atom in unique_atoms:
300

    
301
        p_atom_idx, = np.where(p_atoms == atom)
302
        q_atom_idx, = np.where(q_atoms == atom)
303

    
304
        A_coord = p_coord[p_atom_idx]
305
        B_coord = q_coord[q_atom_idx]
306

    
307
        # Calculate distance from each atom to centroid
308
        A_norms = np.linalg.norm(A_coord, axis=1)
309
        B_norms = np.linalg.norm(B_coord, axis=1)
310

    
311
        reorder_indices_A = np.argsort(A_norms)
312
        reorder_indices_B = np.argsort(B_norms)
313

    
314
        # Project the order of P onto Q
315
        translator = np.argsort(reorder_indices_A)
316
        view = reorder_indices_B[translator]
317
        view_reorder[p_atom_idx] = q_atom_idx[view]
318

    
319
    return view_reorder
320

    
321

    
322
def hungarian(A, B):
323
    """
324
    Hungarian reordering.
325

    
326
    Assume A and B are coordinates for atoms of SAME type only
327
    """
328

    
329
    # should be kabasch here i think
330
    distances = cdist(A, B, 'euclidean')
331

    
332
    # Perform Hungarian analysis on distance matrix between atoms of 1st
333
    # structure and trial structure
334
    indices_a, indices_b = linear_sum_assignment(distances)
335

    
336
    return indices_b
337

    
338

    
339
def reorder_hungarian(p_atoms, q_atoms, p_coord, q_coord):
340
    """
341
    Re-orders the input atom list and xyz coordinates using the Hungarian
342
    method (using optimized column results)
343

    
344
    Parameters
345
    ----------
346
    p_atoms : array
347
        (N,1) matrix, where N is points holding the atoms' names
348
    p_atoms : array
349
        (N,1) matrix, where N is points holding the atoms' names
350
    p_coord : array
351
        (N,D) matrix, where N is points and D is dimension
352
    q_coord : array
353
        (N,D) matrix, where N is points and D is dimension
354

    
355
    Returns
356
    -------
357
    view_reorder : array
358
             (N,1) matrix, reordered indexes of atom alignment based on the
359
             coordinates of the atoms
360

    
361
    """
362

    
363
    # Find unique atoms
364
    unique_atoms = np.unique(p_atoms)
365

    
366
    # generate full view from q shape to fill in atom view on the fly
367
    view_reorder = np.zeros(q_atoms.shape, dtype=int)
368
    view_reorder -= 1
369

    
370
    for atom in unique_atoms:
371
        p_atom_idx, = np.where(p_atoms == atom)
372
        q_atom_idx, = np.where(q_atoms == atom)
373

    
374
        A_coord = p_coord[p_atom_idx]
375
        B_coord = q_coord[q_atom_idx]
376

    
377
        view = hungarian(A_coord, B_coord)
378
        view_reorder[p_atom_idx] = q_atom_idx[view]
379

    
380
    return view_reorder
381

    
382

    
383
def generate_permutations(elements, n):
384
    """
385
    Heap's algorithm for generating all n! permutations in a list
386
    https://en.wikipedia.org/wiki/Heap%27s_algorithm
387

    
388
    """
389
    c = [0] * n
390
    yield elements
391
    i = 0
392
    while i < n:
393
        if c[i] < i:
394
            if i % 2 == 0:
395
                elements[0], elements[i] = elements[i], elements[0]
396
            else:
397
                elements[c[i]], elements[i] = elements[i], elements[c[i]]
398
            yield elements
399
            c[i] += 1
400
            i = 0
401
        else:
402
            c[i] = 0
403
            i += 1
404

    
405

    
406
def brute_permutation(A, B):
407
    """
408
    Re-orders the input atom list and xyz coordinates using the brute force
409
    method of permuting all rows of the input coordinates
410

    
411
    Parameters
412
    ----------
413
    A : array
414
        (N,D) matrix, where N is points and D is dimension
415
    B : array
416
        (N,D) matrix, where N is points and D is dimension
417

    
418
    Returns
419
    -------
420
    view : array
421
        (N,1) matrix, reordered view of B projected to A
422
    """
423

    
424
    rmsd_min = np.inf
425
    view_min = None
426

    
427
    # Sets initial ordering for row indices to [0, 1, 2, ..., len(A)], used in
428
    # brute-force method
429

    
430
    num_atoms = A.shape[0]
431
    initial_order = list(range(num_atoms))
432

    
433
    for reorder_indices in generate_permutations(initial_order, num_atoms):
434

    
435
        # Re-order the atom array and coordinate matrix
436
        coords_ordered = B[reorder_indices]
437

    
438
        # Calculate the RMSD between structure 1 and the Hungarian re-ordered
439
        # structure 2
440
        rmsd_temp = kabsch_rmsd(A, coords_ordered)
441

    
442
        # Replaces the atoms and coordinates with the current structure if the
443
        # RMSD is lower
444
        if rmsd_temp < rmsd_min:
445
            rmsd_min = rmsd_temp
446
            view_min = copy.deepcopy(reorder_indices)
447

    
448
    return view_min
449

    
450

    
451
def reorder_brute(p_atoms, q_atoms, p_coord, q_coord):
452
    """
453
    Re-orders the input atom list and xyz coordinates using all permutation of
454
    rows (using optimized column results)
455

    
456
    Parameters
457
    ----------
458
    p_atoms : array
459
        (N,1) matrix, where N is points holding the atoms' names
460
    q_atoms : array
461
        (N,1) matrix, where N is points holding the atoms' names
462
    p_coord : array
463
        (N,D) matrix, where N is points and D is dimension
464
    q_coord : array
465
        (N,D) matrix, where N is points and D is dimension
466

    
467
    Returns
468
    -------
469
    view_reorder : array
470
        (N,1) matrix, reordered indexes of atom alignment based on the
471
        coordinates of the atoms
472

    
473
    """
474

    
475
    # Find unique atoms
476
    unique_atoms = np.unique(p_atoms)
477

    
478
    # generate full view from q shape to fill in atom view on the fly
479
    view_reorder = np.zeros(q_atoms.shape, dtype=int)
480
    view_reorder -= 1
481

    
482
    for atom in unique_atoms:
483
        p_atom_idx, = np.where(p_atoms == atom)
484
        q_atom_idx, = np.where(q_atoms == atom)
485

    
486
        A_coord = p_coord[p_atom_idx]
487
        B_coord = q_coord[q_atom_idx]
488

    
489
        view = brute_permutation(A_coord, B_coord)
490
        view_reorder[p_atom_idx] = q_atom_idx[view]
491

    
492
    return view_reorder
493

    
494

    
495
def check_reflections(p_atoms, q_atoms, p_coord, q_coord,
496
                      reorder_method=reorder_hungarian,
497
                      rotation_method=kabsch_rmsd,
498
                      keep_stereo=False):
499
    """
500
    Minimize RMSD using reflection planes for molecule P and Q
501

    
502
    Warning: This will affect stereo-chemistry
503

    
504
    Parameters
505
    ----------
506
    p_atoms : array
507
        (N,1) matrix, where N is points holding the atoms' names
508
    q_atoms : array
509
        (N,1) matrix, where N is points holding the atoms' names
510
    p_coord : array
511
        (N,D) matrix, where N is points and D is dimension
512
    q_coord : array
513
        (N,D) matrix, where N is points and D is dimension
514

    
515
    Returns
516
    -------
517
    min_rmsd
518
    min_swap
519
    min_reflection
520
    min_review
521

    
522
    """
523

    
524
    min_rmsd = np.inf
525
    min_swap = None
526
    min_reflection = None
527
    min_review = None
528
    tmp_review = None
529
    swap_mask = [1,-1,-1,1,-1,1]
530
    reflection_mask = [1,-1,-1,-1,1,1,1,-1]
531

    
532
    for swap, i in zip(AXIS_SWAPS, swap_mask):
533
        for reflection, j in zip(AXIS_REFLECTIONS, reflection_mask):
534
            if keep_stereo and  i * j == -1: continue # skip enantiomers
535

    
536
            tmp_atoms = copy.copy(q_atoms)
537
            tmp_coord = copy.deepcopy(q_coord)
538
            tmp_coord = tmp_coord[:, swap]
539
            tmp_coord = np.dot(tmp_coord, np.diag(reflection))
540
            tmp_coord -= centroid(tmp_coord)
541

    
542
            # Reorder
543
            if reorder_method is not None:
544
                tmp_review = reorder_method(p_atoms, tmp_atoms, p_coord, tmp_coord)
545
                tmp_coord = tmp_coord[tmp_review]
546
                tmp_atoms = tmp_atoms[tmp_review]
547

    
548
            # Rotation
549
            if rotation_method is None:
550
                this_rmsd = rmsd(p_coord, tmp_coord)
551
            else:
552
                this_rmsd = rotation_method(p_coord, tmp_coord)
553

    
554
            if this_rmsd < min_rmsd:
555
                min_rmsd = this_rmsd
556
                min_swap = swap
557
                min_reflection = reflection
558
                min_review = tmp_review
559

    
560
    if not (p_atoms == q_atoms[min_review]).all():
561
        print("error: Not aligned")
562
        quit()
563

    
564
    return min_rmsd, min_swap, min_reflection, min_review
565

    
566

    
567
def set_coordinates(atoms, V, title="", decimals=8):
568
    """
569
    Print coordinates V with corresponding atoms to stdout in XYZ format.
570
    Parameters
571
    ----------
572
    atoms : list
573
        List of atomic types
574
    V : array
575
        (N,3) matrix of atomic coordinates
576
    title : string (optional)
577
        Title of molecule
578
    decimals : int (optional)
579
        number of decimals for the coordinates
580

    
581
    Return
582
    ------
583
    output : str
584
        Molecule in XYZ format
585

    
586
    """
587
    N, D = V.shape
588

    
589
    fmt = "{:2s}" + (" {:15."+str(decimals)+"f}")*3
590

    
591
    out = list()
592
    out += [str(N)]
593
    out += [title]
594

    
595
    for i in range(N):
596
        atom = atoms[i]
597
        atom = atom[0].upper() + atom[1:]
598
        out += [fmt.format(atom, V[i, 0], V[i, 1], V[i, 2])]
599

    
600
    return "\n".join(out)
601

    
602

    
603
def print_coordinates(atoms, V, title=""):
604
    """
605
    Print coordinates V with corresponding atoms to stdout in XYZ format.
606

    
607
    Parameters
608
    ----------
609
    atoms : list
610
        List of element types
611
    V : array
612
        (N,3) matrix of atomic coordinates
613
    title : string (optional)
614
        Title of molecule
615

    
616
    """
617

    
618
    print(set_coordinates(atoms, V, title=title))
619

    
620
    return
621

    
622

    
623
def get_coordinates(filename, fmt):
624
    """
625
    Get coordinates from filename in format fmt. Supports XYZ and PDB.
626
    Parameters
627
    ----------
628
    filename : string
629
        Filename to read
630
    fmt : string
631
        Format of filename. Either xyz or pdb.
632
    Returns
633
    -------
634
    atoms : list
635
        List of atomic types
636
    V : array
637
        (N,3) where N is number of atoms
638
    """
639
    if fmt == "xyz":
640
        get_func = get_coordinates_xyz
641
    elif fmt == "pdb":
642
        get_func = get_coordinates_pdb
643
    else:
644
        exit("Could not recognize file format: {:s}".format(fmt))
645

    
646
    return get_func(filename)
647

    
648

    
649
def get_coordinates_pdb(filename):
650
    """
651
    Get coordinates from the first chain in a pdb file
652
    and return a vectorset with all the coordinates.
653

    
654
    Parameters
655
    ----------
656
    filename : string
657
        Filename to read
658

    
659
    Returns
660
    -------
661
    atoms : list
662
        List of atomic types
663
    V : array
664
        (N,3) where N is number of atoms
665
    """
666

    
667
    # PDB files tend to be a bit of a mess. The x, y and z coordinates
668
    # are supposed to be in column 31-38, 39-46 and 47-54, but this is
669
    # not always the case.
670
    # Because of this the three first columns containing a decimal is used.
671
    # Since the format doesn't require a space between columns, we use the
672
    # above column indices as a fallback.
673

    
674
    x_column = None
675
    V = list()
676

    
677
    # Same with atoms and atom naming.
678
    # The most robust way to do this is probably
679
    # to assume that the atomtype is given in column 3.
680

    
681
    atoms = list()
682

    
683
    with open(filename, 'r') as f:
684
        lines = f.readlines()
685
        for line in lines:
686
            if line.startswith("TER") or line.startswith("END"):
687
                break
688
            if line.startswith("ATOM"):
689
                tokens = line.split()
690
                # Try to get the atomtype
691
                try:
692
                    atom = tokens[2][0]
693
                    if atom in ("H", "C", "N", "O", "S", "P"):
694
                        atoms.append(atom)
695
                    else:
696
                        # e.g. 1HD1
697
                        atom = tokens[2][1]
698
                        if atom == "H":
699
                            atoms.append(atom)
700
                        else:
701
                            raise Exception
702
                except:
703
                    exit("error: Parsing atomtype for the following line: \n{0:s}".format(line))
704

    
705
                if x_column == None:
706
                    try:
707
                        # look for x column
708
                        for i, x in enumerate(tokens):
709
                            if "." in x and "." in tokens[i + 1] and "." in tokens[i + 2]:
710
                                x_column = i
711
                                break
712
                    except IndexError:
713
                        exit("error: Parsing coordinates for the following line: \n{0:s}".format(line))
714
                # Try to read the coordinates
715
                try:
716
                    V.append(np.asarray(tokens[x_column:x_column + 3], dtype=float))
717
                except:
718
                    # If that doesn't work, use hardcoded indices
719
                    try:
720
                        x = line[30:38]
721
                        y = line[38:46]
722
                        z = line[46:54]
723
                        V.append(np.asarray([x, y ,z], dtype=float))
724
                    except:
725
                        exit("error: Parsing input for the following line: \n{0:s}".format(line))
726

    
727

    
728
    V = np.asarray(V)
729
    atoms = np.asarray(atoms)
730

    
731
    assert V.shape[0] == atoms.size
732

    
733
    return atoms, V
734

    
735

    
736
def get_coordinates_xyz(filename):
737
    """
738
    Get coordinates from filename and return a vectorset with all the
739
    coordinates, in XYZ format.
740

    
741
    Parameters
742
    ----------
743
    filename : string
744
        Filename to read
745

    
746
    Returns
747
    -------
748
    atoms : list
749
        List of atomic types
750
    V : array
751
        (N,3) where N is number of atoms
752
    """
753

    
754
    f = open(filename, 'r')
755
    V = list()
756
    atoms = list()
757
    n_atoms = 0
758

    
759
    # Read the first line to obtain the number of atoms to read
760
    try:
761
        n_atoms = int(f.readline())
762
    except ValueError:
763
        exit("error: Could not obtain the number of atoms in the .xyz file.")
764

    
765
    # Skip the title line
766
    f.readline()
767

    
768
    # Use the number of atoms to not read beyond the end of a file
769
    for lines_read, line in enumerate(f):
770

    
771
        if lines_read == n_atoms:
772
            break
773

    
774
        atom = re.findall(r'[a-zA-Z]+', line)[0]
775
        atom = atom.upper()
776

    
777
        numbers = re.findall(r'[-]?\d+\.\d*(?:[Ee][-\+]\d+)?', line)
778
        numbers = [float(number) for number in numbers]
779

    
780
        # The numbers are not valid unless we obtain exacly three
781
        if len(numbers) >= 3:
782
            V.append(np.array(numbers)[:3])
783
            atoms.append(atom)
784
        else:
785
            exit("Reading the .xyz file failed in line {0}. Please check the format.".format(lines_read + 2))
786

    
787
    f.close()
788
    atoms = np.array(atoms)
789
    V = np.array(V)
790
    return atoms, V
791

    
792

    
793
def main():
794

    
795
    import argparse
796
    import sys
797

    
798
    description = __doc__
799

    
800
    version_msg = """
801
rmsd {}
802

    
803
See https://github.com/charnley/rmsd for citation information
804

    
805
"""
806
    version_msg = version_msg.format(__version__)
807

    
808
    epilog = """
809
"""
810

    
811
    parser = argparse.ArgumentParser(
812
        usage='calculate_rmsd [options] FILE_A FILE_B',
813
        description=description,
814
        formatter_class=argparse.RawDescriptionHelpFormatter,
815
        epilog=epilog)
816

    
817

    
818
    # Input structures
819
    parser.add_argument('structure_a', metavar='FILE_A', type=str, help='structures in .xyz or .pdb format')
820
    parser.add_argument('structure_b', metavar='FILE_B', type=str)
821

    
822
    # Admin
823
    parser.add_argument('-v', '--version', action='version', version=version_msg)
824

    
825
    # Rotation
826
    parser.add_argument('-r', '--rotation', action='store', default="kabsch", help='select rotation method. "kabsch" (default), "quaternion" or "none"', metavar="METHOD")
827

    
828
    # Reorder arguments
829
    parser.add_argument('-e', '--reorder', action='store_true', help='align the atoms of molecules (default: Hungarian)')
830
    parser.add_argument('--reorder-method', action='store', default="hungarian", metavar="METHOD", help='select which reorder method to use; hungarian (default), brute, distance')
831
    parser.add_argument('--use-reflections', action='store_true', help='scan through reflections in planes (eg Y transformed to -Y -> X, -Y, Z) and axis changes, (eg X and Z coords exchanged -> Z, Y, X). This will affect stereo-chemistry.')
832
    parser.add_argument('--use-reflections-keep-stereo', action='store_true', help='scan through reflections in planes (eg Y transformed to -Y -> X, -Y, Z) and axis changes, (eg X and Z coords exchanged -> Z, Y, X). Stereo-chemistry will be kept.')
833

    
834
    # Filter
835
    index_group = parser.add_mutually_exclusive_group()
836
    index_group.add_argument('-nh', '--no-hydrogen', action='store_true', help='ignore hydrogens when calculating RMSD')
837
    index_group.add_argument('--remove-idx', nargs='+', type=int, help='index list of atoms NOT to consider', metavar='IDX')
838
    index_group.add_argument('--add-idx', nargs='+', type=int, help='index list of atoms to consider', metavar='IDX')
839

    
840
    # format and print
841
    parser.add_argument('--format', action='store', help='format of input files. valid format are xyz and pdb', metavar='FMT')
842
    parser.add_argument('-p', '--output', '--print', action='store_true', help='print out structure B, centered and rotated unto structure A\'s coordinates in XYZ format')
843

    
844
    if len(sys.argv) == 1:
845
        parser.print_help()
846
        sys.exit(1)
847

    
848
    args = parser.parse_args()
849

    
850
    # As default, load the extension as format
851
    if args.format is None:
852
        args.format = args.structure_a.split('.')[-1]
853

    
854
    p_all_atoms, p_all = get_coordinates(args.structure_a, args.format)
855
    q_all_atoms, q_all = get_coordinates(args.structure_b, args.format)
856

    
857
    p_size = p_all.shape[0]
858
    q_size = q_all.shape[0]
859

    
860
    if not p_size == q_size:
861
        print("error: Structures not same size")
862
        quit()
863

    
864
    if np.count_nonzero(p_all_atoms != q_all_atoms) and not args.reorder:
865
        msg = """
866
error: Atoms are not in the same order.
867

    
868
Use --reorder to align the atoms (can be expensive for large structures).
869

    
870
Please see --help or documentation for more information or
871
https://github.com/charnley/rmsd for further examples.
872
"""
873
        print(msg)
874
        exit()
875

    
876

    
877
    # Set local view
878
    p_view = None
879
    q_view = None
880

    
881

    
882
    if args.no_hydrogen:
883
        p_view = np.where(p_all_atoms != 'H')
884
        q_view = np.where(q_all_atoms != 'H')
885

    
886
    elif args.remove_idx:
887
        index = range(p_size)
888
        index = set(index) - set(args.remove_idx)
889
        index = list(index)
890
        p_view = index
891
        q_view = index
892

    
893
    elif args.add_idx:
894
        p_view = args.add_idx
895
        q_view = args.add_idx
896

    
897

    
898
    # Set local view
899
    if p_view is None:
900
        p_coord = copy.deepcopy(p_all)
901
        q_coord = copy.deepcopy(q_all)
902
        p_atoms = copy.deepcopy(p_all_atoms)
903
        q_atoms = copy.deepcopy(q_all_atoms)
904

    
905
    else:
906

    
907
        if args.reorder and args.output:
908
            print("error: Cannot reorder atoms and print structure, when excluding atoms (such as --no-hydrogen)")
909
            quit()
910

    
911
        if args.use_reflections and args.output:
912
            print("error: Cannot use reflections on atoms and print, when excluding atoms (such as --no-hydrogen)")
913
            quit()
914

    
915
        p_coord = copy.deepcopy(p_all[p_view])
916
        q_coord = copy.deepcopy(q_all[q_view])
917
        p_atoms = copy.deepcopy(p_all_atoms[p_view])
918
        q_atoms = copy.deepcopy(q_all_atoms[q_view])
919

    
920

    
921
    # Create the centroid of P and Q which is the geometric center of a
922
    # N-dimensional region and translate P and Q onto that center.
923
    # http://en.wikipedia.org/wiki/Centroid
924
    p_cent = centroid(p_coord)
925
    q_cent = centroid(q_coord)
926
    p_coord -= p_cent
927
    q_coord -= q_cent
928

    
929

    
930
    # set rotation method
931
    if args.rotation.lower() == "kabsch":
932
        rotation_method = kabsch_rmsd
933

    
934
    elif args.rotation.lower() == "quaternion":
935
        rotation_method = quaternion_rmsd
936

    
937
    elif args.rotation.lower() == "none":
938
        rotation_method = None
939

    
940
    else:
941
        print("error: Unknown rotation method:", args.rotation)
942
        quit()
943

    
944

    
945
    # set reorder method
946
    if not args.reorder:
947
        reorder_method = None
948

    
949
    if args.reorder_method == "hungarian":
950
        reorder_method = reorder_hungarian
951

    
952
    elif args.reorder_method == "brute":
953
        reorder_method = reorder_brute
954

    
955
    elif args.reorder_method == "distance":
956
        reorder_method = reorder_distance
957

    
958
    else:
959
        print("error: Unknown reorder method:", args.reorder_method)
960
        quit()
961

    
962

    
963
    # Save the resulting RMSD
964
    result_rmsd = None
965

    
966

    
967
    if args.use_reflections:
968

    
969
        result_rmsd, q_swap, q_reflection, q_review = check_reflections(
970
            p_atoms,
971
            q_atoms,
972
            p_coord,
973
            q_coord,
974
            reorder_method=reorder_method,
975
            rotation_method=rotation_method)
976

    
977
    elif args.use_reflections_keep_stereo:
978

    
979
        result_rmsd, q_swap, q_reflection, q_review = check_reflections(
980
            p_atoms,
981
            q_atoms,
982
            p_coord,
983
            q_coord,
984
            reorder_method=reorder_method,
985
            rotation_method=rotation_method,
986
            keep_stereo=True)
987

    
988
    elif args.reorder:
989

    
990
        q_review = reorder_method(p_atoms, q_atoms, p_coord, q_coord)
991
        q_coord = q_coord[q_review]
992
        q_atoms = q_atoms[q_review]
993

    
994
        if not all(p_atoms == q_atoms):
995
            print("error: Structure not aligned")
996
            quit()
997

    
998

    
999
    # print result
1000
    if args.output:
1001

    
1002
        if args.reorder:
1003

    
1004
            if q_review.shape[0] != q_all.shape[0]:
1005
                print("error: Reorder length error. Full atom list needed for --print")
1006
                quit()
1007

    
1008
            q_all = q_all[q_review]
1009
            q_all_atoms = q_all_atoms[q_review]
1010

    
1011
        # Get rotation matrix
1012
        U = kabsch(q_coord, p_coord)
1013

    
1014
        # recenter all atoms and rotate all atoms
1015
        q_all -= q_cent
1016
        q_all = np.dot(q_all, U)
1017

    
1018
        # center q on p's original coordinates
1019
        q_all += p_cent
1020

    
1021
        # done and done
1022
        xyz = set_coordinates(q_all_atoms, q_all, title="{} - modified".format(args.structure_b))
1023
        print(xyz)
1024

    
1025
    else:
1026

    
1027
        if result_rmsd:
1028
            pass
1029

    
1030
        elif rotation_method is None:
1031
            result_rmsd = rmsd(p_coord, q_coord)
1032

    
1033
        else:
1034
            result_rmsd = rotation_method(p_coord, q_coord)
1035

    
1036
        print("{0}".format(result_rmsd))
1037

    
1038

    
1039
    return
1040

    
1041
if __name__ == "__main__":
1042
    main()