root / ase / dft / stm.py @ 1
Historique | Voir | Annoter | Télécharger (4,71 ko)
1 |
from math import exp, sqrt |
---|---|
2 |
|
3 |
import numpy as np |
4 |
|
5 |
from ase.atoms import Atoms |
6 |
|
7 |
|
8 |
class STM: |
9 |
def __init__(self, atoms, symmetries=None): |
10 |
if isinstance(atoms, Atoms): |
11 |
calc = atoms.get_calculator() |
12 |
else:
|
13 |
calc = atoms |
14 |
atoms = calc.get_atoms() |
15 |
self.nbands = calc.get_number_of_bands()
|
16 |
self.weights = calc.get_k_point_weights()
|
17 |
self.nkpts = len(self.weights) |
18 |
self.nspins = calc.get_number_of_spins()
|
19 |
self.eigs = np.array([[calc.get_eigenvalues(k, s)
|
20 |
for k in range(self.nkpts)] |
21 |
for s in range(self.nspins)]) |
22 |
self.eigs -= calc.get_fermi_level()
|
23 |
self.calc = calc
|
24 |
self.cell = atoms.get_cell()
|
25 |
assert not self.cell[2, :2].any() and not self.cell[:2, 2].any() |
26 |
self.ldos = None |
27 |
self.symmetries = symmetries or [] |
28 |
|
29 |
def calculate_ldos(self, width=None): |
30 |
if self.ldos is not None and width == self.width: |
31 |
return
|
32 |
|
33 |
if width is None: |
34 |
width = 0.1
|
35 |
|
36 |
ldos = None
|
37 |
for s in range(self.nspins): |
38 |
for k in range(self.nkpts): |
39 |
for n in range(self.nbands): |
40 |
psi = self.calc.get_pseudo_wave_function(n, k, s)
|
41 |
if ldos is None: |
42 |
ldos = np.zeros_like(psi) |
43 |
f = (exp(-(self.eigs[s, k, n] / width)**2) * |
44 |
self.weights[k])
|
45 |
ldos += f * (psi * np.conj(psi)).real |
46 |
|
47 |
if 0 in self.symmetries: |
48 |
# (x,y) -> (-x,y)
|
49 |
ldos[1:] += ldos[:0:-1].copy() |
50 |
ldos[1:] *= 0.5 |
51 |
|
52 |
if 1 in self.symmetries: |
53 |
# (x,y) -> (x,-y)
|
54 |
ldos[:, 1:] += ldos[:, :0:-1].copy() |
55 |
ldos[:, 1:] *= 0.5 |
56 |
|
57 |
if 2 in self.symmetries: |
58 |
# (x,y) -> (y,x)
|
59 |
ldos += ldos.transpose((1, 0, 2)).copy() |
60 |
ldos *= 0.5
|
61 |
|
62 |
self.ldos = ldos
|
63 |
self.width = width
|
64 |
|
65 |
#def save_ldos(self, filename='ldos.pckl'):
|
66 |
|
67 |
|
68 |
def get_averaged_current(self, z, width=None): |
69 |
self.calculate_ldos(width)
|
70 |
nz = self.ldos.shape[2] |
71 |
|
72 |
# Find grid point:
|
73 |
n = z / self.cell[2, 2] * nz |
74 |
dn = n - np.floor(n) |
75 |
n = int(n) % nz
|
76 |
print n,dn
|
77 |
|
78 |
# Average and do linear interpolation:
|
79 |
return ((1 - dn) * self.ldos[:, :, n].mean() + |
80 |
dn * self.ldos[:, :, (n + 1) % nz].mean()) |
81 |
|
82 |
def scan(self, current, z=None, width=None): |
83 |
self.calculate_ldos(width)
|
84 |
|
85 |
L = self.cell[2, 2] |
86 |
if z is None: |
87 |
z = L / 2
|
88 |
|
89 |
nz = self.ldos.shape[2] |
90 |
n = int(round(z / L * nz)) % nz |
91 |
h = L / nz |
92 |
|
93 |
ldos = self.ldos.reshape((-1, nz)) |
94 |
|
95 |
heights = np.empty(ldos.shape[0])
|
96 |
for i, a in enumerate(ldos): |
97 |
heights[i], z, n = find_height(a, current, z, n, nz, h) |
98 |
|
99 |
heights.shape = self.ldos.shape[:2] |
100 |
return heights
|
101 |
|
102 |
def linescan(self, current, p1, p2, npoints=None, z=None, width=None): |
103 |
self.calculate_ldos(width)
|
104 |
|
105 |
L = self.cell[2, 2] |
106 |
if z is None: |
107 |
z = L / 2
|
108 |
|
109 |
nz = self.ldos.shape[2] |
110 |
n = int(round(z / L * nz)) % nz |
111 |
h = L / nz |
112 |
ldos = self.ldos.reshape((-1, nz)) |
113 |
|
114 |
p1 = np.asarray(p1) |
115 |
p2 = np.asarray(p2) |
116 |
d = p2 - p1 |
117 |
s = sqrt(np.dot(d, d)) |
118 |
|
119 |
if npints == None: |
120 |
npoints = int(3 * s / h + 2) |
121 |
|
122 |
cell = self.cell[:2, :2] |
123 |
shape = np.array(self.ldos.shape[:2], float) |
124 |
M = cell.I |
125 |
heights = np.empty(npoints) |
126 |
for i in range(npoints): |
127 |
p = p1 + i * d / (npoints - 1)
|
128 |
q = np.dot(M, p) * shape |
129 |
qi = q.astype(int)
|
130 |
n0, n1 = qi |
131 |
f = q - qi |
132 |
g = 1 - f
|
133 |
a = (g[0] * g[0] * ldos[n0, n1 ] + |
134 |
f[0] * g[0] * ldos[n0 + 1, n1 ] + |
135 |
g[0] * f[0] * ldos[n0, n1 + 1] + |
136 |
f[0] * f[0] * ldos[n0 + 1, n1 + 1]) |
137 |
heights[i], z, n = find_height(a, current, z, n, nz, h) |
138 |
return np.linspace(0, s, npoints), heights |
139 |
|
140 |
def cube(self, filename, atoms=None): |
141 |
pass
|
142 |
|
143 |
|
144 |
def find_height(array, current, z, n, nz, h): |
145 |
c1 = array[n] |
146 |
sign = cmp(c1, current)
|
147 |
m = 0
|
148 |
while m < nz:
|
149 |
n = (n + sign) % nz |
150 |
z += sign * h |
151 |
c2 = array[n] |
152 |
if cmp(c2, current) != sign: |
153 |
break
|
154 |
c1 = c2 |
155 |
m += 1
|
156 |
|
157 |
if m == nz:
|
158 |
print z, n, nz, h, current, array
|
159 |
raise RuntimeError('Tip crash!') |
160 |
|
161 |
return z - sign * h * (current - c2) / (c1 - c2), z, n
|
162 |
|
163 |
|
164 |
|
165 |
|
166 |
|
167 |
|