Statistiques
| Révision :

root / ase / test / neighbor.py @ 13

Historique | Voir | Annoter | Télécharger (1,94 ko)

1
import numpy.random as random
2
import numpy as np
3
from ase import Atoms
4
from ase.calculators.neighborlist import NeighborList
5

    
6
atoms = Atoms(numbers=range(10),
7
              cell=[(0.2, 1.2, 1.4),
8
                    (1.4, 0.1, 1.6),
9
                    (1.3, 2.0, -0.1)])
10
atoms.set_scaled_positions(3 * random.random((10, 3)) - 1)
11

    
12
def count(nl, atoms):
13
    c = np.zeros(len(atoms), int)
14
    R = atoms.get_positions()
15
    cell = atoms.get_cell()
16
    d = 0.0
17
    for a in range(len(atoms)):
18
        i, offsets = nl.get_neighbors(a)
19
        for j in i:
20
            c[j] += 1
21
        c[a] += len(i)
22
        d += (((R[i] + np.dot(offsets, cell) - R[a])**2).sum(1)**0.5).sum()
23
    return d, c
24

    
25
for sorted in [False, True]:
26
    for p1 in range(2):
27
        for p2 in range(2):
28
            for p3 in range(2):
29
                print p1, p2, p3
30
                atoms.set_pbc((p1, p2, p3))
31
                nl = NeighborList(atoms.numbers * 0.2 + 0.5,
32
                                  skin=0.0, sorted=sorted)
33
                nl.update(atoms)
34
                d, c = count(nl, atoms)
35
                atoms2 = atoms.repeat((p1 + 1, p2 + 1, p3 + 1))
36
                nl2 = NeighborList(atoms2.numbers * 0.2 + 0.5,
37
                                   skin=0.0, sorted=sorted)
38
                nl2.update(atoms2)
39
                d2, c2 = count(nl2, atoms2)
40
                c2.shape = (-1, 10)
41
                dd = d * (p1 + 1) * (p2 + 1) * (p3 + 1) - d2
42
                print dd
43
                print c2 - c
44
                assert abs(dd) < 1e-10
45
                assert not (c2 - c).any()
46

    
47
h2 = Atoms('H2', positions=[(0, 0, 0), (0, 0, 1)])
48
nl = NeighborList([0.5, 0.5], skin=0.1, sorted=True, self_interaction=False)
49
assert nl.update(h2)
50
assert not nl.update(h2)
51
assert (nl.get_neighbors(0)[0] == [1]).all()
52

    
53
h2[1].z += 0.09
54
assert not nl.update(h2)
55
assert (nl.get_neighbors(0)[0] == [1]).all()
56

    
57
h2[1].z += 0.09
58
assert nl.update(h2)
59
assert (nl.get_neighbors(0)[0] == []).all()
60
assert nl.nupdates == 2