Statistiques
| Révision :

root / ase / dft / wannier.py @ 1

Historique | Voir | Annoter | Télécharger (28,24 ko)

1
""" Maximally localized Wannier Functions
2

3
    Find the set of maximally localized Wannier functions
4
    using the spread functional of Marzari and Vanderbilt
5
    (PRB 56, 1997 page 12847). 
6
"""
7
import numpy as np
8
from time import time
9
from math import sqrt, pi
10
from pickle import dump, load
11
from ase.parallel import paropen
12
from ase.calculators.dacapo import Dacapo
13
from ase.dft.kpoints import get_monkhorst_shape
14
from ase.transport.tools import dagger, normalize
15

    
16
dag = dagger
17

    
18

    
19
def gram_schmidt(U):
20
    """Orthonormalize columns of U according to the Gram-Schmidt procedure."""
21
    for i, col in enumerate(U.T):
22
        for col2 in U.T[:i]:
23
            col -= col2 * np.dot(col2.conj(), col)
24
        col /= np.linalg.norm(col)
25

    
26

    
27
def lowdin(U, S=None):
28
    """Orthonormalize columns of U according to the Lowdin procedure.
29
    
30
    If the overlap matrix is know, it can be specified in S.
31
    """
32
    if S is None:
33
        S = np.dot(dag(U), U)
34
    eig, rot = np.linalg.eigh(S)
35
    rot = np.dot(rot / np.sqrt(eig), dag(rot))
36
    U[:] = np.dot(U, rot)
37

    
38

    
39
def neighbor_k_search(k_c, G_c, kpt_kc, tol=1e-4):
40
    # search for k1 (in kpt_kc) and k0 (in alldir), such that
41
    # k1 - k - G + k0 = 0
42
    alldir_dc = np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1],
43
                           [1,1,0],[1,0,1],[0,1,1]], int)
44
    for k0_c in alldir_dc:
45
        for k1, k1_c in enumerate(kpt_kc):
46
            if np.linalg.norm(k1_c - k_c - G_c + k0_c) < tol:
47
                return k1, k0_c
48

    
49
    print 'Wannier: Did not find matching kpoint for kpt=', k_c
50
    print 'Probably non-uniform k-point grid'
51
    raise NotImplementedError
52

    
53

    
54
def calculate_weights(cell_cc):
55
    """ Weights are used for non-cubic cells, see PRB **61**, 10040"""
56
    alldirs_dc = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1],
57
                           [1, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=int)
58
    g = np.dot(cell_cc, cell_cc.T)
59
    # NOTE: Only first 3 of following 6 weights are presently used:
60
    w = np.zeros(6)              
61
    w[0] = g[0, 0] - g[0, 1] - g[0, 2]
62
    w[1] = g[1, 1] - g[0, 1] - g[1, 2]
63
    w[2] = g[2, 2] - g[0, 2] - g[1, 2]
64
    w[3] = g[0, 1]
65
    w[4] = g[0, 2]
66
    w[5] = g[1, 2]
67
    # Make sure that first 3 Gdir vectors are included - 
68
    # these are used to calculate Wanniercenters.
69
    Gdir_dc = alldirs_dc[:3]
70
    weight_d = w[:3]
71
    for d in range(3, 6):
72
        if abs(w[d]) > 1e-5:
73
            Gdir_dc = np.concatenate(Gdir_dc, alldirs_dc[d])
74
            weight_d = np.concatenate(weight_d, w[d])
75
    weight_d /= max(abs(weight_d))
76
    return weight_d, Gdir_dc
77

    
78

    
79
def random_orthogonal_matrix(dim, seed=None, real=False):
80
    """Generate a random orthogonal matrix"""
81
    if seed is not None:
82
        np.random.seed(seed)
83

    
84
    H = np.random.rand(dim, dim)
85
    np.add(dag(H), H, H)
86
    np.multiply(.5, H, H)
87

    
88
    if real:
89
        gram_schmidt(H)
90
        return H
91
    else: 
92
        val, vec = np.linalg.eig(H)
93
        return np.dot(vec * np.exp(1.j * val), dag(vec))
94

    
95

    
96
def steepest_descent(func, step=.005, tolerance=1e-6, **kwargs):
97
    fvalueold = 0.
98
    fvalue = fvalueold + 10
99
    count=0
100
    while abs((fvalue - fvalueold) / fvalue) > tolerance:
101
        fvalueold = fvalue
102
        dF = func.get_gradients()
103
        func.step(dF * step, **kwargs)
104
        fvalue = func.get_functional_value()
105
        count += 1
106
        print 'SteepestDescent: iter=%s, value=%s' % (count, fvalue)
107

    
108

    
109
def md_min(func, step=.25, tolerance=1e-6, verbose=False, **kwargs):
110
    if verbose:
111
        print 'Localize with step =', step, 'and tolerance =', tolerance
112
        t = -time()
113
    fvalueold = 0.
114
    fvalue = fvalueold + 10
115
    count = 0
116
    V = np.zeros(func.get_gradients().shape, dtype=complex)
117
    while abs((fvalue - fvalueold) / fvalue) > tolerance:
118
        fvalueold = fvalue
119
        dF = func.get_gradients()
120
        V *= (dF * V.conj()).real > 0
121
        V += step * dF
122
        func.step(V, **kwargs)
123
        fvalue = func.get_functional_value()
124
        if fvalue < fvalueold:
125
            step *= 0.5
126
        count += 1
127
        if verbose:
128
            print 'MDmin: iter=%s, step=%s, value=%s' % (count, step, fvalue)
129
    if verbose:
130
        t += time()
131
        print '%d iterations in %0.2f seconds (%0.2f ms/iter), endstep = %s' %(
132
            count, t, t * 1000. / count, step)
133

    
134

    
135
def rotation_from_projection(proj_nw, fixed, ortho=True):
136
    """Determine rotation and coefficient matrices from projections
137
    
138
    proj_nw = <psi_n|p_w>
139
    psi_n: eigenstates
140
    p_w: localized function
141
    
142
    Nb (n) = Number of bands
143
    Nw (w) = Number of wannier functions
144
    M  (f) = Number of fixed states
145
    L  (l) = Number of extra degrees of freedom
146
    U  (u) = Number of non-fixed states
147
    """
148

    
149
    Nb, Nw = proj_nw.shape
150
    M = fixed
151
    L = Nw - M
152

    
153
    U_ww = np.empty((Nw, Nw), dtype=proj_nw.dtype)
154
    U_ww[:M] = proj_nw[:M]
155

    
156
    if L > 0:
157
        proj_uw = proj_nw[M:]
158
        eig_w, C_ww = np.linalg.eigh(np.dot(dag(proj_uw), proj_uw))
159
        C_ul = np.dot(proj_uw, C_ww[:, np.argsort(-eig_w.real)[:L]])
160
        #eig_u, C_uu = np.linalg.eigh(np.dot(proj_uw, dag(proj_uw)))
161
        #C_ul = C_uu[:, np.argsort(-eig_u.real)[:L]]
162

    
163
        U_ww[M:] = np.dot(dag(C_ul), proj_uw)
164
    else:
165
        C_ul = np.empty((Nb - M, 0))
166

    
167
    normalize(C_ul)
168
    if ortho:
169
        lowdin(U_ww)
170
    else:
171
        normalize(U_ww)
172

    
173
    return U_ww, C_ul
174

    
175

    
176
class Wannier:
177
    """Maximally localized Wannier Functions
178

179
    Find the set of maximally localized Wannier functions using the
180
    spread functional of Marzari and Vanderbilt (PRB 56, 1997 page
181
    12847).
182
    """
183

    
184
    def __init__(self, nwannier, calc,
185
                 file=None,
186
                 nbands=None,
187
                 fixedenergy=None,
188
                 fixedstates=None,
189
                 spin=0,
190
                 initialwannier='random',
191
                 seed=None,
192
                 verbose=False):
193
        """
194
        Required arguments:
195

196
          ``nwannier``: The number of Wannier functions you wish to construct.
197
            This must be at least half the number of electrons in the system
198
            and at most equal to the number of bands in the calculation.
199

200
          ``calc``: A converged DFT calculator class.
201
            If ``file`` arg. is not provided, the calculator *must* provide the
202
            method ``get_wannier_localization_matrix``, and contain the
203
            wavefunctions (save files with only the density is not enough).
204
            If the localization matrix is read from file, this is not needed,
205
            unless ``get_function`` or ``write_cube`` is called.
206
          
207
        Optional arguments:
208

209
          ``nbands``: Bands to include in localization.
210
            The number of bands considered by Wannier can be smaller than the
211
            number of bands in the calculator. This is useful if the highest
212
            bands of the DFT calculation are not well converged.
213

214
          ``spin``: The spin channel to be considered.
215
            The Wannier code treats each spin channel independently.
216

217
          ``fixedenergy`` / ``fixedstates``: Fixed part of Heilbert space.
218
            Determine the fixed part of Hilbert space by either a maximal
219
            energy *or* a number of bands (possibly a list for multiple
220
            k-points).
221
            Default is None meaning that the number of fixed states is equated
222
            to ``nwannier``.
223

224
          ``file``: Read localization and rotation matrices from this file.
225

226
          ``initialwannier``: Initial guess for Wannier rotation matrix.
227
            Can be 'bloch' to start from the Bloch states, 'random' to be
228
            randomized, or a list passed to calc.get_initial_wannier.
229

230
          ``seed``: Seed for random ``initialwannier``.
231

232
          ``verbose``: True / False level of verbosity.
233
          """
234
        # Bloch phase sign convention
235
        sign = -1
236
        classname = calc.__class__.__name__
237
        if classname in ['Dacapo', 'Jacapo']:
238
            print 'Using ' + classname
239
            sign = +1
240
            
241
        self.nwannier = nwannier
242
        self.calc = calc
243
        self.spin = spin
244
        self.verbose = verbose
245
        self.kpt_kc = sign * calc.get_ibz_k_points()
246
        assert len(calc.get_bz_k_points()) == len(self.kpt_kc)
247
        
248
        self.kptgrid = get_monkhorst_shape(self.kpt_kc)
249
        self.Nk = len(self.kpt_kc)
250
        self.unitcell_cc = calc.get_atoms().get_cell()
251
        self.largeunitcell_cc = (self.unitcell_cc.T * self.kptgrid).T
252
        self.weight_d, self.Gdir_dc = calculate_weights(self.largeunitcell_cc)
253
        self.Ndir = len(self.weight_d) # Number of directions
254

    
255
        if nbands is not None:
256
            self.nbands = nbands
257
        else:
258
            self.nbands = calc.get_number_of_bands()
259
        if fixedenergy is None:
260
            if fixedstates is None:
261
                self.fixedstates_k = np.array([nwannier] * self.Nk, int)
262
            else:
263
                if type(fixedstates) is int:
264
                    fixedstates = [fixedstates] * self.Nk
265
                self.fixedstates_k = np.array(fixedstates, int)
266
        else:
267
            # Setting number of fixed states and EDF from specified energy.
268
            # All states below this energy (relative to Fermi level) are fixed.
269
            fixedenergy += calc.get_fermi_level()
270
            print fixedenergy
271
            self.fixedstates_k = np.array(
272
                [calc.get_eigenvalues(k, spin).searchsorted(fixedenergy)
273
                 for k in range(self.Nk)], int)
274
        self.edf_k = self.nwannier - self.fixedstates_k
275
        if verbose:
276
            print 'Wannier: Fixed states            : %s' % self.fixedstates_k
277
            print 'Wannier: Extra degrees of freedom: %s' % self.edf_k
278

    
279
        # Set the list of neighboring k-points k1, and the "wrapping" k0,
280
        # such that k1 - k - G + k0 = 0
281
        #
282
        # Example: kpoints = (-0.375,-0.125,0.125,0.375), dir=0
283
        # G = [0.25,0,0]
284
        # k=0.375, k1= -0.375 : -0.375-0.375-0.25 => k0=[1,0,0]
285
        #
286
        # For a gamma point calculation k1 = k = 0,  k0 = [1,0,0] for dir=0
287
        if self.Nk == 1:
288
            self.kklst_dk = np.zeros((self.Ndir, 1), int)
289
            k0_dkc = self.Gdir_dc.reshape(-1, 1, 3)
290
        else:
291
            self.kklst_dk = np.empty((self.Ndir, self.Nk), int)
292
            k0_dkc = np.empty((self.Ndir, self.Nk, 3), int)
293

    
294
            # Distance between kpoints
295
            kdist_c = np.empty(3)
296
            for c in range(3):
297
                # make a sorted list of the kpoint values in this direction
298
                slist = np.argsort(self.kpt_kc[:, c], kind='mergesort')
299
                skpoints_kc = np.take(self.kpt_kc, slist, axis=0)
300
                kdist_c[c] = max([skpoints_kc[n + 1, c] - skpoints_kc[n, c]
301
                                  for n in range(self.Nk - 1)])               
302

    
303
            for d, Gdir_c in enumerate(self.Gdir_dc):
304
                for k, k_c in enumerate(self.kpt_kc):
305
                    # setup dist vector to next kpoint
306
                    G_c = np.where(Gdir_c > 0, kdist_c, 0)
307
                    if max(G_c) < 1e-4:
308
                        self.kklst_dk[d, k] = k
309
                        k0_dkc[d, k] = Gdir_c
310
                    else:
311
                        self.kklst_dk[d, k], k0_dkc[d, k] = \
312
                                       neighbor_k_search(k_c, G_c, self.kpt_kc)
313

    
314
        # Set the inverse list of neighboring k-points
315
        self.invkklst_dk = np.empty((self.Ndir, self.Nk), int)
316
        for d in range(self.Ndir):
317
            for k1 in range(self.Nk):
318
                self.invkklst_dk[d, k1] = self.kklst_dk[d].tolist().index(k1)
319

    
320
        Nw = self.nwannier
321
        Nb = self.nbands
322
        self.Z_dkww = np.empty((self.Ndir, self.Nk, Nw, Nw), complex)
323
        self.V_knw = np.zeros((self.Nk, Nb, Nw), complex)
324
        if file is None:
325
            self.Z_dknn = np.empty((self.Ndir, self.Nk, Nb, Nb), complex)
326
            for d, dirG in enumerate(self.Gdir_dc):
327
                for k in range(self.Nk):
328
                    k1 = self.kklst_dk[d, k]
329
                    k0_c = k0_dkc[d, k]
330
                    self.Z_dknn[d, k] = calc.get_wannier_localization_matrix(
331
                        nbands=Nb, dirG=dirG, kpoint=k, nextkpoint=k1,
332
                        G_I=k0_c, spin=self.spin)
333
        self.initialize(file=file, initialwannier=initialwannier, seed=seed)
334

    
335
    def initialize(self, file=None, initialwannier='random', seed=None):
336
        """Re-initialize current rotation matrix.
337

338
        Keywords are identical to those of the constructor.
339
        """
340
        Nw = self.nwannier
341
        Nb = self.nbands
342

    
343
        if file is not None:
344
            self.Z_dknn, self.U_kww, self.C_kul = load(paropen(file))
345
        elif initialwannier == 'bloch':
346
            # Set U and C to pick the lowest Bloch states
347
            self.U_kww = np.zeros((self.Nk, Nw, Nw), complex)
348
            self.C_kul = []
349
            for U, M, L in zip(self.U_kww, self.fixedstates_k, self.edf_k):
350
                U[:] = np.identity(Nw, complex)
351
                if L > 0:
352
                    self.C_kul.append(
353
                        np.identity(Nb - M, complex)[:, :L])
354
                else:
355
                    self.C_kul.append([])
356
        elif initialwannier == 'random':
357
            # Set U and C to random (orthogonal) matrices
358
            self.U_kww = np.zeros((self.Nk, Nw, Nw), complex)
359
            self.C_kul = []
360
            for U, M, L in zip(self.U_kww, self.fixedstates_k, self.edf_k):
361
                U[:] = random_orthogonal_matrix(Nw, seed, real=False)
362
                if L > 0:
363
                    self.C_kul.append(random_orthogonal_matrix(
364
                        Nb - M, seed=seed, real=False)[:, :L])
365
                else:
366
                    self.C_kul.append(np.array([]))        
367
        else:
368
            # Use initial guess to determine U and C
369
            self.C_kul, self.U_kww = self.calc.initial_wannier(
370
                initialwannier, self.kptgrid, self.fixedstates_k,
371
                self.edf_k, self.spin)
372
        self.update()
373

    
374
    def save(self, file):
375
        """Save information on localization and rotation matrices to file."""
376
        dump((self.Z_dknn, self.U_kww, self.C_kul), paropen(file, 'w'))
377

    
378
    def update(self):
379
        # Update large rotation matrix V (from rotation U and coeff C)
380
        for k, M in enumerate(self.fixedstates_k):
381
            self.V_knw[k, :M] = self.U_kww[k, :M]
382
            if M < self.nwannier:
383
                self.V_knw[k, M:] = np.dot(self.C_kul[k], self.U_kww[k, M:])
384
            # else: self.V_knw[k, M:] = 0.0
385

    
386
        # Calculate the Zk matrix from the large rotation matrix:
387
        # Zk = V^d[k] Zbloch V[k1]
388
        for d in range(self.Ndir):
389
            for k in range(self.Nk):
390
                k1 = self.kklst_dk[d, k]
391
                self.Z_dkww[d, k] = np.dot(dag(self.V_knw[k]), np.dot(
392
                    self.Z_dknn[d, k], self.V_knw[k1]))
393

    
394
        # Update the new Z matrix
395
        self.Z_dww = self.Z_dkww.sum(axis=1) / self.Nk
396

    
397
    def get_centers(self, scaled=False):
398
        """Calculate the Wannier centers
399

400
        ::
401
        
402
          pos =  L / 2pi * phase(diag(Z))
403
        """
404
        coord_wc = np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T / (2 * pi) % 1
405
        if not scaled:
406
            coord_wc = np.dot(coord_wc, self.largeunitcell_cc)
407
        return coord_wc
408

    
409
    def get_radii(self):
410
        """Calculate the spread of the Wannier functions.
411

412
        ::
413
          
414
                        --  /  L  \ 2       2
415
          radius**2 = - >   | --- |   ln |Z| 
416
                        --d \ 2pi /
417
        """
418
        r2 = -np.dot(self.largeunitcell_cc.diagonal()**2 / (2 * pi)**2,
419
                     np.log(abs(self.Z_dww[:3].diagonal(0, 1, 2))**2))
420
        return np.sqrt(r2)
421

    
422
    def get_spectral_weight(self, w):
423
        return abs(self.V_knw[:, :, w])**2 / self.Nk
424

    
425
    def get_pdos(self, w, energies, width):
426
        """Projected density of states (PDOS).
427

428
        Returns the (PDOS) for Wannier function ``w``. The calculation
429
        is performed over the energy grid specified in energies. The
430
        PDOS is produced as a sum of Gaussians centered at the points
431
        of the energy grid and with the specified width.
432
        """
433
        spec_kn = self.get_spectral_weight(w)
434
        dos = np.zeros(len(energies))
435
        for k, spec_n in enumerate(spec_kn):
436
            eig_n = self.calc.get_eigenvalues(k=kpt, s=self.spin)
437
            for weight, eig in zip(spec_n, eig):
438
                # Add gaussian centered at the eigenvalue
439
                x = ((energies - center) / width)**2
440
                dos += weight * np.exp(-x.clip(0., 40.)) / (sqrt(pi) * width)
441
        return dos
442

    
443
    def max_spread(self, directions=[0, 1, 2]):
444
        """Returns the index of the most delocalized Wannier function
445
        together with the value of the spread functional"""
446
        d = np.zeros(self.nwannier)
447
        for dir in directions:
448
            d[dir] = np.abs(self.Z_dww[dir].diagonal())**2 *self.weight_d[dir]
449
        index = np.argsort(d)[0]
450
        print 'Index:', index
451
        print 'Spread:', d[index]           
452

    
453
    def translate(self, w, R):
454
        """Translate the w'th Wannier function
455

456
        The distance vector R = [n1, n2, n3], is in units of the basis
457
        vectors of the small cell.
458
        """
459
        for kpt_c, U_ww in zip(self.kpt_kc, self.U_kww):
460
            U_ww[:, w] *= np.exp(2.j * pi * np.dot(np.array(R), kpt_c))
461
        self.update()
462

    
463
    def translate_to_cell(self, w, cell):
464
        """Translate the w'th Wannier function to specified cell"""
465
        scaled_c = np.angle(self.Z_dww[:3, w, w]) * self.kptgrid / (2 * pi)
466
        trans = np.array(cell) - np.floor(scaled_c)
467
        self.translate(w, trans)
468

    
469
    def translate_all_to_cell(self, cell=[0, 0, 0]):
470
        """Translate all Wannier functions to specified cell.
471

472
        Move all Wannier orbitals to a specific unit cell.  There
473
        exists an arbitrariness in the positions of the Wannier
474
        orbitals relative to the unit cell. This method can move all
475
        orbitals to the unit cell specified by ``cell``.  For a
476
        `\Gamma`-point calculation, this has no effect. For a
477
        **k**-point calculation the periodicity of the orbitals are
478
        given by the large unit cell defined by repeating the original
479
        unitcell by the number of **k**-points in each direction.  In
480
        this case it is usefull to move the orbitals away from the
481
        boundaries of the large cell before plotting them. For a bulk
482
        calculation with, say 10x10x10 **k** points, one could move
483
        the orbitals to the cell [2,2,2].  In this way the pbc
484
        boundary conditions will not be noticed.
485
        """
486
        scaled_wc = np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T  * \
487
                    self.kptgrid / (2 * pi)
488
        trans_wc =  np.array(cell)[None] - np.floor(scaled_wc)
489
        for kpt_c, U_ww in zip(self.kpt_kc, self.U_kww):
490
            U_ww *= np.exp(2.j * pi * np.dot(trans_wc, kpt_c))
491
        self.update()
492

    
493
    def distances(self, R):
494
        Nw = self.nwannier
495
        cen = self.get_centers()
496
        r1 = cen.repeat(Nw, axis=0).reshape(Nw, Nw, 3)
497
        r2 = cen.copy()
498
        for i in range(3):
499
            r2 += self.unitcell_cc[i] * R[i]
500

    
501
        r2 = np.swapaxes(r2.repeat(Nw, axis=0).reshape(Nw, Nw, 3), 0, 1)
502
        return np.sqrt(np.sum((r1 - r2)**2, axis=-1))
503

    
504
    def get_hopping(self, R):
505
        """Returns the matrix H(R)_nm=<0,n|H|R,m>.
506

507
        ::
508
        
509
                                1   _   -ik.R 
510
          H(R) = <0,n|H|R,m> = --- >_  e      H(k)
511
                                Nk  k         
512

513
        where R is the cell-distance (in units of the basis vectors of
514
        the small cell) and n,m are indices of the Wannier functions.
515
        """
516
        H_ww = np.zeros([self.nwannier, self.nwannier], complex)
517
        for k, kpt_c in enumerate(self.kpt_kc):
518
            phase = np.exp(-2.j * pi * np.dot(np.array(R), kpt_c))
519
            H_ww += self.get_hamiltonian(k) * phase
520
        return H_ww / self.Nk
521

    
522
    def get_hamiltonian(self, k=0):
523
        """Get Hamiltonian at existing k-vector of index k
524

525
        ::
526
        
527
                  dag
528
          H(k) = V    diag(eps )  V
529
                  k           k    k
530
        """
531
        eps_n = self.calc.get_eigenvalues(kpt=k, spin=self.spin)
532
        return np.dot(dag(self.V_knw[k]) * eps_n, self.V_knw[k])
533

    
534
    def get_hamiltonian_kpoint(self, kpt_c):
535
        """Get Hamiltonian at some new arbitrary k-vector
536

537
        ::
538
        
539
                  _   ik.R 
540
          H(k) = >_  e     H(R)
541
                  R         
542

543
        Warning: This method moves all Wannier functions to cell (0, 0, 0)
544
        """
545
        if self.verbose:
546
            print 'Translating all Wannier functions to cell (0, 0, 0)'
547
        self.translate_all_to_cell()
548
        max = (self.kptgrid - 1) / 2
549
        max += max > 0
550
        N1, N2, N3 = max
551
        Hk = np.zeros([self.nwannier, self.nwannier], complex)
552
        for n1 in xrange(-N1, N1 + 1):
553
            for n2 in xrange(-N2, N2 + 1):
554
                for n3 in xrange(-N3, N3 + 1):
555
                    R = np.array([n1, n2, n3], float)
556
                    hop_ww = self.get_hopping(R)
557
                    phase = np.exp(+2.j * pi * np.dot(R, kpt_c))
558
                    Hk += hop_ww * phase
559
        return Hk
560

    
561
    def get_function(self, index, repeat=None):
562
        """Get Wannier function on grid.
563

564
        Returns an array with the funcion values of the indicated Wannier
565
        function on a grid with the size of the *repeated* unit cell.
566
       
567
        For a calculation using **k**-points the relevant unit cell for
568
        eg. visualization of the Wannier orbitals is not the original unit
569
        cell, but rather a larger unit cell defined by repeating the
570
        original unit cell by the number of **k**-points in each direction.
571
        Note that for a `\Gamma`-point calculation the large unit cell
572
        coinsides with the original unit cell.
573
        The large unitcell also defines the periodicity of the Wannier
574
        orbitals.
575

576
        ``index`` can be either a single WF or a coordinate vector in terms
577
        of the WFs.
578
        """
579

    
580
        # Default size of plotting cell is the one corresponding to k-points.
581
        if repeat is None:
582
            repeat = self.kptgrid
583
        N1, N2, N3 = repeat
584

    
585
        dim = self.calc.get_number_of_grid_points()
586
        largedim = dim * [N1, N2, N3]
587
        
588
        wanniergrid = np.zeros(largedim, dtype=complex)
589
        for k, kpt_c in enumerate(self.kpt_kc):
590
            # The coordinate vector of wannier functions
591
            if type(index) == int:
592
                vec_n = self.V_knw[k, :, index]
593
            else:   
594
                vec_n = np.dot(self.V_knw[k], index)
595

    
596
            wan_G = np.zeros(dim, complex)
597
            for n, coeff in enumerate(vec_n):
598
                wan_G += coeff * self.calc.get_pseudo_wave_function(
599
                    n, k, self.spin, pad=True)
600

    
601
            # Distribute the small wavefunction over large cell:
602
            for n1 in xrange(N1):
603
                for n2 in xrange(N2):
604
                    for n3 in xrange(N3): # sign?
605
                        e = np.exp(-2.j * pi * np.dot([n1, n2, n3], kpt_c))
606
                        wanniergrid[n1 * dim[0]:(n1 + 1) * dim[0],
607
                                    n2 * dim[1]:(n2 + 1) * dim[1],
608
                                    n3 * dim[2]:(n3 + 1) * dim[2]] += e * wan_G
609

    
610
        # Normalization
611
        wanniergrid /= np.sqrt(self.Nk)
612
        return wanniergrid
613

    
614
    def write_cube(self, index, fname, repeat=None, real=True):
615
        """Dump specified Wannier function to a cube file"""
616
        from ase.io.cube import write_cube
617

    
618
        # Default size of plotting cell is the one corresponding to k-points.
619
        if repeat is None:
620
            repeat = self.kptgrid
621
        atoms = self.calc.get_atoms() * repeat
622
        func = self.get_function(index, repeat)
623

    
624
        # Handle separation of complex wave into real parts
625
        if real:
626
            if self.Nk == 1:
627
                func *= np.exp(-1.j * np.angle(func.max()))
628
                if 0: assert max(abs(func.imag).flat) < 1e-4
629
                func = func.real
630
            else:
631
                func = abs(func)
632
        else:
633
            phase_fname = fname.split('.')
634
            phase_fname.insert(1, 'phase')
635
            phase_fname = '.'.join(phase_fname)
636
            write_cube(phase_fname, atoms, data=np.angle(func))
637
            func = abs(func)
638

    
639
        write_cube(fname, atoms, data=func)
640

    
641
    def localize(self, step=0.25, tolerance=1e-08,
642
                 updaterot=True, updatecoeff=True):
643
        """Optimize rotation to give maximal localization"""
644
        md_min(self, step, tolerance, verbose=self.verbose,
645
               updaterot=updaterot, updatecoeff=updatecoeff)
646

    
647
    def get_functional_value(self): 
648
        """Calculate the value of the spread functional.
649

650
        ::
651

652
          Tr[|ZI|^2]=sum(I)sum(n) w_i|Z_(i)_nn|^2,
653

654
        where w_i are weights."""
655
        a_d = np.sum(np.abs(self.Z_dww.diagonal(0, 1, 2))**2, axis=1)
656
        return np.dot(a_d, self.weight_d).real
657

    
658
    def get_gradients(self):
659
        # Determine gradient of the spread functional.
660
        # 
661
        # The gradient for a rotation A_kij is::
662
        # 
663
        #    dU = dRho/dA_{k,i,j} = sum(I) sum(k')
664
        #            + Z_jj Z_kk',ij^* - Z_ii Z_k'k,ij^*
665
        #            - Z_ii^* Z_kk',ji + Z_jj^* Z_k'k,ji
666
        # 
667
        # The gradient for a change of coefficients is::
668
        # 
669
        #   dRho/da^*_{k,i,j} = sum(I) [[(Z_0)_{k} V_{k'} diag(Z^*) +
670
        #                                (Z_0_{k''})^d V_{k''} diag(Z)] *
671
        #                                U_k^d]_{N+i,N+j}
672
        # 
673
        # where diag(Z) is a square,diagonal matrix with Z_nn in the diagonal, 
674
        # k' = k + dk and k = k'' + dk.
675
        # 
676
        # The extra degrees of freedom chould be kept orthonormal to the fixed
677
        # space, thus we introduce lagrange multipliers, and minimize instead::
678
        # 
679
        #     Rho_L=Rho- sum_{k,n,m} lambda_{k,nm} <c_{kn}|c_{km}>
680
        # 
681
        # for this reason the coefficient gradients should be multiplied
682
        # by (1 - c c^d).
683
        
684
        Nb = self.nbands
685
        Nw = self.nwannier
686
        
687
        dU = []
688
        dC = []
689
        for k in xrange(self.Nk):
690
            M = self.fixedstates_k[k]
691
            L = self.edf_k[k]
692
            U_ww = self.U_kww[k]
693
            C_ul = self.C_kul[k]
694
            Utemp_ww = np.zeros((Nw, Nw), complex)
695
            Ctemp_nw = np.zeros((Nb, Nw), complex)
696

    
697
            for d, weight in enumerate(self.weight_d):
698
                if abs(weight) < 1.0e-6:
699
                    continue
700

    
701
                Z_knn = self.Z_dknn[d]
702
                diagZ_w = self.Z_dww[d].diagonal()
703
                Zii_ww = np.repeat(diagZ_w, Nw).reshape(Nw, Nw)
704
                k1 = self.kklst_dk[d, k]
705
                k2 = self.invkklst_dk[d, k]
706
                V_knw = self.V_knw
707
                Z_kww = self.Z_dkww[d]
708
                
709
                if L > 0:
710
                    Ctemp_nw += weight * np.dot(
711
                        np.dot(Z_knn[k], V_knw[k1]) * diagZ_w.conj() +
712
                        np.dot(dag(Z_knn[k2]), V_knw[k2]) * diagZ_w,
713
                        dag(U_ww))
714

    
715
                temp = Zii_ww.T * Z_kww[k].conj() - Zii_ww * Z_kww[k2].conj()
716
                Utemp_ww += weight * (temp - dag(temp))
717
            dU.append(Utemp_ww.ravel())
718
            if L > 0:
719
                # Ctemp now has same dimension as V, the gradient is in the
720
                # lower-right (Nb-M) x L block
721
                Ctemp_ul = Ctemp_nw[M:, M:]
722
                G_ul = Ctemp_ul - np.dot(np.dot(C_ul, dag(C_ul)), Ctemp_ul)
723
                dC.append(G_ul.ravel())
724

    
725
        return np.concatenate(dU + dC)
726
                        
727
    def step(self, dX, updaterot=True, updatecoeff=True):
728
        # dX is (A, dC) where U->Uexp(-A) and C->C+dC
729
        Nw = self.nwannier
730
        Nk = self.Nk
731
        M_k = self.fixedstates_k
732
        L_k = self.edf_k
733
        if updaterot:
734
            A_kww = dX[:Nk * Nw**2].reshape(Nk, Nw, Nw)
735
            for U, A in zip(self.U_kww, A_kww):
736
                H = -1.j * A.conj()
737
                epsilon, Z = np.linalg.eigh(H)
738
                # Z contains the eigenvectors as COLUMNS.
739
                # Since H = iA, dU = exp(-A) = exp(iH) = ZDZ^d
740
                dU = np.dot(Z * np.exp(1.j * epsilon), dag(Z))
741
                U[:] = np.dot(U, dU)
742

    
743
        if updatecoeff:
744
            start = 0
745
            for C, unocc, L in zip(self.C_kul, self.nbands - M_k, L_k):
746
                if L == 0 or unocc == 0:
747
                    continue
748
                Ncoeff = L * unocc
749
                deltaC = dX[Nk * Nw**2 + start: Nk * Nw**2 + start + Ncoeff]
750
                C += deltaC.reshape(unocc, L)
751
                gram_schmidt(C)
752
                start += Ncoeff
753

    
754
        self.update()