Statistiques
| Révision :

root / ase / optimize / bfgslinesearch.py @ 1

Historique | Voir | Annoter | Télécharger (14,89 ko)

1
#__docformat__ = "restructuredtext en"
2
# ******NOTICE***************
3
# optimize.py module by Travis E. Oliphant
4
#
5
# You may copy and use this module as you see fit with no
6
# guarantee implied provided you keep this notice in all copies.
7
# *****END NOTICE************
8

    
9
import numpy as np
10
from numpy import atleast_1d, eye, mgrid, argmin, zeros, shape, empty, \
11
     squeeze, vectorize, asarray, absolute, sqrt, Inf, asfarray, isinf
12
from ase.utils.linesearch import LineSearch
13
from ase.optimize.optimize import Optimizer
14
from numpy import arange
15

    
16

    
17
# These have been copied from Numeric's MLab.py
18
# I don't think they made the transition to scipy_core
19

    
20
# Modified from scipy_optimize
21
abs = absolute
22
import __builtin__
23
pymin = __builtin__.min
24
pymax = __builtin__.max
25
__version__="0.1"
26

    
27
class BFGSLineSearch(Optimizer):
28
    def __init__(self, atoms, restart=None, logfile='-', maxstep=.2,
29
                 trajectory=None, c1=.23, c2=0.46, alpha=10., stpmax=50.):
30
        """Minimize a function using the BFGS algorithm.
31

32
        Notes:
33

34
            Optimize the function, f, whose gradient is given by fprime
35
            using the quasi-Newton method of Broyden, Fletcher, Goldfarb,
36
            and Shanno (BFGS) See Wright, and Nocedal 'Numerical
37
            Optimization', 1999, pg. 198.
38

39
        *See Also*:
40

41
          scikits.openopt : SciKit which offers a unified syntax to call
42
                            this and other solvers.
43

44
        """
45
        self.maxstep = maxstep
46
        self.stpmax = stpmax
47
        self.alpha = alpha
48
        self.H = None
49
        self.c1 = c1
50
        self.c2 = c2
51
        self.force_calls = 0
52
        self.function_calls = 0
53
        self.r0 = None
54
        self.g0 = None
55
        self.e0 = None
56
        self.load_restart = False
57
        self.task = 'START'
58
        self.rep_count = 0
59
        self.p = None
60
        self.alpha_k = None
61
        self.no_update = False
62
        self.replay = False
63

    
64
        Optimizer.__init__(self, atoms, restart, logfile, trajectory)
65

    
66
    def read(self):
67
        self.r0, self.g0, self.e0, self.task, self.H = self.load()
68
        self.load_restart = True    
69

    
70
    def reset(self):
71
        print 'reset'
72
        self.H = None
73
        self.r0 = None
74
        self.g0 = None
75
        self.e0 = None
76
        self.rep_count = 0
77
          
78

    
79
    def step(self, f):
80
        atoms = self.atoms
81
        r = atoms.get_positions()
82
        r = r.reshape(-1)
83
        g = -f.reshape(-1) / self.alpha
84
        #g = -f.reshape(-1) 
85
        p0 = self.p
86
        self.update(r, g, self.r0, self.g0, p0)
87
        e = atoms.get_potential_energy() / self.alpha
88
        #e = self.func(r)
89

    
90
        self.p = -np.dot(self.H,g)
91
        p_size = np.sqrt((self.p **2).sum())
92
        if self.nsteps != 0:
93
            p0_size = np.sqrt((p0 **2).sum())
94
            delta_p = self.p/p_size + p0/p0_size
95
        if p_size <= np.sqrt(len(atoms) * 1e-10):
96
            self.p /= (p_size / np.sqrt(len(atoms)*1e-10))
97
        ls = LineSearch()
98
        self.alpha_k, e, self.e0, self.no_update = \
99
           ls._line_search(self.func, self.fprime, r, self.p, g, e, self.e0,
100
                           maxstep=self.maxstep, c1=self.c1,
101
                           c2=self.c2, stpmax=self.stpmax)
102
        #if alpha_k is None:  # line search failed try different one.
103
        #    alpha_k, fc, gc, e, e0, gfkp1 = \
104
        #             line_search(self.func, self.fprime,r,p,g,
105
        #                         e,self.e0)
106
        #if abs(e - self.e0) < 0.000001:
107
        #    self.rep_count += 1
108
        #else:
109
        #    self.rep_count = 0
110

    
111
        #if (alpha_k is None) or (self.rep_count >= 3):
112
        #    # If the line search fails, reset the Hessian matrix and
113
        #    # start a new line search.
114
        #    self.reset()
115
        #    return
116

    
117
        dr = self.alpha_k * self.p
118
        atoms.set_positions((r+dr).reshape(len(atoms),-1))
119
        self.r0 = r
120
        self.g0 = g
121
        self.dump((self.r0, self.g0, self.e0, self.task, self.H))
122

    
123
    def update(self, r, g, r0, g0, p0):
124
        self.I = eye(len(self.atoms) * 3, dtype=int)
125
        if self.H is None:
126
            self.H = eye(3 * len(self.atoms))
127
            #self.H = eye(3 * len(self.atoms)) / self.alpha
128
            return
129
        else:
130
            dr = r - r0
131
            dg = g - g0 
132
            if not ((self.alpha_k > 0 and abs(np.dot(g,p0))-abs(np.dot(g0,p0)) < 0) \
133
                or self.replay):
134
                return
135
            if self.no_update == True:
136
                print 'skip update'
137
                return
138

    
139
            try: # this was handled in numeric, let it remaines for more safety
140
                rhok = 1.0 / (np.dot(dg,dr))
141
            except ZeroDivisionError:
142
                rhok = 1000.0
143
                print "Divide-by-zero encountered: rhok assumed large"
144
            if isinf(rhok): # this is patch for np
145
                rhok = 1000.0
146
                print "Divide-by-zero encountered: rhok assumed large"
147
            A1 = self.I - dr[:, np.newaxis] * dg[np.newaxis, :] * rhok
148
            A2 = self.I - dg[:, np.newaxis] * dr[np.newaxis, :] * rhok
149
            H0 = self.H
150
            self.H = np.dot(A1, np.dot(self.H, A2)) + rhok * dr[:, np.newaxis] \
151
                     * dr[np.newaxis, :]
152
            #self.B = np.linalg.inv(self.H)
153
            #omega, V = np.linalg.eigh(self.B)
154
            #eighfile = open('eigh.log','w')
155

    
156
    def func(self, x):
157
        """Objective function for use of the optimizers"""
158
        self.atoms.set_positions(x.reshape(-1, 3))
159
        self.function_calls += 1
160
        # Scale the problem as SciPy uses I as initial Hessian.
161
        return self.atoms.get_potential_energy() / self.alpha
162
        #return self.atoms.get_potential_energy() 
163
    
164
    def fprime(self, x):
165
        """Gradient of the objective function for use of the optimizers"""
166
        self.atoms.set_positions(x.reshape(-1, 3))
167
        self.force_calls += 1
168
        # Remember that forces are minus the gradient!
169
        # Scale the problem as SciPy uses I as initial Hessian.
170
        return - self.atoms.get_forces().reshape(-1) / self.alpha
171
        #return - self.atoms.get_forces().reshape(-1) 
172

    
173
    def replay_trajectory(self, traj):
174
        """Initialize hessian from old trajectory."""
175
        self.replay = True
176
        if isinstance(traj, str):
177
            from ase.io.trajectory import PickleTrajectory
178
            traj = PickleTrajectory(traj, 'r')
179
        atoms = traj[0]
180
        r0 = None
181
        g0 = None
182
        for i in range(0, len(traj) - 1):
183
            r = traj[i].get_positions().ravel()
184
            g = - traj[i].get_forces().ravel() / self.alpha
185
            self.update(r, g, r0, g0, self.p)
186
            self.p = -np.dot(self.H,g)
187
            r0 = r.copy()
188
            g0 = g.copy()
189
        self.r0 = r0
190
        self.g0 = g0
191
        #self.r0 = traj[-2].get_positions().ravel()
192
        #self.g0 = - traj[-2].get_forces().ravel()
193

    
194
def wrap_function(function, args):
195
    ncalls = [0]
196
    def function_wrapper(x):
197
        ncalls[0] += 1
198
        return function(x, *args)
199
    return ncalls, function_wrapper
200

    
201
def _cubicmin(a,fa,fpa,b,fb,c,fc):
202
    # finds the minimizer for a cubic polynomial that goes through the
203
    #  points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa.
204
    #
205
    # if no minimizer can be found return None
206
    #
207
    # f(x) = A *(x-a)^3 + B*(x-a)^2 + C*(x-a) + D
208

    
209
    C = fpa
210
    D = fa
211
    db = b-a
212
    dc = c-a
213
    if (db == 0) or (dc == 0) or (b==c): return None
214
    denom = (db*dc)**2 * (db-dc)
215
    d1 = empty((2,2))
216
    d1[0,0] = dc**2
217
    d1[0,1] = -db**2
218
    d1[1,0] = -dc**3
219
    d1[1,1] = db**3
220
    [A,B] = np.dot(d1,asarray([fb-fa-C*db,fc-fa-C*dc]).flatten())
221
    A /= denom
222
    B /= denom
223
    radical = B*B-3*A*C
224
    if radical < 0:  return None
225
    if (A == 0): return None
226
    xmin = a + (-B + sqrt(radical))/(3*A)
227
    return xmin
228

    
229
def _quadmin(a,fa,fpa,b,fb):
230
    # finds the minimizer for a quadratic polynomial that goes through
231
    #  the points (a,fa), (b,fb) with derivative at a of fpa
232
    # f(x) = B*(x-a)^2 + C*(x-a) + D
233
    D = fa
234
    C = fpa
235
    db = b-a*1.0
236
    if (db==0): return None
237
    B = (fb-D-C*db)/(db*db)
238
    if (B <= 0): return None
239
    xmin = a  - C / (2.0*B)
240
    return xmin
241

    
242
def zoom(a_lo, a_hi, phi_lo, phi_hi, derphi_lo,
243
         phi, derphi, phi0, derphi0, c1, c2):
244
    maxiter = 10
245
    i = 0
246
    delta1 = 0.2  # cubic interpolant check
247
    delta2 = 0.1  # quadratic interpolant check
248
    phi_rec = phi0
249
    a_rec = 0
250
    while 1:
251
        # interpolate to find a trial step length between a_lo and a_hi
252
        # Need to choose interpolation here.  Use cubic interpolation and then 
253
        #if the result is within delta * dalpha or outside of the interval 
254
        #bounded by a_lo or a_hi then use quadratic interpolation, if the 
255
        #result is still too close, then use bisection
256

    
257
        dalpha = a_hi-a_lo;
258
        if dalpha < 0: a,b = a_hi,a_lo
259
        else: a,b = a_lo, a_hi
260

    
261
        # minimizer of cubic interpolant
262
        #    (uses phi_lo, derphi_lo, phi_hi, and the most recent value of phi)
263
        #      if the result is too close to the end points (or out of the 
264
        #         interval) then use quadratic interpolation with phi_lo, 
265
        #         derphi_lo and phi_hi
266
        #      if the result is stil too close to the end points (or out of 
267
        #         the interval) then use bisection
268

    
269
        if (i > 0):
270
            cchk = delta1*dalpha
271
            a_j = _cubicmin(a_lo, phi_lo, derphi_lo, a_hi, phi_hi, a_rec, 
272
                            phi_rec)
273
        if (i==0) or (a_j is None) or (a_j > b-cchk) or (a_j < a+cchk):
274
            qchk = delta2*dalpha
275
            a_j = _quadmin(a_lo, phi_lo, derphi_lo, a_hi, phi_hi)
276
            if (a_j is None) or (a_j > b-qchk) or (a_j < a+qchk):
277
                a_j = a_lo + 0.5*dalpha
278
#                print "Using bisection."
279
#            else: print "Using quadratic."
280
#        else: print "Using cubic."
281

    
282
        # Check new value of a_j
283

    
284
        phi_aj = phi(a_j)
285
        if (phi_aj > phi0 + c1*a_j*derphi0) or (phi_aj >= phi_lo):
286
            phi_rec = phi_hi
287
            a_rec = a_hi
288
            a_hi = a_j
289
            phi_hi = phi_aj
290
        else:
291
            derphi_aj = derphi(a_j)
292
            if abs(derphi_aj) <= -c2*derphi0:
293
                a_star = a_j
294
                val_star = phi_aj
295
                valprime_star = derphi_aj
296
                break
297
            if derphi_aj*(a_hi - a_lo) >= 0:
298
                phi_rec = phi_hi
299
                a_rec = a_hi
300
                a_hi = a_lo
301
                phi_hi = phi_lo
302
            else:
303
                phi_rec = phi_lo
304
                a_rec = a_lo
305
            a_lo = a_j
306
            phi_lo = phi_aj
307
            derphi_lo = derphi_aj
308
        i += 1
309
        if (i > maxiter):
310
            a_star = a_j
311
            val_star = phi_aj
312
            valprime_star = None
313
            break
314
    return a_star, val_star, valprime_star
315

    
316
def line_search(f, myfprime, xk, pk, gfk, old_fval, old_old_fval,
317
                args=(), c1=1e-4, c2=0.9, amax=50):
318
    """Find alpha that satisfies strong Wolfe conditions.
319

320
    Parameters:
321

322
        f : callable f(x,*args)
323
            Objective function.
324
        myfprime : callable f'(x,*args)
325
            Objective function gradient (can be None).
326
        xk : ndarray
327
            Starting point.
328
        pk : ndarray
329
            Search direction.
330
        gfk : ndarray
331
            Gradient value for x=xk (xk being the current parameter
332
            estimate).
333
        args : tuple
334
            Additional arguments passed to objective function.
335
        c1 : float
336
            Parameter for Armijo condition rule.
337
        c2 : float
338
            Parameter for curvature condition rule.
339

340
    Returns:
341

342
        alpha0 : float
343
            Alpha for which ``x_new = x0 + alpha * pk``.
344
        fc : int
345
            Number of function evaluations made.
346
        gc : int
347
            Number of gradient evaluations made.
348

349
    Notes:
350

351
        Uses the line search algorithm to enforce strong Wolfe
352
        conditions.  See Wright and Nocedal, 'Numerical Optimization',
353
        1999, pg. 59-60.
354

355
        For the zoom phase it uses an algorithm by [...].
356

357
    """
358

    
359
    global _ls_fc, _ls_gc, _ls_ingfk
360
    _ls_fc = 0
361
    _ls_gc = 0
362
    _ls_ingfk = None
363
    def phi(alpha):
364
        global _ls_fc
365
        _ls_fc += 1
366
        return f(xk+alpha*pk,*args)
367

    
368
    if isinstance(myfprime,type(())):
369
        def phiprime(alpha):
370
            global _ls_fc, _ls_ingfk
371
            _ls_fc += len(xk)+1
372
            eps = myfprime[1]
373
            fprime = myfprime[0]
374
            newargs = (f,eps) + args
375
            _ls_ingfk = fprime(xk+alpha*pk,*newargs)  # store for later use
376
            return np.dot(_ls_ingfk,pk)
377
    else:
378
        fprime = myfprime
379
        def phiprime(alpha):
380
            global _ls_gc, _ls_ingfk
381
            _ls_gc += 1
382
            _ls_ingfk = fprime(xk+alpha*pk,*args)  # store for later use
383
            return np.dot(_ls_ingfk,pk)
384

    
385

    
386
    alpha0 = 0
387
    phi0 = old_fval
388
    derphi0 = np.dot(gfk,pk)
389

    
390
    alpha1 = pymin(1.,1.01*2*(phi0-old_old_fval)/derphi0)
391

    
392
    if alpha1 == 0:
393
        # This shouldn't happen. Perhaps the increment has slipped below
394
        # machine precision?  For now, set the return variables skip the
395
        # useless while loop, and raise warnflag=2 due to possible imprecision.
396
        alpha_star = None
397
        fval_star = old_fval
398
        old_fval = old_old_fval
399
        fprime_star = None
400

    
401
    phi_a1 = phi(alpha1)
402
    #derphi_a1 = phiprime(alpha1)  evaluated below
403

    
404
    phi_a0 = phi0
405
    derphi_a0 = derphi0
406

    
407
    i = 1
408
    maxiter = 10
409
    while 1:         # bracketing phase
410
        if alpha1 == 0:
411
            break
412
        if (phi_a1 > phi0 + c1*alpha1*derphi0) or \
413
           ((phi_a1 >= phi_a0) and (i > 1)):
414
            alpha_star, fval_star, fprime_star = \
415
                        zoom(alpha0, alpha1, phi_a0,
416
                             phi_a1, derphi_a0, phi, phiprime,
417
                             phi0, derphi0, c1, c2)
418
            break
419

    
420
        derphi_a1 = phiprime(alpha1)
421
        if (abs(derphi_a1) <= -c2*derphi0):
422
            alpha_star = alpha1
423
            fval_star = phi_a1
424
            fprime_star = derphi_a1
425
            break
426

    
427
        if (derphi_a1 >= 0):
428
            alpha_star, fval_star, fprime_star = \
429
                        zoom(alpha1, alpha0, phi_a1,
430
                             phi_a0, derphi_a1, phi, phiprime,
431
                             phi0, derphi0, c1, c2)
432
            break
433

    
434
        alpha2 = 2 * alpha1   # increase by factor of two on each iteration
435
        i = i + 1
436
        alpha0 = alpha1
437
        alpha1 = alpha2
438
        phi_a0 = phi_a1
439
        phi_a1 = phi(alpha1)
440
        derphi_a0 = derphi_a1
441

    
442
        # stopping test if lower function not found
443
        if (i > maxiter):
444
            alpha_star = alpha1
445
            fval_star = phi_a1
446
            fprime_star = None
447
            break
448

    
449
    if fprime_star is not None:
450
        # fprime_star is a number (derphi) -- so use the most recently
451
        # calculated gradient used in computing it derphi = gfk*pk
452
        # this is the gradient at the next step no need to compute it
453
        # again in the outer loop.
454
        fprime_star = _ls_ingfk
455

    
456
    return alpha_star, _ls_fc, _ls_gc, fval_star, old_fval, fprime_star
457