Statistiques
| Révision :

root / ase / visualize / vtk / data.py @ 1

Historique | Voir | Annoter | Télécharger (7,12 ko)

1

    
2
import numpy as np
3
from numpy.ctypeslib import ctypes
4

    
5
from vtk import vtkDataArray, vtkFloatArray, vtkDoubleArray
6

    
7
if ctypes is None:
8
    class CTypesEmulator:
9
        def __init__(self):
10
            self._SimpleCData = np.number
11
            self.c_float = np.float32
12
            self.c_double = np.float64
13
    try:
14
        import ctypes
15
    except ImportError:
16
        ctypes = CTypesEmulator()
17

    
18
# -------------------------------------------------------------------
19

    
20
class vtkNumPyBuffer:
21
    def __init__(self, data):
22
        self.strbuf = data.tostring()
23
        self.nitems = len(data.flat)
24

    
25
    def __len__(self):
26
        return self.nitems
27

    
28
    def get_pointer(self):
29
        # Any C/C++ method that requires a void * can be passed a Python
30
        # string. No check is done to ensure that the string is the correct
31
        # size, and the string's reference count is not incremented. Extreme
32
        # caution should be applied when using this feature.
33
        return self.strbuf
34

    
35
    def notify(self, obj, event):
36
        if event == 'DeleteEvent':
37
            del self.strbuf
38
        else:
39
            raise RuntimeError('Event not recognized.')
40

    
41
class vtkDataArrayFromNumPyBuffer:
42
    def __init__(self, vtk_class, ctype, data=None):
43

    
44
        assert issubclass(ctype, ctypes._SimpleCData)
45
        self.ctype = ctype
46

    
47
        self.vtk_da = vtk_class()
48
        assert isinstance(self.vtk_da, vtkDataArray)
49
        assert self.vtk_da.GetDataTypeSize() == np.nbytes[np.dtype(self.ctype)]
50

    
51
        if data is not None:
52
            self.read_numpy_array(data)
53

    
54
    def read_numpy_array(self, data):
55

    
56
        if not isinstance(data, np.ndarray):
57
            data = np.array(data, dtype=self.ctype)
58

    
59
        if data.dtype != self.ctype: # NB: "is not" gets it wrong
60
            data = data.astype(self.ctype)
61

    
62
        self.vtk_da.SetNumberOfComponents(data.shape[-1])
63

    
64
        # Passing the void* buffer to the C interface does not increase
65
        # its reference count, hence the buffer is deleted by Python when
66
        # the reference count of the string from tostring reaches zero.
67
        # Also, the boolean True tells VTK to save (not delete) the buffer
68
        # when the VTK data array is deleted - we want Python to do this.
69
        npybuf = vtkNumPyBuffer(data)
70
        self.vtk_da.SetVoidArray(npybuf.get_pointer(), len(npybuf), True)
71
        self.vtk_da.AddObserver('DeleteEvent', npybuf.notify)
72

    
73
    def get_output(self):
74
        return self.vtk_da
75

    
76
    def copy(self):
77
        vtk_da_copy = self.vtk_da.NewInstance()
78
        vtk_da_copy.SetNumberOfComponents(self.vtk_da.GetNumberOfComponents())
79
        vtk_da_copy.SetNumberOfTuples(self.vtk_da.GetNumberOfTuples())
80

    
81
        assert vtk_da_copy.GetSize() == self.vtk_da.GetSize()
82

    
83
        vtk_da_copy.DeepCopy(self.vtk_da)
84

    
85
        return vtk_da_copy
86

    
87
# -------------------------------------------------------------------
88

    
89
class vtkDataArrayFromNumPyArray(vtkDataArrayFromNumPyBuffer):
90
    """Class for reading vtkDataArray from 1D or 2D NumPy array.
91

92
    This class can be used to generate a vtkDataArray from a NumPy array.
93
    The NumPy array should be of the form <entries> x <number of components>
94
    where 'number of components' indicates the number of components in 
95
    each entry in the vtkDataArray. Note that this form is also expected
96
    even in the case of only a single component.
97
    """
98
    def __init__(self, vtk_class, ctype, data=None, buffered=True):
99

    
100
        self.buffered = buffered
101

    
102
        vtkDataArrayFromNumPyBuffer.__init__(self, vtk_class, ctype, data)
103

    
104
    def read_numpy_array(self, data):
105
        """Read vtkDataArray from NumPy array"""
106

    
107
        if not isinstance(data, np.ndarray):
108
            data = np.array(data, dtype=self.ctype)
109

    
110
        if data.dtype != self.ctype: # NB: "is not" gets it wrong
111
            data = data.astype(self.ctype)
112

    
113
        if data.ndim == 1:
114
            data = data[:, np.newaxis]
115
        elif data.ndim != 2:
116
            raise ValueError('Data must be a 1D or 2D NumPy array.')
117

    
118
        if self.buffered:
119
            vtkDataArrayFromNumPyBuffer.read_numpy_array(self, data)
120
        else:
121
            self.vtk_da.SetNumberOfComponents(data.shape[-1])
122
            self.vtk_da.SetNumberOfTuples(data.shape[0])
123

    
124
            for i, d_c in enumerate(data):
125
                for c, d in enumerate(d_c):
126
                    self.vtk_da.SetComponent(i, c, d)
127

    
128
class vtkFloatArrayFromNumPyArray(vtkDataArrayFromNumPyArray):
129
    def __init__(self, data):
130
        vtkDataArrayFromNumPyArray.__init__(self, vtkFloatArray,
131
                                            ctypes.c_float, data)
132

    
133
class vtkDoubleArrayFromNumPyArray(vtkDataArrayFromNumPyArray):
134
    def __init__(self, data):
135
        vtkDataArrayFromNumPyArray.__init__(self, vtkDoubleArray,
136
                                            ctypes.c_double, data)
137

    
138
# -------------------------------------------------------------------
139

    
140
class vtkDataArrayFromNumPyMultiArray(vtkDataArrayFromNumPyBuffer):
141
    """Class for reading vtkDataArray from a multi-dimensional NumPy array.
142

143
    This class can be used to generate a vtkDataArray from a NumPy array.
144
    The NumPy array should be of the form <gridsize> x <number of components>
145
    where 'number of components' indicates the number of components in 
146
    each gridpoint in the vtkDataArray. Note that this form is also expected
147
    even in the case of only a single component.
148
    """
149
    def __init__(self, vtk_class, ctype, data=None, buffered=True):
150

    
151
        self.buffered = buffered
152

    
153
        vtkDataArrayFromNumPyBuffer.__init__(self, vtk_class, ctype, data)
154

    
155
    def read_numpy_array(self, data):
156
        """Read vtkDataArray from NumPy array"""
157

    
158
        if not isinstance(data, np.ndarray):
159
            data = np.array(data, dtype=self.ctype)
160

    
161
        if data.dtype != self.ctype: # NB: "is not" gets it wrong
162
            data = data.astype(self.ctype)
163

    
164
        if data.ndim <=2:
165
            raise Warning('This is inefficient for 1D and 2D NumPy arrays. ' +
166
                          'Use a vtkDataArrayFromNumPyArray subclass instead.')
167

    
168
        if self.buffered:
169
            # This is less than ideal, but will not copy data (uses views).
170
            # To get the correct ordering, the grid dimensions have to be
171
            # transposed without moving the last dimension (the components).
172
            n = data.ndim-1
173
            for c in range(n//2):
174
                data = data.swapaxes(c,n-1-c)
175

    
176
            vtkDataArrayFromNumPyBuffer.read_numpy_array(self, data)
177
        else:
178
            self.vtk_da.SetNumberOfComponents(data.shape[-1])
179
            self.vtk_da.SetNumberOfTuples(np.prod(data.shape[:-1]))
180

    
181
            for c, d_T in enumerate(data.T):
182
                for i, d in enumerate(d_T.flat):
183
                    self.vtk_da.SetComponent(i, c, d)
184

    
185
class vtkFloatArrayFromNumPyMultiArray(vtkDataArrayFromNumPyMultiArray):
186
    def __init__(self, data):
187
        vtkDataArrayFromNumPyMultiArray.__init__(self, vtkFloatArray,
188
                                                 ctypes.c_float, data)
189

    
190
class vtkDoubleArrayFromNumPyMultiArray(vtkDataArrayFromNumPyMultiArray):
191
    def __init__(self, data):
192
        vtkDataArrayFromNumPyMultiArray.__init__(self, vtkDoubleArray,
193
                                                 ctypes.c_double, data)
194