Statistiques
| Révision :

root / ase / optimize / fmin_bfgs.py

Historique | Voir | Annoter | Télécharger (15,47 ko)

1
#__docformat__ = "restructuredtext en"
2
# ******NOTICE***************
3
# optimize.py module by Travis E. Oliphant
4
#
5
# You may copy and use this module as you see fit with no
6
# guarantee implied provided you keep this notice in all copies.
7
# *****END NOTICE************
8

    
9
import numpy
10
from numpy import atleast_1d, eye, mgrid, argmin, zeros, shape, empty, \
11
     squeeze, vectorize, asarray, absolute, sqrt, Inf, asfarray, isinf
12
from ase.utils.linesearch import LineSearch
13

    
14
# These have been copied from Numeric's MLab.py
15
# I don't think they made the transition to scipy_core
16

    
17
# Copied and modified from scipy_optimize
18
abs = absolute
19
import __builtin__
20
pymin = __builtin__.min
21
pymax = __builtin__.max
22
__version__="0.7"
23
_epsilon = sqrt(numpy.finfo(float).eps)
24

    
25
def fmin_bfgs(f, x0, fprime=None, args=(), gtol=1e-5, norm=Inf,
26
              epsilon=_epsilon, maxiter=None, full_output=0, disp=1,
27
              retall=0, callback=None, maxstep=0.2):
28
    """Minimize a function using the BFGS algorithm.
29

30
    Parameters:
31

32
      f : callable f(x,*args)
33
          Objective function to be minimized.
34
      x0 : ndarray
35
          Initial guess.
36
      fprime : callable f'(x,*args)
37
          Gradient of f.
38
      args : tuple
39
          Extra arguments passed to f and fprime.
40
      gtol : float
41
          Gradient norm must be less than gtol before succesful termination.
42
      norm : float
43
          Order of norm (Inf is max, -Inf is min)
44
      epsilon : int or ndarray
45
          If fprime is approximated, use this value for the step size.
46
      callback : callable
47
          An optional user-supplied function to call after each
48
          iteration.  Called as callback(xk), where xk is the
49
          current parameter vector.
50

51
    Returns: (xopt, {fopt, gopt, Hopt, func_calls, grad_calls, warnflag}, <allvecs>)
52

53
        xopt : ndarray
54
            Parameters which minimize f, i.e. f(xopt) == fopt.
55
        fopt : float
56
            Minimum value.
57
        gopt : ndarray
58
            Value of gradient at minimum, f'(xopt), which should be near 0.
59
        Bopt : ndarray
60
            Value of 1/f''(xopt), i.e. the inverse hessian matrix.
61
        func_calls : int
62
            Number of function_calls made.
63
        grad_calls : int
64
            Number of gradient calls made.
65
        warnflag : integer
66
            1 : Maximum number of iterations exceeded.
67
            2 : Gradient and/or function calls not changing.
68
        allvecs  :  list
69
            Results at each iteration.  Only returned if retall is True.
70

71
    *Other Parameters*:
72
        maxiter : int
73
            Maximum number of iterations to perform.
74
        full_output : bool
75
            If True,return fopt, func_calls, grad_calls, and warnflag
76
            in addition to xopt.
77
        disp : bool
78
            Print convergence message if True.
79
        retall : bool
80
            Return a list of results at each iteration if True.
81

82
    Notes:
83

84
        Optimize the function, f, whose gradient is given by fprime
85
        using the quasi-Newton method of Broyden, Fletcher, Goldfarb,
86
        and Shanno (BFGS) See Wright, and Nocedal 'Numerical
87
        Optimization', 1999, pg. 198.
88

89
    *See Also*:
90

91
      scikits.openopt : SciKit which offers a unified syntax to call
92
                        this and other solvers.
93

94
    """
95
    x0 = asarray(x0).squeeze()
96
    if x0.ndim == 0:
97
        x0.shape = (1,)
98
    if maxiter is None:
99
        maxiter = len(x0)*200
100
    func_calls, f = wrap_function(f, args)
101
    if fprime is None:
102
        grad_calls, myfprime = wrap_function(approx_fprime, (f, epsilon))
103
    else:
104
        grad_calls, myfprime = wrap_function(fprime, args)
105
    gfk = myfprime(x0)
106
    k = 0
107
    N = len(x0)
108
    I = numpy.eye(N,dtype=int)
109
    Hk = I
110
    old_fval = f(x0)
111
    old_old_fval = old_fval + 5000
112
    xk = x0
113
    if retall:
114
        allvecs = [x0]
115
    sk = [2*gtol]
116
    warnflag = 0
117
    gnorm = vecnorm(gfk,ord=norm)
118
    while (gnorm > gtol) and (k < maxiter):
119
        pk = -numpy.dot(Hk,gfk)
120
        ls = LineSearch()
121
        alpha_k, fc, gc, old_fval, old_old_fval, gfkp1 = \
122
           ls._line_search(f,myfprime,xk,pk,gfk,
123
                                  old_fval,old_old_fval,maxstep=maxstep)
124
        if alpha_k is None:  # line search failed try different one.
125
            alpha_k, fc, gc, old_fval, old_old_fval, gfkp1 = \
126
                     line_search(f,myfprime,xk,pk,gfk,
127
                                 old_fval,old_old_fval)
128
            if alpha_k is None:
129
                # This line search also failed to find a better solution.
130
                warnflag = 2
131
                break
132
        xkp1 = xk + alpha_k * pk
133
        if retall:
134
            allvecs.append(xkp1)
135
        sk = xkp1 - xk
136
        xk = xkp1
137
        if gfkp1 is None:
138
            gfkp1 = myfprime(xkp1)
139

    
140
        yk = gfkp1 - gfk
141
        gfk = gfkp1
142
        if callback is not None:
143
            callback(xk)
144
        k += 1
145
        gnorm = vecnorm(gfk,ord=norm)
146
        if (gnorm <= gtol):
147
            break
148

    
149
        try: # this was handled in numeric, let it remaines for more safety
150
            rhok = 1.0 / (numpy.dot(yk,sk))
151
        except ZeroDivisionError:
152
            rhok = 1000.0
153
            print "Divide-by-zero encountered: rhok assumed large"
154
        if isinf(rhok): # this is patch for numpy
155
            rhok = 1000.0
156
            print "Divide-by-zero encountered: rhok assumed large"
157
        A1 = I - sk[:,numpy.newaxis] * yk[numpy.newaxis,:] * rhok
158
        A2 = I - yk[:,numpy.newaxis] * sk[numpy.newaxis,:] * rhok
159
        Hk = numpy.dot(A1,numpy.dot(Hk,A2)) + rhok * sk[:,numpy.newaxis] \
160
                 * sk[numpy.newaxis,:]
161

    
162
    if disp or full_output:
163
        fval = old_fval
164
    if warnflag == 2:
165
        if disp:
166
            print "Warning: Desired error not necessarily achieved" \
167
                  "due to precision loss"
168
            print "         Current function value: %f" % fval
169
            print "         Iterations: %d" % k
170
            print "         Function evaluations: %d" % func_calls[0]
171
            print "         Gradient evaluations: %d" % grad_calls[0]
172

    
173
    elif k >= maxiter:
174
        warnflag = 1
175
        if disp:
176
            print "Warning: Maximum number of iterations has been exceeded"
177
            print "         Current function value: %f" % fval
178
            print "         Iterations: %d" % k
179
            print "         Function evaluations: %d" % func_calls[0]
180
            print "         Gradient evaluations: %d" % grad_calls[0]
181
    else:
182
        if disp:
183
            print "Optimization terminated successfully."
184
            print "         Current function value: %f" % fval
185
            print "         Iterations: %d" % k
186
            print "         Function evaluations: %d" % func_calls[0]
187
            print "         Gradient evaluations: %d" % grad_calls[0]
188

    
189
    if full_output:
190
        retlist = xk, fval, gfk, Hk, func_calls[0], grad_calls[0], warnflag
191
        if retall:
192
            retlist += (allvecs,)
193
    else:
194
        retlist = xk
195
        if retall:
196
            retlist = (xk, allvecs)
197

    
198
    return retlist
199

    
200
def vecnorm(x, ord=2):
201
    if ord == Inf:
202
        return numpy.amax(abs(x))
203
    elif ord == -Inf:
204
        return numpy.amin(abs(x))
205
    else:
206
        return numpy.sum(abs(x)**ord,axis=0)**(1.0/ord)
207

    
208
def wrap_function(function, args):
209
    ncalls = [0]
210
    def function_wrapper(x):
211
        ncalls[0] += 1
212
        return function(x, *args)
213
    return ncalls, function_wrapper
214

    
215
def _cubicmin(a,fa,fpa,b,fb,c,fc):
216
    # finds the minimizer for a cubic polynomial that goes through the
217
    #  points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa.
218
    #
219
    # if no minimizer can be found return None
220
    #
221
    # f(x) = A *(x-a)^3 + B*(x-a)^2 + C*(x-a) + D
222

    
223
    C = fpa
224
    D = fa
225
    db = b-a
226
    dc = c-a
227
    if (db == 0) or (dc == 0) or (b==c): return None
228
    denom = (db*dc)**2 * (db-dc)
229
    d1 = empty((2,2))
230
    d1[0,0] = dc**2
231
    d1[0,1] = -db**2
232
    d1[1,0] = -dc**3
233
    d1[1,1] = db**3
234
    [A,B] = numpy.dot(d1,asarray([fb-fa-C*db,fc-fa-C*dc]).flatten())
235
    A /= denom
236
    B /= denom
237
    radical = B*B-3*A*C
238
    if radical < 0:  return None
239
    if (A == 0): return None
240
    xmin = a + (-B + sqrt(radical))/(3*A)
241
    return xmin
242

    
243
def _quadmin(a,fa,fpa,b,fb):
244
    # finds the minimizer for a quadratic polynomial that goes through
245
    #  the points (a,fa), (b,fb) with derivative at a of fpa
246
    # f(x) = B*(x-a)^2 + C*(x-a) + D
247
    D = fa
248
    C = fpa
249
    db = b-a*1.0
250
    if (db==0): return None
251
    B = (fb-D-C*db)/(db*db)
252
    if (B <= 0): return None
253
    xmin = a  - C / (2.0*B)
254
    return xmin
255

    
256
def zoom(a_lo, a_hi, phi_lo, phi_hi, derphi_lo,
257
         phi, derphi, phi0, derphi0, c1, c2):
258
    maxiter = 10
259
    i = 0
260
    delta1 = 0.2  # cubic interpolant check
261
    delta2 = 0.1  # quadratic interpolant check
262
    phi_rec = phi0
263
    a_rec = 0
264
    while 1:
265
        # interpolate to find a trial step length between a_lo and a_hi
266
        # Need to choose interpolation here.  Use cubic interpolation and then if the
267
        #  result is within delta * dalpha or outside of the interval bounded by a_lo or a_hi
268
        #  then use quadratic interpolation, if the result is still too close, then use bisection
269

    
270
        dalpha = a_hi-a_lo;
271
        if dalpha < 0: a,b = a_hi,a_lo
272
        else: a,b = a_lo, a_hi
273

    
274
        # minimizer of cubic interpolant
275
        #    (uses phi_lo, derphi_lo, phi_hi, and the most recent value of phi)
276
        #      if the result is too close to the end points (or out of the interval)
277
        #         then use quadratic interpolation with phi_lo, derphi_lo and phi_hi
278
        #      if the result is stil too close to the end points (or out of the interval)
279
        #         then use bisection
280

    
281
        if (i > 0):
282
            cchk = delta1*dalpha
283
            a_j = _cubicmin(a_lo, phi_lo, derphi_lo, a_hi, phi_hi, a_rec, phi_rec)
284
        if (i==0) or (a_j is None) or (a_j > b-cchk) or (a_j < a+cchk):
285
            qchk = delta2*dalpha
286
            a_j = _quadmin(a_lo, phi_lo, derphi_lo, a_hi, phi_hi)
287
            if (a_j is None) or (a_j > b-qchk) or (a_j < a+qchk):
288
                a_j = a_lo + 0.5*dalpha
289
#                print "Using bisection."
290
#            else: print "Using quadratic."
291
#        else: print "Using cubic."
292

    
293
        # Check new value of a_j
294

    
295
        phi_aj = phi(a_j)
296
        if (phi_aj > phi0 + c1*a_j*derphi0) or (phi_aj >= phi_lo):
297
            phi_rec = phi_hi
298
            a_rec = a_hi
299
            a_hi = a_j
300
            phi_hi = phi_aj
301
        else:
302
            derphi_aj = derphi(a_j)
303
            if abs(derphi_aj) <= -c2*derphi0:
304
                a_star = a_j
305
                val_star = phi_aj
306
                valprime_star = derphi_aj
307
                break
308
            if derphi_aj*(a_hi - a_lo) >= 0:
309
                phi_rec = phi_hi
310
                a_rec = a_hi
311
                a_hi = a_lo
312
                phi_hi = phi_lo
313
            else:
314
                phi_rec = phi_lo
315
                a_rec = a_lo
316
            a_lo = a_j
317
            phi_lo = phi_aj
318
            derphi_lo = derphi_aj
319
        i += 1
320
        if (i > maxiter):
321
            a_star = a_j
322
            val_star = phi_aj
323
            valprime_star = None
324
            break
325
    return a_star, val_star, valprime_star
326

    
327
def line_search(f, myfprime, xk, pk, gfk, old_fval, old_old_fval,
328
                args=(), c1=1e-4, c2=0.9, amax=50):
329
    """Find alpha that satisfies strong Wolfe conditions.
330

331
    Parameters:
332

333
        f : callable f(x,*args)
334
            Objective function.
335
        myfprime : callable f'(x,*args)
336
            Objective function gradient (can be None).
337
        xk : ndarray
338
            Starting point.
339
        pk : ndarray
340
            Search direction.
341
        gfk : ndarray
342
            Gradient value for x=xk (xk being the current parameter
343
            estimate).
344
        args : tuple
345
            Additional arguments passed to objective function.
346
        c1 : float
347
            Parameter for Armijo condition rule.
348
        c2 : float
349
            Parameter for curvature condition rule.
350

351
    Returns:
352

353
        alpha0 : float
354
            Alpha for which ``x_new = x0 + alpha * pk``.
355
        fc : int
356
            Number of function evaluations made.
357
        gc : int
358
            Number of gradient evaluations made.
359

360
    Notes:
361

362
        Uses the line search algorithm to enforce strong Wolfe
363
        conditions.  See Wright and Nocedal, 'Numerical Optimization',
364
        1999, pg. 59-60.
365

366
        For the zoom phase it uses an algorithm by [...].
367

368
    """
369

    
370
    global _ls_fc, _ls_gc, _ls_ingfk
371
    _ls_fc = 0
372
    _ls_gc = 0
373
    _ls_ingfk = None
374
    def phi(alpha):
375
        global _ls_fc
376
        _ls_fc += 1
377
        return f(xk+alpha*pk,*args)
378

    
379
    if isinstance(myfprime,type(())):
380
        def phiprime(alpha):
381
            global _ls_fc, _ls_ingfk
382
            _ls_fc += len(xk)+1
383
            eps = myfprime[1]
384
            fprime = myfprime[0]
385
            newargs = (f,eps) + args
386
            _ls_ingfk = fprime(xk+alpha*pk,*newargs)  # store for later use
387
            return numpy.dot(_ls_ingfk,pk)
388
    else:
389
        fprime = myfprime
390
        def phiprime(alpha):
391
            global _ls_gc, _ls_ingfk
392
            _ls_gc += 1
393
            _ls_ingfk = fprime(xk+alpha*pk,*args)  # store for later use
394
            return numpy.dot(_ls_ingfk,pk)
395

    
396
    alpha0 = 0
397
    phi0 = old_fval
398
    derphi0 = numpy.dot(gfk,pk)
399

    
400
    alpha1 = pymin(1.0,1.01*2*(phi0-old_old_fval)/derphi0)
401

    
402
    if alpha1 == 0:
403
        # This shouldn't happen. Perhaps the increment has slipped below
404
        # machine precision?  For now, set the return variables skip the
405
        # useless while loop, and raise warnflag=2 due to possible imprecision.
406
        alpha_star = None
407
        fval_star = old_fval
408
        old_fval = old_old_fval
409
        fprime_star = None
410

    
411
    phi_a1 = phi(alpha1)
412
    #derphi_a1 = phiprime(alpha1)  evaluated below
413

    
414
    phi_a0 = phi0
415
    derphi_a0 = derphi0
416

    
417
    i = 1
418
    maxiter = 10
419
    while 1:         # bracketing phase
420
        if alpha1 == 0:
421
            break
422
        if (phi_a1 > phi0 + c1*alpha1*derphi0) or \
423
           ((phi_a1 >= phi_a0) and (i > 1)):
424
            alpha_star, fval_star, fprime_star = \
425
                        zoom(alpha0, alpha1, phi_a0,
426
                             phi_a1, derphi_a0, phi, phiprime,
427
                             phi0, derphi0, c1, c2)
428
            break
429

    
430
        derphi_a1 = phiprime(alpha1)
431
        if (abs(derphi_a1) <= -c2*derphi0):
432
            alpha_star = alpha1
433
            fval_star = phi_a1
434
            fprime_star = derphi_a1
435
            break
436

    
437
        if (derphi_a1 >= 0):
438
            alpha_star, fval_star, fprime_star = \
439
                        zoom(alpha1, alpha0, phi_a1,
440
                             phi_a0, derphi_a1, phi, phiprime,
441
                             phi0, derphi0, c1, c2)
442
            break
443

    
444
        alpha2 = 2 * alpha1   # increase by factor of two on each iteration
445
        i = i + 1
446
        alpha0 = alpha1
447
        alpha1 = alpha2
448
        phi_a0 = phi_a1
449
        phi_a1 = phi(alpha1)
450
        derphi_a0 = derphi_a1
451

    
452
        # stopping test if lower function not found
453
        if (i > maxiter):
454
            alpha_star = alpha1
455
            fval_star = phi_a1
456
            fprime_star = None
457
            break
458

    
459
    if fprime_star is not None:
460
        # fprime_star is a number (derphi) -- so use the most recently
461
        # calculated gradient used in computing it derphi = gfk*pk
462
        # this is the gradient at the next step no need to compute it
463
        # again in the outer loop.
464
        fprime_star = _ls_ingfk
465

    
466
    return alpha_star, _ls_fc, _ls_gc, fval_star, old_fval, fprime_star
467

    
468
def approx_fprime(xk,f,epsilon,*args):
469
    f0 = f(*((xk,)+args))
470
    grad = numpy.zeros((len(xk),), float)
471
    ei = numpy.zeros((len(xk),), float)
472
    for k in range(len(xk)):
473
        ei[k] = epsilon
474
        grad[k] = (f(*((xk+ei,)+args)) - f0)/epsilon
475
        ei[k] = 0.0
476
    return grad
477