Statistiques
| Révision :

root / ase / utils / linesearch.py @ 19

Historique | Voir | Annoter | Télécharger (13,96 ko)

1
import numpy as np
2
import __builtin__
3
pymin = __builtin__.min
4
pymax = __builtin__.max
5

    
6
class LineSearch:
7
    def __init__(self,  xtol=1e-14):
8

    
9
        self.xtol = xtol
10
        self.task = 'START'
11
        self.isave = np.zeros((2,), np.intc)
12
        self.dsave = np.zeros((13,), float)
13
        self.fc = 0
14
        self.gc = 0
15
        self.case = 0
16
        self.old_stp = 0
17

    
18
    def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
19
                     maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
20
                     stpmax=50., stpmin=1e-8, args=()):
21
        self.stpmin = stpmin
22
        self.pk = pk
23
        p_size = np.sqrt((pk **2).sum())
24
        self.stpmax = stpmax
25
        self.xtrapl = xtrapl
26
        self.xtrapu = xtrapu
27
        self.maxstep = maxstep
28
        phi0 = old_fval
29
        derphi0 = np.dot(gfk,pk)
30
        self.dim = len(pk)
31
        self.gms = np.sqrt(self.dim) * maxstep
32
        #alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
33
        alpha1 = 1.
34
        self.no_update = False
35

    
36
        if isinstance(myfprime,type(())):
37
            eps = myfprime[1]
38
            fprime = myfprime[0]
39
            newargs = (f,eps) + args
40
            gradient = False
41
        else:
42
            fprime = myfprime
43
            newargs = args
44
            gradient = True
45

    
46
        fval = old_fval
47
        gval = gfk
48
        self.steps=[]
49

    
50
        while 1:
51
            stp = self.step(alpha1, phi0, derphi0, c1, c2,
52
                                             self.xtol,
53
                                             self.isave, self.dsave)
54

    
55
            if self.task[:2] == 'FG':
56
                alpha1 = stp
57
                fval = func(xk + stp * pk, *args)
58
                self.fc += 1
59
                gval = fprime(xk + stp * pk, *newargs)
60
                if gradient: self.gc += 1
61
                else: self.fc += len(xk) + 1
62
                phi0 = fval
63
                derphi0 = np.dot(gval,pk)
64
                self.old_stp = alpha1
65
                if self.no_update == True:
66
                    break
67
            else:
68
                break
69

    
70
        if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
71
            stp = None  # failed
72
        return stp, fval, old_fval, self.no_update
73

    
74
    def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
75
        if self.task[:5] == 'START':
76
            # Check the input arguments for errors.
77
            if stp < self.stpmin:
78
                self.task = 'ERROR: STP .LT. minstep'
79
            if stp > self.stpmax:
80
                self.task = 'ERROR: STP .GT. maxstep'
81
            if g >= 0:
82
                self.task = 'ERROR: INITIAL G >= 0'
83
            if c1 < 0:
84
                self.task = 'ERROR: c1 .LT. 0'
85
            if c2 < 0:
86
                self.task = 'ERROR: c2 .LT. 0'
87
            if xtol < 0:
88
                self.task = 'ERROR: XTOL .LT. 0'
89
            if self.stpmin < 0:
90
                self.task = 'ERROR: minstep .LT. 0'
91
            if self.stpmax < self.stpmin:
92
                self.task = 'ERROR: maxstep .LT. minstep'
93
            if self.task[:5] == 'ERROR':
94
                return stp
95

    
96
            # Initialize local variables.
97
            self.bracket = False
98
            stage = 1
99
            finit = f
100
            ginit = g
101
            gtest = c1 * ginit
102
            width = self.stpmax - self.stpmin
103
            width1 = width / .5
104
#           The variables stx, fx, gx contain the values of the step,
105
#           function, and derivative at the best step.
106
#           The variables sty, fy, gy contain the values of the step,
107
#           function, and derivative at sty.
108
#           The variables stp, f, g contain the values of the step,
109
#           function, and derivative at stp.
110
            stx = 0
111
            fx = finit
112
            gx = ginit
113
            sty = 0
114
            fy = finit
115
            gy = ginit
116
            stmin = 0
117
            stmax = stp + self.xtrapu * stp
118
            self.task = 'FG'
119
            self.save((stage, ginit, gtest, gx,
120
                       gy, finit, fx, fy, stx, sty,
121
                       stmin, stmax, width, width1))
122
            stp = self.determine_step(stp)
123
            #return stp, f, g
124
            return stp
125
        else:
126
            if self.isave[0] == 1:
127
                self.bracket = True
128
            else:
129
                self.bracket = False
130
            stage = self.isave[1]
131
            (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, \
132
            width, width1) =self.dsave
133

    
134
#           If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
135
#           algorithm enters the second stage.
136
            ftest = finit + stp * gtest
137
            if stage == 1 and f < ftest and g >= 0.:
138
                stage = 2
139

    
140
#           Test for warnings.
141
            if self.bracket and (stp <= stmin or stp >= stmax):
142
                self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
143
            if self.bracket and stmax - stmin <= self.xtol * stmax:
144
                self.task = 'WARNING: XTOL TEST SATISFIED'
145
            if stp == self.stpmax and f <= ftest and g <= gtest:
146
                self.task = 'WARNING: STP = maxstep'
147
            if stp == self.stpmin and (f > ftest or g >= gtest):
148
                self.task = 'WARNING: STP = minstep'
149

    
150
#           Test for convergence.
151
            if f <= ftest and abs(g) <= c2 * (- ginit):
152
                self.task = 'CONVERGENCE'
153

    
154
#           Test for termination.
155
            if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
156
                self.save((stage, ginit, gtest, gx,
157
                           gy, finit, fx, fy, stx, sty,
158
                           stmin, stmax, width, width1))
159
                #return stp, f, g
160
                return stp
161

    
162
#              A modified function is used to predict the step during the
163
#              first stage if a lower function value has been obtained but
164
#              the decrease is not sufficient.
165
            #if stage == 1 and f <= fx and f > ftest:
166
#           #    Define the modified function and derivative values.
167
            #    fm =f - stp * gtest
168
            #    fxm = fx - stx * gtest
169
            #    fym = fy - sty * gtest
170
            #    gm = g - gtest
171
            #    gxm = gx - gtest
172
            #    gym = gy - gtest
173

    
174
#               Call step to update stx, sty, and to compute the new step.
175
            #    stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
176
            #                                        fym, gym, stp, fm, gm,
177
            #                                        stmin, stmax)
178

    
179
#           #    Reset the function and derivative values for f.
180

    
181
            #    fx = fxm + stx * gtest
182
            #    fy = fym + sty * gtest
183
            #    gx = gxm + gtest
184
            #    gy = gym + gtest
185

    
186
            #else:
187
#           Call step to update stx, sty, and to compute the new step.
188

    
189
            stx, sty, stp, gx, fx, gy, fy= self.update(stx, fx, gx, sty,
190
                                               fy, gy, stp, f, g,
191
                                               stmin, stmax)
192

    
193

    
194
#           Decide if a bisection step is needed.
195

    
196
            if self.bracket:
197
                if abs(sty-stx) >= .66 * width1:
198
                    stp = stx + .5 * (sty - stx)
199
                width1 = width
200
                width = abs(sty - stx)
201

    
202
#           Set the minimum and maximum steps allowed for stp.
203

    
204
            if self.bracket:
205
                stmin = min(stx, sty)
206
                stmax = max(stx, sty)
207
            else:
208
                stmin = stp + self.xtrapl * (stp - stx)
209
                stmax = stp + self.xtrapu * (stp - stx)
210

    
211
#           Force the step to be within the bounds maxstep and minstep.
212

    
213
            stp = max(stp, self.stpmin)
214
            stp = min(stp, self.stpmax)
215

    
216
            if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
217
                self.no_update = True
218
#           If further progress is not possible, let stp be the best
219
#           point obtained during the search.
220

    
221
            if (self.bracket and stp < stmin or stp >= stmax) \
222
               or (self.bracket and stmax - stmin < self.xtol * stmax):
223
                stp = stx
224

    
225
#           Obtain another function and derivative.
226

    
227
            self.task = 'FG'
228
            self.save((stage, ginit, gtest, gx,
229
                       gy, finit, fx, fy, stx, sty,
230
                       stmin, stmax, width, width1))
231
            return stp
232

    
233
    def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
234
               stpmin, stpmax):
235
        sign = gp * (gx / abs(gx))
236

    
237
#       First case: A higher function value. The minimum is bracketed.
238
#       If the cubic step is closer to stx than the quadratic step, the
239
#       cubic step is taken, otherwise the average of the cubic and
240
#       quadratic steps is taken.
241
        if fp > fx:  #case1
242
            self.case = 1
243
            theta = 3. * (fx - fp) / (stp - stx) + gx + gp
244
            s = max(abs(theta), abs(gx), abs(gp))
245
            gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
246
            if stp < stx:
247
                gamma = -gamma
248
            p = (gamma - gx) + theta
249
            q = ((gamma - gx) + gamma) + gp
250
            r = p / q
251
            stpc = stx + r * (stp - stx)
252
            stpq = stx + ((gx / ((fx - fp) / (stp-stx) + gx)) / 2.) \
253
                   * (stp - stx)
254
            if (abs(stpc - stx) < abs(stpq - stx)):
255
               stpf = stpc
256
            else:
257
               stpf = stpc + (stpq - stpc) / 2.
258

    
259
            self.bracket = True
260

    
261
#       Second case: A lower function value and derivatives of opposite
262
#       sign. The minimum is bracketed. If the cubic step is farther from
263
#       stp than the secant step, the cubic step is taken, otherwise the
264
#       secant step is taken.
265

    
266
        elif sign < 0:  #case2
267
            self.case = 2
268
            theta = 3. * (fx - fp) / (stp - stx) + gx + gp
269
            s = max(abs(theta), abs(gx), abs(gp))
270
            gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
271
            if stp > stx:
272
                 gamma = -gamma
273
            p = (gamma - gp) + theta
274
            q = ((gamma - gp) + gamma) + gx
275
            r = p / q
276
            stpc = stp + r * (stx - stp)
277
            stpq = stp + (gp / (gp - gx)) * (stx - stp)
278
            if (abs(stpc - stp) > abs(stpq - stp)):
279
               stpf = stpc
280
            else:
281
               stpf = stpq
282
            self.bracket = True
283

    
284
#       Third case: A lower function value, derivatives of the same sign,
285
#       and the magnitude of the derivative decreases.
286

    
287
        elif abs(gp) < abs(gx):  #case3
288
            self.case = 3
289
#           The cubic step is computed only if the cubic tends to infinity
290
#           in the direction of the step or if the minimum of the cubic
291
#           is beyond stp. Otherwise the cubic step is defined to be the
292
#           secant step.
293

    
294
            theta = 3. * (fx - fp) / (stp - stx) + gx + gp
295
            s = max(abs(theta), abs(gx), abs(gp))
296

    
297
#           The case gamma = 0 only arises if the cubic does not tend
298
#           to infinity in the direction of the step.
299

    
300
            gamma = s * np.sqrt(max(0.,(theta / s) ** 2-(gx / s) * (gp / s)))
301
            if stp > stx:
302
                gamma = -gamma
303
            p = (gamma - gp) + theta
304
            q = (gamma + (gx - gp)) + gamma
305
            r = p / q
306
            if r < 0. and gamma != 0:
307
               stpc = stp + r * (stx - stp)
308
            elif stp > stx:
309
               stpc = stpmax
310
            else:
311
               stpc = stpmin
312
            stpq = stp + (gp / (gp - gx)) * (stx - stp)
313

    
314
            if self.bracket:
315

    
316
#               A minimizer has been bracketed. If the cubic step is
317
#               closer to stp than the secant step, the cubic step is
318
#               taken, otherwise the secant step is taken.
319

    
320
                if abs(stpc - stp) < abs(stpq - stp):
321
                    stpf = stpc
322
                else:
323
                    stpf = stpq
324
                if stp > stx:
325
                    stpf = min(stp + .66 * (sty - stp), stpf)
326
                else:
327
                    stpf = max(stp + .66 * (sty - stp), stpf)
328
            else:
329

    
330
#               A minimizer has not been bracketed. If the cubic step is
331
#               farther from stp than the secant step, the cubic step is
332
#               taken, otherwise the secant step is taken.
333

    
334
                if abs(stpc - stp) > abs(stpq - stp):
335
                   stpf = stpc
336
                else:
337
                   stpf = stpq
338
                stpf = min(stpmax, stpf)
339
                stpf = max(stpmin, stpf)
340

    
341
#       Fourth case: A lower function value, derivatives of the same sign,
342
#       and the magnitude of the derivative does not decrease. If the
343
#       minimum is not bracketed, the step is either minstep or maxstep,
344
#       otherwise the cubic step is taken.
345

    
346
        else:  #case4
347
            self.case = 4
348
            if self.bracket:
349
                theta = 3. * (fp - fy) / (sty - stp) + gy + gp
350
                s = max(abs(theta), abs(gy), abs(gp))
351
                gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
352
                if stp > sty:
353
                    gamma = -gamma
354
                p = (gamma - gp) + theta
355
                q = ((gamma - gp) + gamma) + gy
356
                r = p / q
357
                stpc = stp + r * (sty - stp)
358
                stpf = stpc
359
            elif stp > stx:
360
                stpf = stpmax
361
            else:
362
                stpf = stpmin
363

    
364
#       Update the interval which contains a minimizer.
365

    
366
        if fp > fx:
367
            sty = stp
368
            fy = fp
369
            gy = gp
370
        else:
371
            if sign < 0:
372
                sty = stx
373
                fy = fx
374
                gy = gx
375
            stx = stp
376
            fx = fp
377
            gx = gp
378
#       Compute the new step.
379

    
380
        stp = self.determine_step(stpf)
381

    
382
        return stx, sty, stp, gx, fx, gy, fy
383

    
384
    def determine_step(self, stp):
385
        dr = stp - self.old_stp
386
        if abs(pymax(self.pk) * dr) > self.maxstep:
387
            dr /= abs((pymax(self.pk) * dr) / self.maxstep)
388
        stp = self.old_stp + dr
389
        return stp
390

    
391
    def save(self, data):
392
        if self.bracket:
393
            self.isave[0] = 1
394
        else:
395
            self.isave[0] = 0
396
        self.isave[1] = data[0]
397
        self.dsave = data[1:]