Statistiques
| Révision :

root / ase / dft / stm.py

Historique | Voir | Annoter | Télécharger (4,71 ko)

1
from math import exp, sqrt
2

    
3
import numpy as np
4

    
5
from ase.atoms import Atoms
6

    
7

    
8
class STM:
9
    def __init__(self, atoms, symmetries=None):
10
        if isinstance(atoms, Atoms):
11
            calc = atoms.get_calculator()
12
        else:
13
            calc = atoms
14
            atoms = calc.get_atoms()
15
        self.nbands = calc.get_number_of_bands()
16
        self.weights = calc.get_k_point_weights()
17
        self.nkpts = len(self.weights)
18
        self.nspins = calc.get_number_of_spins()
19
        self.eigs = np.array([[calc.get_eigenvalues(k, s)
20
                               for k in range(self.nkpts)]
21
                              for s in range(self.nspins)])
22
        self.eigs -= calc.get_fermi_level()
23
        self.calc = calc
24
        self.cell = atoms.get_cell()
25
        assert not self.cell[2, :2].any() and not self.cell[:2, 2].any()
26
        self.ldos = None
27
        self.symmetries = symmetries or []
28
                               
29
    def calculate_ldos(self, width=None):
30
        if self.ldos is not None and width == self.width:
31
            return
32

    
33
        if width is None:
34
            width = 0.1
35
            
36
        ldos = None
37
        for s in range(self.nspins):
38
            for k in range(self.nkpts):
39
                for n in range(self.nbands):
40
                    psi = self.calc.get_pseudo_wave_function(n, k, s)
41
                    if ldos is None:
42
                        ldos = np.zeros_like(psi)
43
                    f = (exp(-(self.eigs[s, k, n] / width)**2) *
44
                         self.weights[k])
45
                    ldos += f * (psi * np.conj(psi)).real
46

    
47
        if 0 in self.symmetries:
48
            # (x,y) -> (-x,y)
49
            ldos[1:] += ldos[:0:-1].copy()
50
            ldos[1:] *= 0.5
51

    
52
        if 1 in self.symmetries:
53
            # (x,y) -> (x,-y)
54
            ldos[:, 1:] += ldos[:, :0:-1].copy()
55
            ldos[:, 1:] *= 0.5
56
            
57
        if 2 in self.symmetries:
58
            # (x,y) -> (y,x)
59
            ldos += ldos.transpose((1, 0, 2)).copy()
60
            ldos *= 0.5
61
            
62
        self.ldos = ldos
63
        self.width = width
64

    
65
    #def save_ldos(self, filename='ldos.pckl'):
66
        
67

    
68
    def get_averaged_current(self, z, width=None):
69
        self.calculate_ldos(width)
70
        nz = self.ldos.shape[2]
71

    
72
        # Find grid point:
73
        n = z / self.cell[2, 2] * nz
74
        dn = n - np.floor(n)
75
        n = int(n) % nz
76
        print n,dn
77

    
78
        # Average and do linear interpolation:
79
        return ((1 - dn) * self.ldos[:, :, n].mean() +
80
                dn *       self.ldos[:, :, (n + 1) % nz].mean())
81
    
82
    def scan(self, current, z=None, width=None):
83
        self.calculate_ldos(width)
84

    
85
        L = self.cell[2, 2]
86
        if z is None:
87
            z = L / 2
88

    
89
        nz = self.ldos.shape[2]
90
        n = int(round(z / L * nz)) % nz
91
        h = L / nz
92

    
93
        ldos = self.ldos.reshape((-1, nz))
94

    
95
        heights = np.empty(ldos.shape[0])
96
        for i, a in enumerate(ldos):
97
            heights[i], z, n = find_height(a, current, z, n, nz, h)
98

    
99
        heights.shape = self.ldos.shape[:2]
100
        return heights
101
    
102
    def linescan(self, current, p1, p2, npoints=None, z=None, width=None):
103
        self.calculate_ldos(width)
104

    
105
        L = self.cell[2, 2]
106
        if z is None:
107
            z = L / 2
108

    
109
        nz = self.ldos.shape[2]
110
        n = int(round(z / L * nz)) % nz
111
        h = L / nz
112
        ldos = self.ldos.reshape((-1, nz))
113

    
114
        p1 = np.asarray(p1)
115
        p2 = np.asarray(p2)
116
        d = p2 - p1
117
        s = sqrt(np.dot(d, d))
118
        
119
        if npints == None:
120
            npoints = int(3 * s / h + 2)
121

    
122
        cell = self.cell[:2, :2]
123
        shape = np.array(self.ldos.shape[:2], float)
124
        M = cell.I
125
        heights = np.empty(npoints)
126
        for i in range(npoints):
127
            p = p1 + i * d / (npoints - 1)
128
            q = np.dot(M, p) * shape
129
            qi = q.astype(int)
130
            n0, n1 = qi
131
            f = q - qi
132
            g = 1 - f
133
            a = (g[0] * g[0] * ldos[n0,     n1    ] +
134
                 f[0] * g[0] * ldos[n0 + 1, n1    ] +
135
                 g[0] * f[0] * ldos[n0,     n1 + 1] +
136
                 f[0] * f[0] * ldos[n0 + 1, n1 + 1])
137
            heights[i], z, n = find_height(a, current, z, n, nz, h)
138
        return np.linspace(0, s, npoints), heights
139

    
140
    def cube(self, filename, atoms=None):
141
        pass
142

    
143

    
144
def find_height(array, current, z, n, nz, h):
145
    c1 = array[n]
146
    sign = cmp(c1, current)
147
    m = 0
148
    while m < nz:
149
        n = (n + sign) % nz
150
        z += sign * h
151
        c2 = array[n]
152
        if cmp(c2, current) != sign:
153
            break
154
        c1 = c2
155
        m += 1
156

    
157
    if m == nz:
158
        print z, n, nz, h, current, array
159
        raise RuntimeError('Tip crash!')
160

    
161
    return z - sign * h * (current - c2) / (c1 - c2), z, n
162

    
163
                
164
            
165
        
166
    
167