Statistiques
| Révision :

root / ase / utils / linesearch.py @ 18

Historique | Voir | Annoter | Télécharger (13,96 ko)

1 1 tkerber
import numpy as np
2 1 tkerber
import __builtin__
3 1 tkerber
pymin = __builtin__.min
4 1 tkerber
pymax = __builtin__.max
5 1 tkerber
6 1 tkerber
class LineSearch:
7 1 tkerber
    def __init__(self,  xtol=1e-14):
8 1 tkerber
9 1 tkerber
        self.xtol = xtol
10 1 tkerber
        self.task = 'START'
11 1 tkerber
        self.isave = np.zeros((2,), np.intc)
12 1 tkerber
        self.dsave = np.zeros((13,), float)
13 1 tkerber
        self.fc = 0
14 1 tkerber
        self.gc = 0
15 1 tkerber
        self.case = 0
16 1 tkerber
        self.old_stp = 0
17 1 tkerber
18 1 tkerber
    def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
19 1 tkerber
                     maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
20 1 tkerber
                     stpmax=50., stpmin=1e-8, args=()):
21 1 tkerber
        self.stpmin = stpmin
22 1 tkerber
        self.pk = pk
23 1 tkerber
        p_size = np.sqrt((pk **2).sum())
24 1 tkerber
        self.stpmax = stpmax
25 1 tkerber
        self.xtrapl = xtrapl
26 1 tkerber
        self.xtrapu = xtrapu
27 1 tkerber
        self.maxstep = maxstep
28 1 tkerber
        phi0 = old_fval
29 1 tkerber
        derphi0 = np.dot(gfk,pk)
30 1 tkerber
        self.dim = len(pk)
31 1 tkerber
        self.gms = np.sqrt(self.dim) * maxstep
32 1 tkerber
        #alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
33 1 tkerber
        alpha1 = 1.
34 1 tkerber
        self.no_update = False
35 1 tkerber
36 1 tkerber
        if isinstance(myfprime,type(())):
37 1 tkerber
            eps = myfprime[1]
38 1 tkerber
            fprime = myfprime[0]
39 1 tkerber
            newargs = (f,eps) + args
40 1 tkerber
            gradient = False
41 1 tkerber
        else:
42 1 tkerber
            fprime = myfprime
43 1 tkerber
            newargs = args
44 1 tkerber
            gradient = True
45 1 tkerber
46 1 tkerber
        fval = old_fval
47 1 tkerber
        gval = gfk
48 1 tkerber
        self.steps=[]
49 1 tkerber
50 1 tkerber
        while 1:
51 1 tkerber
            stp = self.step(alpha1, phi0, derphi0, c1, c2,
52 1 tkerber
                                             self.xtol,
53 1 tkerber
                                             self.isave, self.dsave)
54 1 tkerber
55 1 tkerber
            if self.task[:2] == 'FG':
56 1 tkerber
                alpha1 = stp
57 1 tkerber
                fval = func(xk + stp * pk, *args)
58 1 tkerber
                self.fc += 1
59 1 tkerber
                gval = fprime(xk + stp * pk, *newargs)
60 1 tkerber
                if gradient: self.gc += 1
61 1 tkerber
                else: self.fc += len(xk) + 1
62 1 tkerber
                phi0 = fval
63 1 tkerber
                derphi0 = np.dot(gval,pk)
64 1 tkerber
                self.old_stp = alpha1
65 1 tkerber
                if self.no_update == True:
66 1 tkerber
                    break
67 1 tkerber
            else:
68 1 tkerber
                break
69 1 tkerber
70 1 tkerber
        if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
71 1 tkerber
            stp = None  # failed
72 1 tkerber
        return stp, fval, old_fval, self.no_update
73 1 tkerber
74 1 tkerber
    def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
75 1 tkerber
        if self.task[:5] == 'START':
76 1 tkerber
            # Check the input arguments for errors.
77 1 tkerber
            if stp < self.stpmin:
78 1 tkerber
                self.task = 'ERROR: STP .LT. minstep'
79 1 tkerber
            if stp > self.stpmax:
80 1 tkerber
                self.task = 'ERROR: STP .GT. maxstep'
81 1 tkerber
            if g >= 0:
82 1 tkerber
                self.task = 'ERROR: INITIAL G >= 0'
83 1 tkerber
            if c1 < 0:
84 1 tkerber
                self.task = 'ERROR: c1 .LT. 0'
85 1 tkerber
            if c2 < 0:
86 1 tkerber
                self.task = 'ERROR: c2 .LT. 0'
87 1 tkerber
            if xtol < 0:
88 1 tkerber
                self.task = 'ERROR: XTOL .LT. 0'
89 1 tkerber
            if self.stpmin < 0:
90 1 tkerber
                self.task = 'ERROR: minstep .LT. 0'
91 1 tkerber
            if self.stpmax < self.stpmin:
92 1 tkerber
                self.task = 'ERROR: maxstep .LT. minstep'
93 1 tkerber
            if self.task[:5] == 'ERROR':
94 1 tkerber
                return stp
95 1 tkerber
96 1 tkerber
            # Initialize local variables.
97 1 tkerber
            self.bracket = False
98 1 tkerber
            stage = 1
99 1 tkerber
            finit = f
100 1 tkerber
            ginit = g
101 1 tkerber
            gtest = c1 * ginit
102 1 tkerber
            width = self.stpmax - self.stpmin
103 1 tkerber
            width1 = width / .5
104 1 tkerber
#           The variables stx, fx, gx contain the values of the step,
105 1 tkerber
#           function, and derivative at the best step.
106 1 tkerber
#           The variables sty, fy, gy contain the values of the step,
107 1 tkerber
#           function, and derivative at sty.
108 1 tkerber
#           The variables stp, f, g contain the values of the step,
109 1 tkerber
#           function, and derivative at stp.
110 1 tkerber
            stx = 0
111 1 tkerber
            fx = finit
112 1 tkerber
            gx = ginit
113 1 tkerber
            sty = 0
114 1 tkerber
            fy = finit
115 1 tkerber
            gy = ginit
116 1 tkerber
            stmin = 0
117 1 tkerber
            stmax = stp + self.xtrapu * stp
118 1 tkerber
            self.task = 'FG'
119 1 tkerber
            self.save((stage, ginit, gtest, gx,
120 1 tkerber
                       gy, finit, fx, fy, stx, sty,
121 1 tkerber
                       stmin, stmax, width, width1))
122 1 tkerber
            stp = self.determine_step(stp)
123 1 tkerber
            #return stp, f, g
124 1 tkerber
            return stp
125 1 tkerber
        else:
126 1 tkerber
            if self.isave[0] == 1:
127 1 tkerber
                self.bracket = True
128 1 tkerber
            else:
129 1 tkerber
                self.bracket = False
130 1 tkerber
            stage = self.isave[1]
131 1 tkerber
            (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, \
132 1 tkerber
            width, width1) =self.dsave
133 1 tkerber
134 1 tkerber
#           If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
135 1 tkerber
#           algorithm enters the second stage.
136 1 tkerber
            ftest = finit + stp * gtest
137 1 tkerber
            if stage == 1 and f < ftest and g >= 0.:
138 1 tkerber
                stage = 2
139 1 tkerber
140 1 tkerber
#           Test for warnings.
141 1 tkerber
            if self.bracket and (stp <= stmin or stp >= stmax):
142 1 tkerber
                self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
143 1 tkerber
            if self.bracket and stmax - stmin <= self.xtol * stmax:
144 1 tkerber
                self.task = 'WARNING: XTOL TEST SATISFIED'
145 1 tkerber
            if stp == self.stpmax and f <= ftest and g <= gtest:
146 1 tkerber
                self.task = 'WARNING: STP = maxstep'
147 1 tkerber
            if stp == self.stpmin and (f > ftest or g >= gtest):
148 1 tkerber
                self.task = 'WARNING: STP = minstep'
149 1 tkerber
150 1 tkerber
#           Test for convergence.
151 1 tkerber
            if f <= ftest and abs(g) <= c2 * (- ginit):
152 1 tkerber
                self.task = 'CONVERGENCE'
153 1 tkerber
154 1 tkerber
#           Test for termination.
155 1 tkerber
            if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
156 1 tkerber
                self.save((stage, ginit, gtest, gx,
157 1 tkerber
                           gy, finit, fx, fy, stx, sty,
158 1 tkerber
                           stmin, stmax, width, width1))
159 1 tkerber
                #return stp, f, g
160 1 tkerber
                return stp
161 1 tkerber
162 1 tkerber
#              A modified function is used to predict the step during the
163 1 tkerber
#              first stage if a lower function value has been obtained but
164 1 tkerber
#              the decrease is not sufficient.
165 1 tkerber
            #if stage == 1 and f <= fx and f > ftest:
166 1 tkerber
#           #    Define the modified function and derivative values.
167 1 tkerber
            #    fm =f - stp * gtest
168 1 tkerber
            #    fxm = fx - stx * gtest
169 1 tkerber
            #    fym = fy - sty * gtest
170 1 tkerber
            #    gm = g - gtest
171 1 tkerber
            #    gxm = gx - gtest
172 1 tkerber
            #    gym = gy - gtest
173 1 tkerber
174 1 tkerber
#               Call step to update stx, sty, and to compute the new step.
175 1 tkerber
            #    stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
176 1 tkerber
            #                                        fym, gym, stp, fm, gm,
177 1 tkerber
            #                                        stmin, stmax)
178 1 tkerber
179 1 tkerber
#           #    Reset the function and derivative values for f.
180 1 tkerber
181 1 tkerber
            #    fx = fxm + stx * gtest
182 1 tkerber
            #    fy = fym + sty * gtest
183 1 tkerber
            #    gx = gxm + gtest
184 1 tkerber
            #    gy = gym + gtest
185 1 tkerber
186 1 tkerber
            #else:
187 1 tkerber
#           Call step to update stx, sty, and to compute the new step.
188 1 tkerber
189 1 tkerber
            stx, sty, stp, gx, fx, gy, fy= self.update(stx, fx, gx, sty,
190 1 tkerber
                                               fy, gy, stp, f, g,
191 1 tkerber
                                               stmin, stmax)
192 1 tkerber
193 1 tkerber
194 1 tkerber
#           Decide if a bisection step is needed.
195 1 tkerber
196 1 tkerber
            if self.bracket:
197 1 tkerber
                if abs(sty-stx) >= .66 * width1:
198 1 tkerber
                    stp = stx + .5 * (sty - stx)
199 1 tkerber
                width1 = width
200 1 tkerber
                width = abs(sty - stx)
201 1 tkerber
202 1 tkerber
#           Set the minimum and maximum steps allowed for stp.
203 1 tkerber
204 1 tkerber
            if self.bracket:
205 1 tkerber
                stmin = min(stx, sty)
206 1 tkerber
                stmax = max(stx, sty)
207 1 tkerber
            else:
208 1 tkerber
                stmin = stp + self.xtrapl * (stp - stx)
209 1 tkerber
                stmax = stp + self.xtrapu * (stp - stx)
210 1 tkerber
211 1 tkerber
#           Force the step to be within the bounds maxstep and minstep.
212 1 tkerber
213 1 tkerber
            stp = max(stp, self.stpmin)
214 1 tkerber
            stp = min(stp, self.stpmax)
215 1 tkerber
216 1 tkerber
            if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
217 1 tkerber
                self.no_update = True
218 1 tkerber
#           If further progress is not possible, let stp be the best
219 1 tkerber
#           point obtained during the search.
220 1 tkerber
221 1 tkerber
            if (self.bracket and stp < stmin or stp >= stmax) \
222 1 tkerber
               or (self.bracket and stmax - stmin < self.xtol * stmax):
223 1 tkerber
                stp = stx
224 1 tkerber
225 1 tkerber
#           Obtain another function and derivative.
226 1 tkerber
227 1 tkerber
            self.task = 'FG'
228 1 tkerber
            self.save((stage, ginit, gtest, gx,
229 1 tkerber
                       gy, finit, fx, fy, stx, sty,
230 1 tkerber
                       stmin, stmax, width, width1))
231 1 tkerber
            return stp
232 1 tkerber
233 1 tkerber
    def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
234 1 tkerber
               stpmin, stpmax):
235 1 tkerber
        sign = gp * (gx / abs(gx))
236 1 tkerber
237 1 tkerber
#       First case: A higher function value. The minimum is bracketed.
238 1 tkerber
#       If the cubic step is closer to stx than the quadratic step, the
239 1 tkerber
#       cubic step is taken, otherwise the average of the cubic and
240 1 tkerber
#       quadratic steps is taken.
241 1 tkerber
        if fp > fx:  #case1
242 1 tkerber
            self.case = 1
243 1 tkerber
            theta = 3. * (fx - fp) / (stp - stx) + gx + gp
244 1 tkerber
            s = max(abs(theta), abs(gx), abs(gp))
245 1 tkerber
            gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
246 1 tkerber
            if stp < stx:
247 1 tkerber
                gamma = -gamma
248 1 tkerber
            p = (gamma - gx) + theta
249 1 tkerber
            q = ((gamma - gx) + gamma) + gp
250 1 tkerber
            r = p / q
251 1 tkerber
            stpc = stx + r * (stp - stx)
252 1 tkerber
            stpq = stx + ((gx / ((fx - fp) / (stp-stx) + gx)) / 2.) \
253 1 tkerber
                   * (stp - stx)
254 1 tkerber
            if (abs(stpc - stx) < abs(stpq - stx)):
255 1 tkerber
               stpf = stpc
256 1 tkerber
            else:
257 1 tkerber
               stpf = stpc + (stpq - stpc) / 2.
258 1 tkerber
259 1 tkerber
            self.bracket = True
260 1 tkerber
261 1 tkerber
#       Second case: A lower function value and derivatives of opposite
262 1 tkerber
#       sign. The minimum is bracketed. If the cubic step is farther from
263 1 tkerber
#       stp than the secant step, the cubic step is taken, otherwise the
264 1 tkerber
#       secant step is taken.
265 1 tkerber
266 1 tkerber
        elif sign < 0:  #case2
267 1 tkerber
            self.case = 2
268 1 tkerber
            theta = 3. * (fx - fp) / (stp - stx) + gx + gp
269 1 tkerber
            s = max(abs(theta), abs(gx), abs(gp))
270 1 tkerber
            gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
271 1 tkerber
            if stp > stx:
272 1 tkerber
                 gamma = -gamma
273 1 tkerber
            p = (gamma - gp) + theta
274 1 tkerber
            q = ((gamma - gp) + gamma) + gx
275 1 tkerber
            r = p / q
276 1 tkerber
            stpc = stp + r * (stx - stp)
277 1 tkerber
            stpq = stp + (gp / (gp - gx)) * (stx - stp)
278 1 tkerber
            if (abs(stpc - stp) > abs(stpq - stp)):
279 1 tkerber
               stpf = stpc
280 1 tkerber
            else:
281 1 tkerber
               stpf = stpq
282 1 tkerber
            self.bracket = True
283 1 tkerber
284 1 tkerber
#       Third case: A lower function value, derivatives of the same sign,
285 1 tkerber
#       and the magnitude of the derivative decreases.
286 1 tkerber
287 1 tkerber
        elif abs(gp) < abs(gx):  #case3
288 1 tkerber
            self.case = 3
289 1 tkerber
#           The cubic step is computed only if the cubic tends to infinity
290 1 tkerber
#           in the direction of the step or if the minimum of the cubic
291 1 tkerber
#           is beyond stp. Otherwise the cubic step is defined to be the
292 1 tkerber
#           secant step.
293 1 tkerber
294 1 tkerber
            theta = 3. * (fx - fp) / (stp - stx) + gx + gp
295 1 tkerber
            s = max(abs(theta), abs(gx), abs(gp))
296 1 tkerber
297 1 tkerber
#           The case gamma = 0 only arises if the cubic does not tend
298 1 tkerber
#           to infinity in the direction of the step.
299 1 tkerber
300 1 tkerber
            gamma = s * np.sqrt(max(0.,(theta / s) ** 2-(gx / s) * (gp / s)))
301 1 tkerber
            if stp > stx:
302 1 tkerber
                gamma = -gamma
303 1 tkerber
            p = (gamma - gp) + theta
304 1 tkerber
            q = (gamma + (gx - gp)) + gamma
305 1 tkerber
            r = p / q
306 1 tkerber
            if r < 0. and gamma != 0:
307 1 tkerber
               stpc = stp + r * (stx - stp)
308 1 tkerber
            elif stp > stx:
309 1 tkerber
               stpc = stpmax
310 1 tkerber
            else:
311 1 tkerber
               stpc = stpmin
312 1 tkerber
            stpq = stp + (gp / (gp - gx)) * (stx - stp)
313 1 tkerber
314 1 tkerber
            if self.bracket:
315 1 tkerber
316 1 tkerber
#               A minimizer has been bracketed. If the cubic step is
317 1 tkerber
#               closer to stp than the secant step, the cubic step is
318 1 tkerber
#               taken, otherwise the secant step is taken.
319 1 tkerber
320 1 tkerber
                if abs(stpc - stp) < abs(stpq - stp):
321 1 tkerber
                    stpf = stpc
322 1 tkerber
                else:
323 1 tkerber
                    stpf = stpq
324 1 tkerber
                if stp > stx:
325 1 tkerber
                    stpf = min(stp + .66 * (sty - stp), stpf)
326 1 tkerber
                else:
327 1 tkerber
                    stpf = max(stp + .66 * (sty - stp), stpf)
328 1 tkerber
            else:
329 1 tkerber
330 1 tkerber
#               A minimizer has not been bracketed. If the cubic step is
331 1 tkerber
#               farther from stp than the secant step, the cubic step is
332 1 tkerber
#               taken, otherwise the secant step is taken.
333 1 tkerber
334 1 tkerber
                if abs(stpc - stp) > abs(stpq - stp):
335 1 tkerber
                   stpf = stpc
336 1 tkerber
                else:
337 1 tkerber
                   stpf = stpq
338 1 tkerber
                stpf = min(stpmax, stpf)
339 1 tkerber
                stpf = max(stpmin, stpf)
340 1 tkerber
341 1 tkerber
#       Fourth case: A lower function value, derivatives of the same sign,
342 1 tkerber
#       and the magnitude of the derivative does not decrease. If the
343 1 tkerber
#       minimum is not bracketed, the step is either minstep or maxstep,
344 1 tkerber
#       otherwise the cubic step is taken.
345 1 tkerber
346 1 tkerber
        else:  #case4
347 1 tkerber
            self.case = 4
348 1 tkerber
            if self.bracket:
349 1 tkerber
                theta = 3. * (fp - fy) / (sty - stp) + gy + gp
350 1 tkerber
                s = max(abs(theta), abs(gy), abs(gp))
351 1 tkerber
                gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
352 1 tkerber
                if stp > sty:
353 1 tkerber
                    gamma = -gamma
354 1 tkerber
                p = (gamma - gp) + theta
355 1 tkerber
                q = ((gamma - gp) + gamma) + gy
356 1 tkerber
                r = p / q
357 1 tkerber
                stpc = stp + r * (sty - stp)
358 1 tkerber
                stpf = stpc
359 1 tkerber
            elif stp > stx:
360 1 tkerber
                stpf = stpmax
361 1 tkerber
            else:
362 1 tkerber
                stpf = stpmin
363 1 tkerber
364 1 tkerber
#       Update the interval which contains a minimizer.
365 1 tkerber
366 1 tkerber
        if fp > fx:
367 1 tkerber
            sty = stp
368 1 tkerber
            fy = fp
369 1 tkerber
            gy = gp
370 1 tkerber
        else:
371 1 tkerber
            if sign < 0:
372 1 tkerber
                sty = stx
373 1 tkerber
                fy = fx
374 1 tkerber
                gy = gx
375 1 tkerber
            stx = stp
376 1 tkerber
            fx = fp
377 1 tkerber
            gx = gp
378 1 tkerber
#       Compute the new step.
379 1 tkerber
380 1 tkerber
        stp = self.determine_step(stpf)
381 1 tkerber
382 1 tkerber
        return stx, sty, stp, gx, fx, gy, fy
383 1 tkerber
384 1 tkerber
    def determine_step(self, stp):
385 1 tkerber
        dr = stp - self.old_stp
386 1 tkerber
        if abs(pymax(self.pk) * dr) > self.maxstep:
387 1 tkerber
            dr /= abs((pymax(self.pk) * dr) / self.maxstep)
388 1 tkerber
        stp = self.old_stp + dr
389 1 tkerber
        return stp
390 1 tkerber
391 1 tkerber
    def save(self, data):
392 1 tkerber
        if self.bracket:
393 1 tkerber
            self.isave[0] = 1
394 1 tkerber
        else:
395 1 tkerber
            self.isave[0] = 0
396 1 tkerber
        self.isave[1] = data[0]
397 1 tkerber
        self.dsave = data[1:]