diff options
author | The Android Open Source Project <initial-contribution@android.com> | 2008-10-21 07:00:00 -0700 |
---|---|---|
committer | The Android Open Source Project <initial-contribution@android.com> | 2008-10-21 07:00:00 -0700 |
commit | cf31fe9b4fb650b27e19f5d7ee7297e383660caf (patch) | |
tree | d04ca6a45d579dca5e5469606c48c405aee68f4b /froofle/protobuf/reflection.py | |
download | git-repo-cf31fe9b4fb650b27e19f5d7ee7297e383660caf.tar.gz |
Initial Contributionv1.0
Diffstat (limited to 'froofle/protobuf/reflection.py')
-rw-r--r-- | froofle/protobuf/reflection.py | 1653 |
1 files changed, 1653 insertions, 0 deletions
diff --git a/froofle/protobuf/reflection.py b/froofle/protobuf/reflection.py new file mode 100644 index 00000000..e2abff04 --- /dev/null +++ b/froofle/protobuf/reflection.py | |||
@@ -0,0 +1,1653 @@ | |||
1 | # Protocol Buffers - Google's data interchange format | ||
2 | # Copyright 2008 Google Inc. All rights reserved. | ||
3 | # http://code.google.com/p/protobuf/ | ||
4 | # | ||
5 | # Redistribution and use in source and binary forms, with or without | ||
6 | # modification, are permitted provided that the following conditions are | ||
7 | # met: | ||
8 | # | ||
9 | # * Redistributions of source code must retain the above copyright | ||
10 | # notice, this list of conditions and the following disclaimer. | ||
11 | # * Redistributions in binary form must reproduce the above | ||
12 | # copyright notice, this list of conditions and the following disclaimer | ||
13 | # in the documentation and/or other materials provided with the | ||
14 | # distribution. | ||
15 | # * Neither the name of Google Inc. nor the names of its | ||
16 | # contributors may be used to endorse or promote products derived from | ||
17 | # this software without specific prior written permission. | ||
18 | # | ||
19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||
22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||
23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||
24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||
25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||
26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
30 | |||
31 | # This code is meant to work on Python 2.4 and above only. | ||
32 | # | ||
33 | # TODO(robinson): Helpers for verbose, common checks like seeing if a | ||
34 | # descriptor's cpp_type is CPPTYPE_MESSAGE. | ||
35 | |||
36 | """Contains a metaclass and helper functions used to create | ||
37 | protocol message classes from Descriptor objects at runtime. | ||
38 | |||
39 | Recall that a metaclass is the "type" of a class. | ||
40 | (A class is to a metaclass what an instance is to a class.) | ||
41 | |||
42 | In this case, we use the GeneratedProtocolMessageType metaclass | ||
43 | to inject all the useful functionality into the classes | ||
44 | output by the protocol compiler at compile-time. | ||
45 | |||
46 | The upshot of all this is that the real implementation | ||
47 | details for ALL pure-Python protocol buffers are *here in | ||
48 | this file*. | ||
49 | """ | ||
50 | |||
51 | __author__ = 'robinson@google.com (Will Robinson)' | ||
52 | |||
53 | import heapq | ||
54 | import threading | ||
55 | import weakref | ||
56 | # We use "as" to avoid name collisions with variables. | ||
57 | from froofle.protobuf.internal import decoder | ||
58 | from froofle.protobuf.internal import encoder | ||
59 | from froofle.protobuf.internal import message_listener as message_listener_mod | ||
60 | from froofle.protobuf.internal import type_checkers | ||
61 | from froofle.protobuf.internal import wire_format | ||
62 | from froofle.protobuf import descriptor as descriptor_mod | ||
63 | from froofle.protobuf import message as message_mod | ||
64 | |||
65 | _FieldDescriptor = descriptor_mod.FieldDescriptor | ||
66 | |||
67 | |||
68 | class GeneratedProtocolMessageType(type): | ||
69 | |||
70 | """Metaclass for protocol message classes created at runtime from Descriptors. | ||
71 | |||
72 | We add implementations for all methods described in the Message class. We | ||
73 | also create properties to allow getting/setting all fields in the protocol | ||
74 | message. Finally, we create slots to prevent users from accidentally | ||
75 | "setting" nonexistent fields in the protocol message, which then wouldn't get | ||
76 | serialized / deserialized properly. | ||
77 | |||
78 | The protocol compiler currently uses this metaclass to create protocol | ||
79 | message classes at runtime. Clients can also manually create their own | ||
80 | classes at runtime, as in this example: | ||
81 | |||
82 | mydescriptor = Descriptor(.....) | ||
83 | class MyProtoClass(Message): | ||
84 | __metaclass__ = GeneratedProtocolMessageType | ||
85 | DESCRIPTOR = mydescriptor | ||
86 | myproto_instance = MyProtoClass() | ||
87 | myproto.foo_field = 23 | ||
88 | ... | ||
89 | """ | ||
90 | |||
91 | # Must be consistent with the protocol-compiler code in | ||
92 | # proto2/compiler/internal/generator.*. | ||
93 | _DESCRIPTOR_KEY = 'DESCRIPTOR' | ||
94 | |||
95 | def __new__(cls, name, bases, dictionary): | ||
96 | """Custom allocation for runtime-generated class types. | ||
97 | |||
98 | We override __new__ because this is apparently the only place | ||
99 | where we can meaningfully set __slots__ on the class we're creating(?). | ||
100 | (The interplay between metaclasses and slots is not very well-documented). | ||
101 | |||
102 | Args: | ||
103 | name: Name of the class (ignored, but required by the | ||
104 | metaclass protocol). | ||
105 | bases: Base classes of the class we're constructing. | ||
106 | (Should be message.Message). We ignore this field, but | ||
107 | it's required by the metaclass protocol | ||
108 | dictionary: The class dictionary of the class we're | ||
109 | constructing. dictionary[_DESCRIPTOR_KEY] must contain | ||
110 | a Descriptor object describing this protocol message | ||
111 | type. | ||
112 | |||
113 | Returns: | ||
114 | Newly-allocated class. | ||
115 | """ | ||
116 | descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] | ||
117 | _AddSlots(descriptor, dictionary) | ||
118 | _AddClassAttributesForNestedExtensions(descriptor, dictionary) | ||
119 | superclass = super(GeneratedProtocolMessageType, cls) | ||
120 | return superclass.__new__(cls, name, bases, dictionary) | ||
121 | |||
122 | def __init__(cls, name, bases, dictionary): | ||
123 | """Here we perform the majority of our work on the class. | ||
124 | We add enum getters, an __init__ method, implementations | ||
125 | of all Message methods, and properties for all fields | ||
126 | in the protocol type. | ||
127 | |||
128 | Args: | ||
129 | name: Name of the class (ignored, but required by the | ||
130 | metaclass protocol). | ||
131 | bases: Base classes of the class we're constructing. | ||
132 | (Should be message.Message). We ignore this field, but | ||
133 | it's required by the metaclass protocol | ||
134 | dictionary: The class dictionary of the class we're | ||
135 | constructing. dictionary[_DESCRIPTOR_KEY] must contain | ||
136 | a Descriptor object describing this protocol message | ||
137 | type. | ||
138 | """ | ||
139 | descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] | ||
140 | # We act as a "friend" class of the descriptor, setting | ||
141 | # its _concrete_class attribute the first time we use a | ||
142 | # given descriptor to initialize a concrete protocol message | ||
143 | # class. | ||
144 | concrete_class_attr_name = '_concrete_class' | ||
145 | if not hasattr(descriptor, concrete_class_attr_name): | ||
146 | setattr(descriptor, concrete_class_attr_name, cls) | ||
147 | cls._known_extensions = [] | ||
148 | _AddEnumValues(descriptor, cls) | ||
149 | _AddInitMethod(descriptor, cls) | ||
150 | _AddPropertiesForFields(descriptor, cls) | ||
151 | _AddStaticMethods(cls) | ||
152 | _AddMessageMethods(descriptor, cls) | ||
153 | _AddPrivateHelperMethods(cls) | ||
154 | superclass = super(GeneratedProtocolMessageType, cls) | ||
155 | superclass.__init__(cls, name, bases, dictionary) | ||
156 | |||
157 | |||
158 | # Stateless helpers for GeneratedProtocolMessageType below. | ||
159 | # Outside clients should not access these directly. | ||
160 | # | ||
161 | # I opted not to make any of these methods on the metaclass, to make it more | ||
162 | # clear that I'm not really using any state there and to keep clients from | ||
163 | # thinking that they have direct access to these construction helpers. | ||
164 | |||
165 | |||
166 | def _PropertyName(proto_field_name): | ||
167 | """Returns the name of the public property attribute which | ||
168 | clients can use to get and (in some cases) set the value | ||
169 | of a protocol message field. | ||
170 | |||
171 | Args: | ||
172 | proto_field_name: The protocol message field name, exactly | ||
173 | as it appears (or would appear) in a .proto file. | ||
174 | """ | ||
175 | # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. | ||
176 | # nnorwitz makes my day by writing: | ||
177 | # """ | ||
178 | # FYI. See the keyword module in the stdlib. This could be as simple as: | ||
179 | # | ||
180 | # if keyword.iskeyword(proto_field_name): | ||
181 | # return proto_field_name + "_" | ||
182 | # return proto_field_name | ||
183 | # """ | ||
184 | return proto_field_name | ||
185 | |||
186 | |||
187 | def _ValueFieldName(proto_field_name): | ||
188 | """Returns the name of the (internal) instance attribute which objects | ||
189 | should use to store the current value for a given protocol message field. | ||
190 | |||
191 | Args: | ||
192 | proto_field_name: The protocol message field name, exactly | ||
193 | as it appears (or would appear) in a .proto file. | ||
194 | """ | ||
195 | return '_value_' + proto_field_name | ||
196 | |||
197 | |||
198 | def _HasFieldName(proto_field_name): | ||
199 | """Returns the name of the (internal) instance attribute which | ||
200 | objects should use to store a boolean telling whether this field | ||
201 | is explicitly set or not. | ||
202 | |||
203 | Args: | ||
204 | proto_field_name: The protocol message field name, exactly | ||
205 | as it appears (or would appear) in a .proto file. | ||
206 | """ | ||
207 | return '_has_' + proto_field_name | ||
208 | |||
209 | |||
210 | def _AddSlots(message_descriptor, dictionary): | ||
211 | """Adds a __slots__ entry to dictionary, containing the names of all valid | ||
212 | attributes for this message type. | ||
213 | |||
214 | Args: | ||
215 | message_descriptor: A Descriptor instance describing this message type. | ||
216 | dictionary: Class dictionary to which we'll add a '__slots__' entry. | ||
217 | """ | ||
218 | field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields] | ||
219 | field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields | ||
220 | if f.label != _FieldDescriptor.LABEL_REPEATED) | ||
221 | field_names.extend(('Extensions', | ||
222 | '_cached_byte_size', | ||
223 | '_cached_byte_size_dirty', | ||
224 | '_called_transition_to_nonempty', | ||
225 | '_listener', | ||
226 | '_lock', '__weakref__')) | ||
227 | dictionary['__slots__'] = field_names | ||
228 | |||
229 | |||
230 | def _AddClassAttributesForNestedExtensions(descriptor, dictionary): | ||
231 | extension_dict = descriptor.extensions_by_name | ||
232 | for extension_name, extension_field in extension_dict.iteritems(): | ||
233 | assert extension_name not in dictionary | ||
234 | dictionary[extension_name] = extension_field | ||
235 | |||
236 | |||
237 | def _AddEnumValues(descriptor, cls): | ||
238 | """Sets class-level attributes for all enum fields defined in this message. | ||
239 | |||
240 | Args: | ||
241 | descriptor: Descriptor object for this message type. | ||
242 | cls: Class we're constructing for this message type. | ||
243 | """ | ||
244 | for enum_type in descriptor.enum_types: | ||
245 | for enum_value in enum_type.values: | ||
246 | setattr(cls, enum_value.name, enum_value.number) | ||
247 | |||
248 | |||
249 | def _DefaultValueForField(message, field): | ||
250 | """Returns a default value for a field. | ||
251 | |||
252 | Args: | ||
253 | message: Message instance containing this field, or a weakref proxy | ||
254 | of same. | ||
255 | field: FieldDescriptor object for this field. | ||
256 | |||
257 | Returns: A default value for this field. May refer back to |message| | ||
258 | via a weak reference. | ||
259 | """ | ||
260 | # TODO(robinson): Only the repeated fields need a reference to 'message' (so | ||
261 | # that they can set the 'has' bit on the containing Message when someone | ||
262 | # append()s a value). We could special-case this, and avoid an extra | ||
263 | # function call on __init__() and Clear() for non-repeated fields. | ||
264 | |||
265 | # TODO(robinson): Find a better place for the default value assertion in this | ||
266 | # function. No need to repeat them every time the client calls Clear('foo'). | ||
267 | # (We should probably just assert these things once and as early as possible, | ||
268 | # by tightening checking in the descriptor classes.) | ||
269 | if field.label == _FieldDescriptor.LABEL_REPEATED: | ||
270 | if field.default_value != []: | ||
271 | raise ValueError('Repeated field default value not empty list: %s' % ( | ||
272 | field.default_value)) | ||
273 | listener = _Listener(message, None) | ||
274 | if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | ||
275 | # We can't look at _concrete_class yet since it might not have | ||
276 | # been set. (Depends on order in which we initialize the classes). | ||
277 | return _RepeatedCompositeFieldContainer(listener, field.message_type) | ||
278 | else: | ||
279 | return _RepeatedScalarFieldContainer( | ||
280 | listener, type_checkers.GetTypeChecker(field.cpp_type, field.type)) | ||
281 | |||
282 | if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | ||
283 | assert field.default_value is None | ||
284 | |||
285 | return field.default_value | ||
286 | |||
287 | |||
288 | def _AddInitMethod(message_descriptor, cls): | ||
289 | """Adds an __init__ method to cls.""" | ||
290 | fields = message_descriptor.fields | ||
291 | def init(self): | ||
292 | self._cached_byte_size = 0 | ||
293 | self._cached_byte_size_dirty = False | ||
294 | self._listener = message_listener_mod.NullMessageListener() | ||
295 | self._called_transition_to_nonempty = False | ||
296 | # TODO(robinson): We should only create a lock if we really need one | ||
297 | # in this class. | ||
298 | self._lock = threading.Lock() | ||
299 | for field in fields: | ||
300 | default_value = _DefaultValueForField(self, field) | ||
301 | python_field_name = _ValueFieldName(field.name) | ||
302 | setattr(self, python_field_name, default_value) | ||
303 | if field.label != _FieldDescriptor.LABEL_REPEATED: | ||
304 | setattr(self, _HasFieldName(field.name), False) | ||
305 | self.Extensions = _ExtensionDict(self, cls._known_extensions) | ||
306 | |||
307 | init.__module__ = None | ||
308 | init.__doc__ = None | ||
309 | cls.__init__ = init | ||
310 | |||
311 | |||
312 | def _AddPropertiesForFields(descriptor, cls): | ||
313 | """Adds properties for all fields in this protocol message type.""" | ||
314 | for field in descriptor.fields: | ||
315 | _AddPropertiesForField(field, cls) | ||
316 | |||
317 | |||
318 | def _AddPropertiesForField(field, cls): | ||
319 | """Adds a public property for a protocol message field. | ||
320 | Clients can use this property to get and (in the case | ||
321 | of non-repeated scalar fields) directly set the value | ||
322 | of a protocol message field. | ||
323 | |||
324 | Args: | ||
325 | field: A FieldDescriptor for this field. | ||
326 | cls: The class we're constructing. | ||
327 | """ | ||
328 | # Catch it if we add other types that we should | ||
329 | # handle specially here. | ||
330 | assert _FieldDescriptor.MAX_CPPTYPE == 10 | ||
331 | |||
332 | if field.label == _FieldDescriptor.LABEL_REPEATED: | ||
333 | _AddPropertiesForRepeatedField(field, cls) | ||
334 | elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | ||
335 | _AddPropertiesForNonRepeatedCompositeField(field, cls) | ||
336 | else: | ||
337 | _AddPropertiesForNonRepeatedScalarField(field, cls) | ||
338 | |||
339 | |||
340 | def _AddPropertiesForRepeatedField(field, cls): | ||
341 | """Adds a public property for a "repeated" protocol message field. Clients | ||
342 | can use this property to get the value of the field, which will be either a | ||
343 | _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see | ||
344 | below). | ||
345 | |||
346 | Note that when clients add values to these containers, we perform | ||
347 | type-checking in the case of repeated scalar fields, and we also set any | ||
348 | necessary "has" bits as a side-effect. | ||
349 | |||
350 | Args: | ||
351 | field: A FieldDescriptor for this field. | ||
352 | cls: The class we're constructing. | ||
353 | """ | ||
354 | proto_field_name = field.name | ||
355 | python_field_name = _ValueFieldName(proto_field_name) | ||
356 | property_name = _PropertyName(proto_field_name) | ||
357 | |||
358 | def getter(self): | ||
359 | return getattr(self, python_field_name) | ||
360 | getter.__module__ = None | ||
361 | getter.__doc__ = 'Getter for %s.' % proto_field_name | ||
362 | |||
363 | # We define a setter just so we can throw an exception with a more | ||
364 | # helpful error message. | ||
365 | def setter(self, new_value): | ||
366 | raise AttributeError('Assignment not allowed to repeated field ' | ||
367 | '"%s" in protocol message object.' % proto_field_name) | ||
368 | |||
369 | doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | ||
370 | setattr(cls, property_name, property(getter, setter, doc=doc)) | ||
371 | |||
372 | |||
373 | def _AddPropertiesForNonRepeatedScalarField(field, cls): | ||
374 | """Adds a public property for a nonrepeated, scalar protocol message field. | ||
375 | Clients can use this property to get and directly set the value of the field. | ||
376 | Note that when the client sets the value of a field by using this property, | ||
377 | all necessary "has" bits are set as a side-effect, and we also perform | ||
378 | type-checking. | ||
379 | |||
380 | Args: | ||
381 | field: A FieldDescriptor for this field. | ||
382 | cls: The class we're constructing. | ||
383 | """ | ||
384 | proto_field_name = field.name | ||
385 | python_field_name = _ValueFieldName(proto_field_name) | ||
386 | has_field_name = _HasFieldName(proto_field_name) | ||
387 | property_name = _PropertyName(proto_field_name) | ||
388 | type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) | ||
389 | |||
390 | def getter(self): | ||
391 | return getattr(self, python_field_name) | ||
392 | getter.__module__ = None | ||
393 | getter.__doc__ = 'Getter for %s.' % proto_field_name | ||
394 | def setter(self, new_value): | ||
395 | type_checker.CheckValue(new_value) | ||
396 | setattr(self, has_field_name, True) | ||
397 | self._MarkByteSizeDirty() | ||
398 | self._MaybeCallTransitionToNonemptyCallback() | ||
399 | setattr(self, python_field_name, new_value) | ||
400 | setter.__module__ = None | ||
401 | setter.__doc__ = 'Setter for %s.' % proto_field_name | ||
402 | |||
403 | # Add a property to encapsulate the getter/setter. | ||
404 | doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | ||
405 | setattr(cls, property_name, property(getter, setter, doc=doc)) | ||
406 | |||
407 | |||
408 | def _AddPropertiesForNonRepeatedCompositeField(field, cls): | ||
409 | """Adds a public property for a nonrepeated, composite protocol message field. | ||
410 | A composite field is a "group" or "message" field. | ||
411 | |||
412 | Clients can use this property to get the value of the field, but cannot | ||
413 | assign to the property directly. | ||
414 | |||
415 | Args: | ||
416 | field: A FieldDescriptor for this field. | ||
417 | cls: The class we're constructing. | ||
418 | """ | ||
419 | # TODO(robinson): Remove duplication with similar method | ||
420 | # for non-repeated scalars. | ||
421 | proto_field_name = field.name | ||
422 | python_field_name = _ValueFieldName(proto_field_name) | ||
423 | has_field_name = _HasFieldName(proto_field_name) | ||
424 | property_name = _PropertyName(proto_field_name) | ||
425 | message_type = field.message_type | ||
426 | |||
427 | def getter(self): | ||
428 | # TODO(robinson): Appropriately scary note about double-checked locking. | ||
429 | field_value = getattr(self, python_field_name) | ||
430 | if field_value is None: | ||
431 | self._lock.acquire() | ||
432 | try: | ||
433 | field_value = getattr(self, python_field_name) | ||
434 | if field_value is None: | ||
435 | field_class = message_type._concrete_class | ||
436 | field_value = field_class() | ||
437 | field_value._SetListener(_Listener(self, has_field_name)) | ||
438 | setattr(self, python_field_name, field_value) | ||
439 | finally: | ||
440 | self._lock.release() | ||
441 | return field_value | ||
442 | getter.__module__ = None | ||
443 | getter.__doc__ = 'Getter for %s.' % proto_field_name | ||
444 | |||
445 | # We define a setter just so we can throw an exception with a more | ||
446 | # helpful error message. | ||
447 | def setter(self, new_value): | ||
448 | raise AttributeError('Assignment not allowed to composite field ' | ||
449 | '"%s" in protocol message object.' % proto_field_name) | ||
450 | |||
451 | # Add a property to encapsulate the getter. | ||
452 | doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name | ||
453 | setattr(cls, property_name, property(getter, setter, doc=doc)) | ||
454 | |||
455 | |||
456 | def _AddStaticMethods(cls): | ||
457 | # TODO(robinson): This probably needs to be thread-safe(?) | ||
458 | def RegisterExtension(extension_handle): | ||
459 | extension_handle.containing_type = cls.DESCRIPTOR | ||
460 | cls._known_extensions.append(extension_handle) | ||
461 | cls.RegisterExtension = staticmethod(RegisterExtension) | ||
462 | |||
463 | |||
464 | def _AddListFieldsMethod(message_descriptor, cls): | ||
465 | """Helper for _AddMessageMethods().""" | ||
466 | |||
467 | # Ensure that we always list in ascending field-number order. | ||
468 | # For non-extension fields, we can do the sort once, here, at import-time. | ||
469 | # For extensions, we sort on each ListFields() call, though | ||
470 | # we could do better if we have to. | ||
471 | fields = sorted(message_descriptor.fields, key=lambda f: f.number) | ||
472 | has_field_names = (_HasFieldName(f.name) for f in fields) | ||
473 | value_field_names = (_ValueFieldName(f.name) for f in fields) | ||
474 | triplets = zip(has_field_names, value_field_names, fields) | ||
475 | |||
476 | def ListFields(self): | ||
477 | # We need to list all extension and non-extension fields | ||
478 | # together, in sorted order by field number. | ||
479 | |||
480 | # Step 0: Get an iterator over all "set" non-extension fields, | ||
481 | # sorted by field number. | ||
482 | # This iterator yields (field_number, field_descriptor, value) tuples. | ||
483 | def SortedSetFieldsIter(): | ||
484 | # Note that triplets is already sorted by field number. | ||
485 | for has_field_name, value_field_name, field_descriptor in triplets: | ||
486 | if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: | ||
487 | value = getattr(self, _ValueFieldName(field_descriptor.name)) | ||
488 | if len(value) > 0: | ||
489 | yield (field_descriptor.number, field_descriptor, value) | ||
490 | elif getattr(self, _HasFieldName(field_descriptor.name)): | ||
491 | value = getattr(self, _ValueFieldName(field_descriptor.name)) | ||
492 | yield (field_descriptor.number, field_descriptor, value) | ||
493 | sorted_fields = SortedSetFieldsIter() | ||
494 | |||
495 | # Step 1: Get an iterator over all "set" extension fields, | ||
496 | # sorted by field number. | ||
497 | # This iterator ALSO yields (field_number, field_descriptor, value) tuples. | ||
498 | # TODO(robinson): It's not necessary to repeat this with each | ||
499 | # serialization call. We can do better. | ||
500 | sorted_extension_fields = sorted( | ||
501 | [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()]) | ||
502 | |||
503 | # Step 2: Create a composite iterator that merges the extension- | ||
504 | # and non-extension fields, and that still yields fields in | ||
505 | # sorted order. | ||
506 | all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields) | ||
507 | |||
508 | # Step 3: Strip off the field numbers and return. | ||
509 | return [field[1:] for field in all_set_fields] | ||
510 | |||
511 | cls.ListFields = ListFields | ||
512 | |||
513 | def _AddHasFieldMethod(cls): | ||
514 | """Helper for _AddMessageMethods().""" | ||
515 | def HasField(self, field_name): | ||
516 | try: | ||
517 | return getattr(self, _HasFieldName(field_name)) | ||
518 | except AttributeError: | ||
519 | raise ValueError('Protocol message has no "%s" field.' % field_name) | ||
520 | cls.HasField = HasField | ||
521 | |||
522 | |||
523 | def _AddClearFieldMethod(cls): | ||
524 | """Helper for _AddMessageMethods().""" | ||
525 | def ClearField(self, field_name): | ||
526 | try: | ||
527 | field = self.DESCRIPTOR.fields_by_name[field_name] | ||
528 | except KeyError: | ||
529 | raise ValueError('Protocol message has no "%s" field.' % field_name) | ||
530 | proto_field_name = field.name | ||
531 | python_field_name = _ValueFieldName(proto_field_name) | ||
532 | has_field_name = _HasFieldName(proto_field_name) | ||
533 | default_value = _DefaultValueForField(self, field) | ||
534 | if field.label == _FieldDescriptor.LABEL_REPEATED: | ||
535 | self._MarkByteSizeDirty() | ||
536 | else: | ||
537 | if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | ||
538 | old_field_value = getattr(self, python_field_name) | ||
539 | if old_field_value is not None: | ||
540 | # Snip the old object out of the object tree. | ||
541 | old_field_value._SetListener(None) | ||
542 | if getattr(self, has_field_name): | ||
543 | setattr(self, has_field_name, False) | ||
544 | # Set dirty bit on ourself and parents only if | ||
545 | # we're actually changing state. | ||
546 | self._MarkByteSizeDirty() | ||
547 | setattr(self, python_field_name, default_value) | ||
548 | cls.ClearField = ClearField | ||
549 | |||
550 | |||
551 | def _AddClearExtensionMethod(cls): | ||
552 | """Helper for _AddMessageMethods().""" | ||
553 | def ClearExtension(self, extension_handle): | ||
554 | self.Extensions._ClearExtension(extension_handle) | ||
555 | cls.ClearExtension = ClearExtension | ||
556 | |||
557 | |||
558 | def _AddClearMethod(cls): | ||
559 | """Helper for _AddMessageMethods().""" | ||
560 | def Clear(self): | ||
561 | # Clear fields. | ||
562 | fields = self.DESCRIPTOR.fields | ||
563 | for field in fields: | ||
564 | self.ClearField(field.name) | ||
565 | # Clear extensions. | ||
566 | extensions = self.Extensions._ListSetExtensions() | ||
567 | for extension in extensions: | ||
568 | self.ClearExtension(extension[0]) | ||
569 | cls.Clear = Clear | ||
570 | |||
571 | |||
572 | def _AddHasExtensionMethod(cls): | ||
573 | """Helper for _AddMessageMethods().""" | ||
574 | def HasExtension(self, extension_handle): | ||
575 | return self.Extensions._HasExtension(extension_handle) | ||
576 | cls.HasExtension = HasExtension | ||
577 | |||
578 | |||
579 | def _AddEqualsMethod(message_descriptor, cls): | ||
580 | """Helper for _AddMessageMethods().""" | ||
581 | def __eq__(self, other): | ||
582 | if self is other: | ||
583 | return True | ||
584 | |||
585 | # Compare all fields contained directly in this message. | ||
586 | for field_descriptor in message_descriptor.fields: | ||
587 | label = field_descriptor.label | ||
588 | property_name = _PropertyName(field_descriptor.name) | ||
589 | # Non-repeated field equality requires matching "has" bits as well | ||
590 | # as having an equal value. | ||
591 | if label != _FieldDescriptor.LABEL_REPEATED: | ||
592 | self_has = self.HasField(property_name) | ||
593 | other_has = other.HasField(property_name) | ||
594 | if self_has != other_has: | ||
595 | return False | ||
596 | if not self_has: | ||
597 | # If the "has" bit for this field is False, we must stop here. | ||
598 | # Otherwise we will recurse forever on recursively-defined protos. | ||
599 | continue | ||
600 | if getattr(self, property_name) != getattr(other, property_name): | ||
601 | return False | ||
602 | |||
603 | # Compare the extensions present in both messages. | ||
604 | return self.Extensions == other.Extensions | ||
605 | cls.__eq__ = __eq__ | ||
606 | |||
607 | |||
608 | def _AddSetListenerMethod(cls): | ||
609 | """Helper for _AddMessageMethods().""" | ||
610 | def SetListener(self, listener): | ||
611 | if listener is None: | ||
612 | self._listener = message_listener_mod.NullMessageListener() | ||
613 | else: | ||
614 | self._listener = listener | ||
615 | cls._SetListener = SetListener | ||
616 | |||
617 | |||
618 | def _BytesForNonRepeatedElement(value, field_number, field_type): | ||
619 | """Returns the number of bytes needed to serialize a non-repeated element. | ||
620 | The returned byte count includes space for tag information and any | ||
621 | other additional space associated with serializing value. | ||
622 | |||
623 | Args: | ||
624 | value: Value we're serializing. | ||
625 | field_number: Field number of this value. (Since the field number | ||
626 | is stored as part of a varint-encoded tag, this has an impact | ||
627 | on the total bytes required to serialize the value). | ||
628 | field_type: The type of the field. One of the TYPE_* constants | ||
629 | within FieldDescriptor. | ||
630 | """ | ||
631 | try: | ||
632 | fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] | ||
633 | return fn(field_number, value) | ||
634 | except KeyError: | ||
635 | raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) | ||
636 | |||
637 | |||
638 | def _AddByteSizeMethod(message_descriptor, cls): | ||
639 | """Helper for _AddMessageMethods().""" | ||
640 | |||
641 | def BytesForField(message, field, value): | ||
642 | """Returns the number of bytes required to serialize a single field | ||
643 | in message. The field may be repeated or not, composite or not. | ||
644 | |||
645 | Args: | ||
646 | message: The Message instance containing a field of the given type. | ||
647 | field: A FieldDescriptor describing the field of interest. | ||
648 | value: The value whose byte size we're interested in. | ||
649 | |||
650 | Returns: The number of bytes required to serialize the current value | ||
651 | of "field" in "message", including space for tags and any other | ||
652 | necessary information. | ||
653 | """ | ||
654 | |||
655 | if _MessageSetField(field): | ||
656 | return wire_format.MessageSetItemByteSize(field.number, value) | ||
657 | |||
658 | field_number, field_type = field.number, field.type | ||
659 | |||
660 | # Repeated fields. | ||
661 | if field.label == _FieldDescriptor.LABEL_REPEATED: | ||
662 | elements = value | ||
663 | else: | ||
664 | elements = [value] | ||
665 | |||
666 | size = sum(_BytesForNonRepeatedElement(element, field_number, field_type) | ||
667 | for element in elements) | ||
668 | return size | ||
669 | |||
670 | fields = message_descriptor.fields | ||
671 | has_field_names = (_HasFieldName(f.name) for f in fields) | ||
672 | zipped = zip(has_field_names, fields) | ||
673 | |||
674 | def ByteSize(self): | ||
675 | if not self._cached_byte_size_dirty: | ||
676 | return self._cached_byte_size | ||
677 | |||
678 | size = 0 | ||
679 | # Hardcoded fields first. | ||
680 | for has_field_name, field in zipped: | ||
681 | if (field.label == _FieldDescriptor.LABEL_REPEATED | ||
682 | or getattr(self, has_field_name)): | ||
683 | value = getattr(self, _ValueFieldName(field.name)) | ||
684 | size += BytesForField(self, field, value) | ||
685 | # Extensions next. | ||
686 | for field, value in self.Extensions._ListSetExtensions(): | ||
687 | size += BytesForField(self, field, value) | ||
688 | |||
689 | self._cached_byte_size = size | ||
690 | self._cached_byte_size_dirty = False | ||
691 | return size | ||
692 | cls.ByteSize = ByteSize | ||
693 | |||
694 | |||
695 | def _MessageSetField(field_descriptor): | ||
696 | """Checks if a field should be serialized using the message set wire format. | ||
697 | |||
698 | Args: | ||
699 | field_descriptor: Descriptor of the field. | ||
700 | |||
701 | Returns: | ||
702 | True if the field should be serialized using the message set wire format, | ||
703 | false otherwise. | ||
704 | """ | ||
705 | return (field_descriptor.is_extension and | ||
706 | field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and | ||
707 | field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and | ||
708 | field_descriptor.containing_type.GetOptions().message_set_wire_format) | ||
709 | |||
710 | |||
711 | def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): | ||
712 | """Appends the serialization of a single value to encoder. | ||
713 | |||
714 | Args: | ||
715 | value: Value to serialize. | ||
716 | field_number: Field number of this value. | ||
717 | field_descriptor: Descriptor of the field to serialize. | ||
718 | encoder: encoder.Encoder object to which we should serialize this value. | ||
719 | """ | ||
720 | if _MessageSetField(field_descriptor): | ||
721 | encoder.AppendMessageSetItem(field_number, value) | ||
722 | return | ||
723 | |||
724 | try: | ||
725 | method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] | ||
726 | method(encoder, field_number, value) | ||
727 | except KeyError: | ||
728 | raise message_mod.EncodeError('Unrecognized field type: %d' % | ||
729 | field_descriptor.type) | ||
730 | |||
731 | |||
732 | def _ImergeSorted(*streams): | ||
733 | """Merges N sorted iterators into a single sorted iterator. | ||
734 | Each element in streams must be an iterable that yields | ||
735 | its elements in sorted order, and the elements contained | ||
736 | in each stream must all be comparable. | ||
737 | |||
738 | There may be repeated elements in the component streams or | ||
739 | across the streams; the repeated elements will all be repeated | ||
740 | in the merged iterator as well. | ||
741 | |||
742 | I believe that the heapq module at HEAD in the Python | ||
743 | sources has a method like this, but for now we roll our own. | ||
744 | """ | ||
745 | iters = [iter(stream) for stream in streams] | ||
746 | heap = [] | ||
747 | for index, it in enumerate(iters): | ||
748 | try: | ||
749 | heap.append((it.next(), index)) | ||
750 | except StopIteration: | ||
751 | pass | ||
752 | heapq.heapify(heap) | ||
753 | |||
754 | while heap: | ||
755 | smallest_value, idx = heap[0] | ||
756 | yield smallest_value | ||
757 | try: | ||
758 | next_element = iters[idx].next() | ||
759 | heapq.heapreplace(heap, (next_element, idx)) | ||
760 | except StopIteration: | ||
761 | heapq.heappop(heap) | ||
762 | |||
763 | |||
764 | def _AddSerializeToStringMethod(message_descriptor, cls): | ||
765 | """Helper for _AddMessageMethods().""" | ||
766 | |||
767 | def SerializeToString(self): | ||
768 | # Check if the message has all of its required fields set. | ||
769 | errors = [] | ||
770 | if not _InternalIsInitialized(self, errors): | ||
771 | raise message_mod.EncodeError('\n'.join(errors)) | ||
772 | return self.SerializePartialToString() | ||
773 | cls.SerializeToString = SerializeToString | ||
774 | |||
775 | |||
776 | def _AddSerializePartialToStringMethod(message_descriptor, cls): | ||
777 | """Helper for _AddMessageMethods().""" | ||
778 | Encoder = encoder.Encoder | ||
779 | |||
780 | def SerializePartialToString(self): | ||
781 | encoder = Encoder() | ||
782 | # We need to serialize all extension and non-extension fields | ||
783 | # together, in sorted order by field number. | ||
784 | for field_descriptor, field_value in self.ListFields(): | ||
785 | if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: | ||
786 | repeated_value = field_value | ||
787 | else: | ||
788 | repeated_value = [field_value] | ||
789 | for element in repeated_value: | ||
790 | _SerializeValueToEncoder(element, field_descriptor.number, | ||
791 | field_descriptor, encoder) | ||
792 | return encoder.ToString() | ||
793 | cls.SerializePartialToString = SerializePartialToString | ||
794 | |||
795 | |||
796 | def _WireTypeForFieldType(field_type): | ||
797 | """Given a field type, returns the expected wire type.""" | ||
798 | try: | ||
799 | return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type] | ||
800 | except KeyError: | ||
801 | raise message_mod.DecodeError('Unknown field type: %d' % field_type) | ||
802 | |||
803 | |||
804 | def _RecursivelyMerge(field_number, field_type, decoder, message): | ||
805 | """Decodes a message from decoder into message. | ||
806 | message is either a group or a nested message within some containing | ||
807 | protocol message. If it's a group, we use the group protocol to | ||
808 | deserialize, and if it's a nested message, we use the nested-message | ||
809 | protocol. | ||
810 | |||
811 | Args: | ||
812 | field_number: The field number of message in its enclosing protocol buffer. | ||
813 | field_type: The field type of message. Must be either TYPE_MESSAGE | ||
814 | or TYPE_GROUP. | ||
815 | decoder: Decoder to read from. | ||
816 | message: Message to deserialize into. | ||
817 | """ | ||
818 | if field_type == _FieldDescriptor.TYPE_MESSAGE: | ||
819 | decoder.ReadMessageInto(message) | ||
820 | elif field_type == _FieldDescriptor.TYPE_GROUP: | ||
821 | decoder.ReadGroupInto(field_number, message) | ||
822 | else: | ||
823 | raise message_mod.DecodeError('Unexpected field type: %d' % field_type) | ||
824 | |||
825 | |||
826 | def _DeserializeScalarFromDecoder(field_type, decoder): | ||
827 | """Deserializes a scalar of the requested type from decoder. field_type must | ||
828 | be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. | ||
829 | """ | ||
830 | try: | ||
831 | method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type] | ||
832 | return method(decoder) | ||
833 | except KeyError: | ||
834 | raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) | ||
835 | |||
836 | |||
837 | def _SkipField(field_number, wire_type, decoder): | ||
838 | """Skips a field with the specified wire type. | ||
839 | |||
840 | Args: | ||
841 | field_number: Tag number of the field to skip. | ||
842 | wire_type: Wire type of the field to skip. | ||
843 | decoder: Decoder used to deserialize the messsage. It must be positioned | ||
844 | just after reading the the tag and wire type of the field. | ||
845 | """ | ||
846 | if wire_type == wire_format.WIRETYPE_VARINT: | ||
847 | decoder.ReadUInt64() | ||
848 | elif wire_type == wire_format.WIRETYPE_FIXED64: | ||
849 | decoder.ReadFixed64() | ||
850 | elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: | ||
851 | decoder.SkipBytes(decoder.ReadInt32()) | ||
852 | elif wire_type == wire_format.WIRETYPE_START_GROUP: | ||
853 | _SkipGroup(field_number, decoder) | ||
854 | elif wire_type == wire_format.WIRETYPE_END_GROUP: | ||
855 | pass | ||
856 | elif wire_type == wire_format.WIRETYPE_FIXED32: | ||
857 | decoder.ReadFixed32() | ||
858 | else: | ||
859 | raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type) | ||
860 | |||
861 | |||
862 | def _SkipGroup(group_number, decoder): | ||
863 | """Skips a nested group from the decoder. | ||
864 | |||
865 | Args: | ||
866 | group_number: Tag number of the group to skip. | ||
867 | decoder: Decoder used to deserialize the message. It must be positioned | ||
868 | exactly at the beginning of the message that should be skipped. | ||
869 | """ | ||
870 | while True: | ||
871 | field_number, wire_type = decoder.ReadFieldNumberAndWireType() | ||
872 | if (wire_type == wire_format.WIRETYPE_END_GROUP and | ||
873 | field_number == group_number): | ||
874 | return | ||
875 | _SkipField(field_number, wire_type, decoder) | ||
876 | |||
877 | |||
878 | def _DeserializeMessageSetItem(message, decoder): | ||
879 | """Deserializes a message using the message set wire format. | ||
880 | |||
881 | Args: | ||
882 | message: Message to be parsed to. | ||
883 | decoder: The decoder to be used to deserialize encoded data. Note that the | ||
884 | decoder should be positioned just after reading the START_GROUP tag that | ||
885 | began the messageset item. | ||
886 | """ | ||
887 | field_number, wire_type = decoder.ReadFieldNumberAndWireType() | ||
888 | if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2: | ||
889 | raise message_mod.DecodeError( | ||
890 | 'Incorrect message set wire format. ' | ||
891 | 'wire_type: %d, field_number: %d' % (wire_type, field_number)) | ||
892 | |||
893 | type_id = decoder.ReadInt32() | ||
894 | field_number, wire_type = decoder.ReadFieldNumberAndWireType() | ||
895 | if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3: | ||
896 | raise message_mod.DecodeError( | ||
897 | 'Incorrect message set wire format. ' | ||
898 | 'wire_type: %d, field_number: %d' % (wire_type, field_number)) | ||
899 | |||
900 | extension_dict = message.Extensions | ||
901 | extensions_by_number = extension_dict._AllExtensionsByNumber() | ||
902 | if type_id not in extensions_by_number: | ||
903 | _SkipField(field_number, wire_type, decoder) | ||
904 | return | ||
905 | |||
906 | field_descriptor = extensions_by_number[type_id] | ||
907 | value = extension_dict[field_descriptor] | ||
908 | decoder.ReadMessageInto(value) | ||
909 | # Read the END_GROUP tag. | ||
910 | field_number, wire_type = decoder.ReadFieldNumberAndWireType() | ||
911 | if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1: | ||
912 | raise message_mod.DecodeError( | ||
913 | 'Incorrect message set wire format. ' | ||
914 | 'wire_type: %d, field_number: %d' % (wire_type, field_number)) | ||
915 | |||
916 | |||
917 | def _DeserializeOneEntity(message_descriptor, message, decoder): | ||
918 | """Deserializes the next wire entity from decoder into message. | ||
919 | The next wire entity is either a scalar or a nested message, | ||
920 | and may also be an element in a repeated field (the wire encoding | ||
921 | is the same). | ||
922 | |||
923 | Args: | ||
924 | message_descriptor: A Descriptor instance describing all fields | ||
925 | in message. | ||
926 | message: The Message instance into which we're decoding our fields. | ||
927 | decoder: The Decoder we're using to deserialize encoded data. | ||
928 | |||
929 | Returns: The number of bytes read from decoder during this method. | ||
930 | """ | ||
931 | initial_position = decoder.Position() | ||
932 | field_number, wire_type = decoder.ReadFieldNumberAndWireType() | ||
933 | extension_dict = message.Extensions | ||
934 | extensions_by_number = extension_dict._AllExtensionsByNumber() | ||
935 | if field_number in message_descriptor.fields_by_number: | ||
936 | # Non-extension field. | ||
937 | field_descriptor = message_descriptor.fields_by_number[field_number] | ||
938 | value = getattr(message, _PropertyName(field_descriptor.name)) | ||
939 | def nonextension_setter_fn(scalar): | ||
940 | setattr(message, _PropertyName(field_descriptor.name), scalar) | ||
941 | scalar_setter_fn = nonextension_setter_fn | ||
942 | elif field_number in extensions_by_number: | ||
943 | # Extension field. | ||
944 | field_descriptor = extensions_by_number[field_number] | ||
945 | value = extension_dict[field_descriptor] | ||
946 | def extension_setter_fn(scalar): | ||
947 | extension_dict[field_descriptor] = scalar | ||
948 | scalar_setter_fn = extension_setter_fn | ||
949 | elif wire_type == wire_format.WIRETYPE_END_GROUP: | ||
950 | # We assume we're being parsed as the group that's ended. | ||
951 | return 0 | ||
952 | elif (wire_type == wire_format.WIRETYPE_START_GROUP and | ||
953 | field_number == 1 and | ||
954 | message_descriptor.GetOptions().message_set_wire_format): | ||
955 | # A Message Set item. | ||
956 | _DeserializeMessageSetItem(message, decoder) | ||
957 | return decoder.Position() - initial_position | ||
958 | else: | ||
959 | _SkipField(field_number, wire_type, decoder) | ||
960 | return decoder.Position() - initial_position | ||
961 | |||
962 | # If we reach this point, we've identified the field as either | ||
963 | # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|, | ||
964 | # and |value| appropriately. Now actually deserialize the thing. | ||
965 | # | ||
966 | # field_descriptor: Describes the field we're deserializing. | ||
967 | # value: The value currently stored in the field to deserialize. | ||
968 | # Used only if the field is composite and/or repeated. | ||
969 | # scalar_setter_fn: A function F such that F(scalar) will | ||
970 | # set a nonrepeated scalar value for this field. Used only | ||
971 | # if this field is a nonrepeated scalar. | ||
972 | |||
973 | field_number = field_descriptor.number | ||
974 | field_type = field_descriptor.type | ||
975 | expected_wire_type = _WireTypeForFieldType(field_type) | ||
976 | if wire_type != expected_wire_type: | ||
977 | # Need to fill in uninterpreted_bytes. Work for the next CL. | ||
978 | raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') | ||
979 | |||
980 | property_name = _PropertyName(field_descriptor.name) | ||
981 | label = field_descriptor.label | ||
982 | cpp_type = field_descriptor.cpp_type | ||
983 | |||
984 | # Nonrepeated scalar. Just set the field directly. | ||
985 | if (label != _FieldDescriptor.LABEL_REPEATED | ||
986 | and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): | ||
987 | scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder)) | ||
988 | return decoder.Position() - initial_position | ||
989 | |||
990 | # Nonrepeated composite. Recursively deserialize. | ||
991 | if label != _FieldDescriptor.LABEL_REPEATED: | ||
992 | composite = value | ||
993 | _RecursivelyMerge(field_number, field_type, decoder, composite) | ||
994 | return decoder.Position() - initial_position | ||
995 | |||
996 | # Now we know we're dealing with a repeated field of some kind. | ||
997 | element_list = value | ||
998 | |||
999 | if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: | ||
1000 | # Repeated scalar. | ||
1001 | element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) | ||
1002 | return decoder.Position() - initial_position | ||
1003 | else: | ||
1004 | # Repeated composite. | ||
1005 | composite = element_list.add() | ||
1006 | _RecursivelyMerge(field_number, field_type, decoder, composite) | ||
1007 | return decoder.Position() - initial_position | ||
1008 | |||
1009 | |||
1010 | def _FieldOrExtensionValues(message, field_or_extension): | ||
1011 | """Retrieves the list of values for the specified field or extension. | ||
1012 | |||
1013 | The target field or extension can be optional, required or repeated, but it | ||
1014 | must have value(s) set. The assumption is that the target field or extension | ||
1015 | is set (e.g. _HasFieldOrExtension holds true). | ||
1016 | |||
1017 | Args: | ||
1018 | message: Message which contains the target field or extension. | ||
1019 | field_or_extension: Field or extension for which the list of values is | ||
1020 | required. Must be an instance of FieldDescriptor. | ||
1021 | |||
1022 | Returns: | ||
1023 | A list of values for the specified field or extension. This list will only | ||
1024 | contain a single element if the field is non-repeated. | ||
1025 | """ | ||
1026 | if field_or_extension.is_extension: | ||
1027 | value = message.Extensions[field_or_extension] | ||
1028 | else: | ||
1029 | value = getattr(message, _ValueFieldName(field_or_extension.name)) | ||
1030 | if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED: | ||
1031 | return [value] | ||
1032 | else: | ||
1033 | # In this case value is a list or repeated values. | ||
1034 | return value | ||
1035 | |||
1036 | |||
1037 | def _HasFieldOrExtension(message, field_or_extension): | ||
1038 | """Checks if a message has the specified field or extension set. | ||
1039 | |||
1040 | The field or extension specified can be optional, required or repeated. If | ||
1041 | it is repeated, this function returns True. Otherwise it checks the has bit | ||
1042 | of the field or extension. | ||
1043 | |||
1044 | Args: | ||
1045 | message: Message which contains the target field or extension. | ||
1046 | field_or_extension: Field or extension to check. This must be a | ||
1047 | FieldDescriptor instance. | ||
1048 | |||
1049 | Returns: | ||
1050 | True if the message has a value set for the specified field or extension, | ||
1051 | or if the field or extension is repeated. | ||
1052 | """ | ||
1053 | if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED: | ||
1054 | return True | ||
1055 | if field_or_extension.is_extension: | ||
1056 | return message.HasExtension(field_or_extension) | ||
1057 | else: | ||
1058 | return message.HasField(field_or_extension.name) | ||
1059 | |||
1060 | |||
1061 | def _IsFieldOrExtensionInitialized(message, field, errors=None): | ||
1062 | """Checks if a message field or extension is initialized. | ||
1063 | |||
1064 | Args: | ||
1065 | message: The message which contains the field or extension. | ||
1066 | field: Field or extension to check. This must be a FieldDescriptor instance. | ||
1067 | errors: Errors will be appended to it, if set to a meaningful value. | ||
1068 | |||
1069 | Returns: | ||
1070 | True if the field/extension can be considered initialized. | ||
1071 | """ | ||
1072 | # If the field is required and is not set, it isn't initialized. | ||
1073 | if field.label == _FieldDescriptor.LABEL_REQUIRED: | ||
1074 | if not _HasFieldOrExtension(message, field): | ||
1075 | if errors is not None: | ||
1076 | errors.append('Required field %s is not set.' % field.full_name) | ||
1077 | return False | ||
1078 | |||
1079 | # If the field is optional and is not set, or if it | ||
1080 | # isn't a submessage then the field is initialized. | ||
1081 | if field.label == _FieldDescriptor.LABEL_OPTIONAL: | ||
1082 | if not _HasFieldOrExtension(message, field): | ||
1083 | return True | ||
1084 | if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: | ||
1085 | return True | ||
1086 | |||
1087 | # The field is set and is either a single or a repeated submessage. | ||
1088 | messages = _FieldOrExtensionValues(message, field) | ||
1089 | # If all submessages in this field are initialized, the field is | ||
1090 | # considered initialized. | ||
1091 | for message in messages: | ||
1092 | if not _InternalIsInitialized(message, errors): | ||
1093 | return False | ||
1094 | return True | ||
1095 | |||
1096 | |||
1097 | def _InternalIsInitialized(message, errors=None): | ||
1098 | """Checks if all required fields of a message are set. | ||
1099 | |||
1100 | Args: | ||
1101 | message: The message to check. | ||
1102 | errors: If set, initialization errors will be appended to it. | ||
1103 | |||
1104 | Returns: | ||
1105 | True iff the specified message has all required fields set. | ||
1106 | """ | ||
1107 | fields_and_extensions = [] | ||
1108 | fields_and_extensions.extend(message.DESCRIPTOR.fields) | ||
1109 | fields_and_extensions.extend( | ||
1110 | [extension[0] for extension in message.Extensions._ListSetExtensions()]) | ||
1111 | for field_or_extension in fields_and_extensions: | ||
1112 | if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors): | ||
1113 | return False | ||
1114 | return True | ||
1115 | |||
1116 | |||
1117 | def _AddMergeFromStringMethod(message_descriptor, cls): | ||
1118 | """Helper for _AddMessageMethods().""" | ||
1119 | Decoder = decoder.Decoder | ||
1120 | def MergeFromString(self, serialized): | ||
1121 | decoder = Decoder(serialized) | ||
1122 | byte_count = 0 | ||
1123 | while not decoder.EndOfStream(): | ||
1124 | bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder) | ||
1125 | if not bytes_read: | ||
1126 | break | ||
1127 | byte_count += bytes_read | ||
1128 | return byte_count | ||
1129 | cls.MergeFromString = MergeFromString | ||
1130 | |||
1131 | |||
1132 | def _AddIsInitializedMethod(cls): | ||
1133 | """Adds the IsInitialized method to the protocol message class.""" | ||
1134 | cls.IsInitialized = _InternalIsInitialized | ||
1135 | |||
1136 | |||
1137 | def _MergeFieldOrExtension(destination_msg, field, value): | ||
1138 | """Merges a specified message field into another message.""" | ||
1139 | property_name = _PropertyName(field.name) | ||
1140 | is_extension = field.is_extension | ||
1141 | |||
1142 | if not is_extension: | ||
1143 | destination = getattr(destination_msg, property_name) | ||
1144 | elif (field.label == _FieldDescriptor.LABEL_REPEATED or | ||
1145 | field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): | ||
1146 | destination = destination_msg.Extensions[field] | ||
1147 | |||
1148 | # Case 1 - a composite field. | ||
1149 | if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | ||
1150 | if field.label == _FieldDescriptor.LABEL_REPEATED: | ||
1151 | for v in value: | ||
1152 | destination.add().MergeFrom(v) | ||
1153 | else: | ||
1154 | destination.MergeFrom(value) | ||
1155 | return | ||
1156 | |||
1157 | # Case 2 - a repeated field. | ||
1158 | if field.label == _FieldDescriptor.LABEL_REPEATED: | ||
1159 | for v in value: | ||
1160 | destination.append(v) | ||
1161 | return | ||
1162 | |||
1163 | # Case 3 - a singular field. | ||
1164 | if is_extension: | ||
1165 | destination_msg.Extensions[field] = value | ||
1166 | else: | ||
1167 | setattr(destination_msg, property_name, value) | ||
1168 | |||
1169 | |||
1170 | def _AddMergeFromMethod(cls): | ||
1171 | def MergeFrom(self, msg): | ||
1172 | assert msg is not self | ||
1173 | for field in msg.ListFields(): | ||
1174 | _MergeFieldOrExtension(self, field[0], field[1]) | ||
1175 | cls.MergeFrom = MergeFrom | ||
1176 | |||
1177 | |||
1178 | def _AddMessageMethods(message_descriptor, cls): | ||
1179 | """Adds implementations of all Message methods to cls.""" | ||
1180 | _AddListFieldsMethod(message_descriptor, cls) | ||
1181 | _AddHasFieldMethod(cls) | ||
1182 | _AddClearFieldMethod(cls) | ||
1183 | _AddClearExtensionMethod(cls) | ||
1184 | _AddClearMethod(cls) | ||
1185 | _AddHasExtensionMethod(cls) | ||
1186 | _AddEqualsMethod(message_descriptor, cls) | ||
1187 | _AddSetListenerMethod(cls) | ||
1188 | _AddByteSizeMethod(message_descriptor, cls) | ||
1189 | _AddSerializeToStringMethod(message_descriptor, cls) | ||
1190 | _AddSerializePartialToStringMethod(message_descriptor, cls) | ||
1191 | _AddMergeFromStringMethod(message_descriptor, cls) | ||
1192 | _AddIsInitializedMethod(cls) | ||
1193 | _AddMergeFromMethod(cls) | ||
1194 | |||
1195 | |||
1196 | def _AddPrivateHelperMethods(cls): | ||
1197 | """Adds implementation of private helper methods to cls.""" | ||
1198 | |||
1199 | def MaybeCallTransitionToNonemptyCallback(self): | ||
1200 | """Calls self._listener.TransitionToNonempty() the first time this | ||
1201 | method is called. On all subsequent calls, this is a no-op. | ||
1202 | """ | ||
1203 | if not self._called_transition_to_nonempty: | ||
1204 | self._listener.TransitionToNonempty() | ||
1205 | self._called_transition_to_nonempty = True | ||
1206 | cls._MaybeCallTransitionToNonemptyCallback = ( | ||
1207 | MaybeCallTransitionToNonemptyCallback) | ||
1208 | |||
1209 | def MarkByteSizeDirty(self): | ||
1210 | """Sets the _cached_byte_size_dirty bit to true, | ||
1211 | and propagates this to our listener iff this was a state change. | ||
1212 | """ | ||
1213 | if not self._cached_byte_size_dirty: | ||
1214 | self._cached_byte_size_dirty = True | ||
1215 | self._listener.ByteSizeDirty() | ||
1216 | cls._MarkByteSizeDirty = MarkByteSizeDirty | ||
1217 | |||
1218 | |||
1219 | class _Listener(object): | ||
1220 | |||
1221 | """MessageListener implementation that a parent message registers with its | ||
1222 | child message. | ||
1223 | |||
1224 | In order to support semantics like: | ||
1225 | |||
1226 | foo.bar.baz = 23 | ||
1227 | assert foo.HasField('bar') | ||
1228 | |||
1229 | ...child objects must have back references to their parents. | ||
1230 | This helper class is at the heart of this support. | ||
1231 | """ | ||
1232 | |||
1233 | def __init__(self, parent_message, has_field_name): | ||
1234 | """Args: | ||
1235 | parent_message: The message whose _MaybeCallTransitionToNonemptyCallback() | ||
1236 | and _MarkByteSizeDirty() methods we should call when we receive | ||
1237 | TransitionToNonempty() and ByteSizeDirty() messages. | ||
1238 | has_field_name: The name of the "has" field that we should set in | ||
1239 | the parent message when we receive a TransitionToNonempty message, | ||
1240 | or None if there's no "has" field to set. (This will be the case | ||
1241 | for child objects in "repeated" fields). | ||
1242 | """ | ||
1243 | # This listener establishes a back reference from a child (contained) object | ||
1244 | # to its parent (containing) object. We make this a weak reference to avoid | ||
1245 | # creating cyclic garbage when the client finishes with the 'parent' object | ||
1246 | # in the tree. | ||
1247 | if isinstance(parent_message, weakref.ProxyType): | ||
1248 | self._parent_message_weakref = parent_message | ||
1249 | else: | ||
1250 | self._parent_message_weakref = weakref.proxy(parent_message) | ||
1251 | self._has_field_name = has_field_name | ||
1252 | |||
1253 | def TransitionToNonempty(self): | ||
1254 | try: | ||
1255 | if self._has_field_name is not None: | ||
1256 | setattr(self._parent_message_weakref, self._has_field_name, True) | ||
1257 | # Propagate the signal to our parents iff this is the first field set. | ||
1258 | self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback() | ||
1259 | except ReferenceError: | ||
1260 | # We can get here if a client has kept a reference to a child object, | ||
1261 | # and is now setting a field on it, but the child's parent has been | ||
1262 | # garbage-collected. This is not an error. | ||
1263 | pass | ||
1264 | |||
1265 | def ByteSizeDirty(self): | ||
1266 | try: | ||
1267 | self._parent_message_weakref._MarkByteSizeDirty() | ||
1268 | except ReferenceError: | ||
1269 | # Same as above. | ||
1270 | pass | ||
1271 | |||
1272 | |||
1273 | # TODO(robinson): Move elsewhere? | ||
1274 | # TODO(robinson): Provide a clear() method here in addition to ClearField()? | ||
1275 | class _RepeatedScalarFieldContainer(object): | ||
1276 | |||
1277 | """Simple, type-checked, list-like container for holding repeated scalars.""" | ||
1278 | |||
1279 | # Minimizes memory usage and disallows assignment to other attributes. | ||
1280 | __slots__ = ['_message_listener', '_type_checker', '_values'] | ||
1281 | |||
1282 | def __init__(self, message_listener, type_checker): | ||
1283 | """ | ||
1284 | Args: | ||
1285 | message_listener: A MessageListener implementation. | ||
1286 | The _RepeatedScalarFieldContaininer will call this object's | ||
1287 | TransitionToNonempty() method when it transitions from being empty to | ||
1288 | being nonempty. | ||
1289 | type_checker: A _ValueChecker instance to run on elements inserted | ||
1290 | into this container. | ||
1291 | """ | ||
1292 | self._message_listener = message_listener | ||
1293 | self._type_checker = type_checker | ||
1294 | self._values = [] | ||
1295 | |||
1296 | def append(self, elem): | ||
1297 | self._type_checker.CheckValue(elem) | ||
1298 | self._values.append(elem) | ||
1299 | self._message_listener.ByteSizeDirty() | ||
1300 | if len(self._values) == 1: | ||
1301 | self._message_listener.TransitionToNonempty() | ||
1302 | |||
1303 | def remove(self, elem): | ||
1304 | self._values.remove(elem) | ||
1305 | self._message_listener.ByteSizeDirty() | ||
1306 | |||
1307 | # List-like __getitem__() support also makes us iterable (via "iter(foo)" | ||
1308 | # or implicitly via "for i in mylist:") for free. | ||
1309 | def __getitem__(self, key): | ||
1310 | return self._values[key] | ||
1311 | |||
1312 | def __setitem__(self, key, value): | ||
1313 | # No need to call TransitionToNonempty(), since if we're able to | ||
1314 | # set the element at this index, we were already nonempty before | ||
1315 | # this method was called. | ||
1316 | self._message_listener.ByteSizeDirty() | ||
1317 | self._type_checker.CheckValue(value) | ||
1318 | self._values[key] = value | ||
1319 | |||
1320 | def __len__(self): | ||
1321 | return len(self._values) | ||
1322 | |||
1323 | def __eq__(self, other): | ||
1324 | if self is other: | ||
1325 | return True | ||
1326 | # Special case for the same type which should be common and fast. | ||
1327 | if isinstance(other, self.__class__): | ||
1328 | return other._values == self._values | ||
1329 | # We are presumably comparing against some other sequence type. | ||
1330 | return other == self._values | ||
1331 | |||
1332 | def __ne__(self, other): | ||
1333 | # Can't use != here since it would infinitely recurse. | ||
1334 | return not self == other | ||
1335 | |||
1336 | |||
1337 | # TODO(robinson): Move elsewhere? | ||
1338 | # TODO(robinson): Provide a clear() method here in addition to ClearField()? | ||
1339 | # TODO(robinson): Unify common functionality with | ||
1340 | # _RepeatedScalarFieldContaininer? | ||
1341 | class _RepeatedCompositeFieldContainer(object): | ||
1342 | |||
1343 | """Simple, list-like container for holding repeated composite fields.""" | ||
1344 | |||
1345 | # Minimizes memory usage and disallows assignment to other attributes. | ||
1346 | __slots__ = ['_values', '_message_descriptor', '_message_listener'] | ||
1347 | |||
1348 | def __init__(self, message_listener, message_descriptor): | ||
1349 | """Note that we pass in a descriptor instead of the generated directly, | ||
1350 | since at the time we construct a _RepeatedCompositeFieldContainer we | ||
1351 | haven't yet necessarily initialized the type that will be contained in the | ||
1352 | container. | ||
1353 | |||
1354 | Args: | ||
1355 | message_listener: A MessageListener implementation. | ||
1356 | The _RepeatedCompositeFieldContainer will call this object's | ||
1357 | TransitionToNonempty() method when it transitions from being empty to | ||
1358 | being nonempty. | ||
1359 | message_descriptor: A Descriptor instance describing the protocol type | ||
1360 | that should be present in this container. We'll use the | ||
1361 | _concrete_class field of this descriptor when the client calls add(). | ||
1362 | """ | ||
1363 | self._message_listener = message_listener | ||
1364 | self._message_descriptor = message_descriptor | ||
1365 | self._values = [] | ||
1366 | |||
1367 | def add(self): | ||
1368 | new_element = self._message_descriptor._concrete_class() | ||
1369 | new_element._SetListener(self._message_listener) | ||
1370 | self._values.append(new_element) | ||
1371 | self._message_listener.ByteSizeDirty() | ||
1372 | self._message_listener.TransitionToNonempty() | ||
1373 | return new_element | ||
1374 | |||
1375 | def __delitem__(self, key): | ||
1376 | self._message_listener.ByteSizeDirty() | ||
1377 | del self._values[key] | ||
1378 | |||
1379 | # List-like __getitem__() support also makes us iterable (via "iter(foo)" | ||
1380 | # or implicitly via "for i in mylist:") for free. | ||
1381 | def __getitem__(self, key): | ||
1382 | return self._values[key] | ||
1383 | |||
1384 | def __len__(self): | ||
1385 | return len(self._values) | ||
1386 | |||
1387 | def __eq__(self, other): | ||
1388 | if self is other: | ||
1389 | return True | ||
1390 | if not isinstance(other, self.__class__): | ||
1391 | raise TypeError('Can only compare repeated composite fields against ' | ||
1392 | 'other repeated composite fields.') | ||
1393 | return self._values == other._values | ||
1394 | |||
1395 | def __ne__(self, other): | ||
1396 | # Can't use != here since it would infinitely recurse. | ||
1397 | return not self == other | ||
1398 | |||
1399 | # TODO(robinson): Implement, document, and test slicing support. | ||
1400 | |||
1401 | |||
1402 | # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... | ||
1403 | # TODO(robinson): Unify error handling of "unknown extension" crap. | ||
1404 | # TODO(robinson): There's so much similarity between the way that | ||
1405 | # extensions behave and the way that normal fields behave that it would | ||
1406 | # be really nice to unify more code. It's not immediately obvious | ||
1407 | # how to do this, though, and I'd rather get the full functionality | ||
1408 | # implemented (and, crucially, get all the tests and specs fleshed out | ||
1409 | # and passing), and then come back to this thorny unification problem. | ||
1410 | # TODO(robinson): Support iteritems()-style iteration over all | ||
1411 | # extensions with the "has" bits turned on? | ||
1412 | class _ExtensionDict(object): | ||
1413 | |||
1414 | """Dict-like container for supporting an indexable "Extensions" | ||
1415 | field on proto instances. | ||
1416 | |||
1417 | Note that in all cases we expect extension handles to be | ||
1418 | FieldDescriptors. | ||
1419 | """ | ||
1420 | |||
1421 | class _ExtensionListener(object): | ||
1422 | |||
1423 | """Adapts an _ExtensionDict to behave as a MessageListener.""" | ||
1424 | |||
1425 | def __init__(self, extension_dict, handle_id): | ||
1426 | self._extension_dict = extension_dict | ||
1427 | self._handle_id = handle_id | ||
1428 | |||
1429 | def TransitionToNonempty(self): | ||
1430 | self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id) | ||
1431 | |||
1432 | def ByteSizeDirty(self): | ||
1433 | self._extension_dict._SubmessageByteSizeBecameDirty() | ||
1434 | |||
1435 | # TODO(robinson): Somewhere, we need to blow up if people | ||
1436 | # try to register two extensions with the same field number. | ||
1437 | # (And we need a test for this of course). | ||
1438 | |||
1439 | def __init__(self, extended_message, known_extensions): | ||
1440 | """extended_message: Message instance for which we are the Extensions dict. | ||
1441 | known_extensions: Iterable of known extension handles. | ||
1442 | These must be FieldDescriptors. | ||
1443 | """ | ||
1444 | # We keep a weak reference to extended_message, since | ||
1445 | # it has a reference to this instance in turn. | ||
1446 | self._extended_message = weakref.proxy(extended_message) | ||
1447 | # We make a deep copy of known_extensions to avoid any | ||
1448 | # thread-safety concerns, since the argument passed in | ||
1449 | # is the global (class-level) dict of known extensions for | ||
1450 | # this type of message, which could be modified at any time | ||
1451 | # via a RegisterExtension() call. | ||
1452 | # | ||
1453 | # This dict maps from handle id to handle (a FieldDescriptor). | ||
1454 | # | ||
1455 | # XXX | ||
1456 | # TODO(robinson): This isn't good enough. The client could | ||
1457 | # instantiate an object in module A, then afterward import | ||
1458 | # module B and pass the instance to B.Foo(). If B imports | ||
1459 | # an extender of this proto and then tries to use it, B | ||
1460 | # will get a KeyError, even though the extension *is* registered | ||
1461 | # at the time of use. | ||
1462 | # XXX | ||
1463 | self._known_extensions = dict((id(e), e) for e in known_extensions) | ||
1464 | # Read lock around self._values, which may be modified by multiple | ||
1465 | # concurrent readers in the conceptually "const" __getitem__ method. | ||
1466 | # So, we grab this lock in every "read-only" method to ensure | ||
1467 | # that concurrent read access is safe without external locking. | ||
1468 | self._lock = threading.Lock() | ||
1469 | # Maps from extension handle ID to current value of that extension. | ||
1470 | self._values = {} | ||
1471 | # Maps from extension handle ID to a boolean "has" bit, but only | ||
1472 | # for non-repeated extension fields. | ||
1473 | keys = (id for id, extension in self._known_extensions.iteritems() | ||
1474 | if extension.label != _FieldDescriptor.LABEL_REPEATED) | ||
1475 | self._has_bits = dict.fromkeys(keys, False) | ||
1476 | |||
1477 | def __getitem__(self, extension_handle): | ||
1478 | """Returns the current value of the given extension handle.""" | ||
1479 | # We don't care as much about keeping critical sections short in the | ||
1480 | # extension support, since it's presumably much less of a common case. | ||
1481 | self._lock.acquire() | ||
1482 | try: | ||
1483 | handle_id = id(extension_handle) | ||
1484 | if handle_id not in self._known_extensions: | ||
1485 | raise KeyError('Extension not known to this class') | ||
1486 | if handle_id not in self._values: | ||
1487 | self._AddMissingHandle(extension_handle, handle_id) | ||
1488 | return self._values[handle_id] | ||
1489 | finally: | ||
1490 | self._lock.release() | ||
1491 | |||
1492 | def __eq__(self, other): | ||
1493 | # We have to grab read locks since we're accessing _values | ||
1494 | # in a "const" method. See the comment in the constructor. | ||
1495 | if self is other: | ||
1496 | return True | ||
1497 | self._lock.acquire() | ||
1498 | try: | ||
1499 | other._lock.acquire() | ||
1500 | try: | ||
1501 | if self._has_bits != other._has_bits: | ||
1502 | return False | ||
1503 | # If there's a "has" bit, then only compare values where it is true. | ||
1504 | for k, v in self._values.iteritems(): | ||
1505 | if self._has_bits.get(k, False) and v != other._values[k]: | ||
1506 | return False | ||
1507 | return True | ||
1508 | finally: | ||
1509 | other._lock.release() | ||
1510 | finally: | ||
1511 | self._lock.release() | ||
1512 | |||
1513 | def __ne__(self, other): | ||
1514 | return not self == other | ||
1515 | |||
1516 | # Note that this is only meaningful for non-repeated, scalar extension | ||
1517 | # fields. Note also that we may have to call | ||
1518 | # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field | ||
1519 | # this way, to set any necssary "has" bits in the ancestors of the extended | ||
1520 | # message. | ||
1521 | def __setitem__(self, extension_handle, value): | ||
1522 | """If extension_handle specifies a non-repeated, scalar extension | ||
1523 | field, sets the value of that field. | ||
1524 | """ | ||
1525 | handle_id = id(extension_handle) | ||
1526 | if handle_id not in self._known_extensions: | ||
1527 | raise KeyError('Extension not known to this class') | ||
1528 | field = extension_handle # Just shorten the name. | ||
1529 | if (field.label == _FieldDescriptor.LABEL_OPTIONAL | ||
1530 | and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): | ||
1531 | # It's slightly wasteful to lookup the type checker each time, | ||
1532 | # but we expect this to be a vanishingly uncommon case anyway. | ||
1533 | type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) | ||
1534 | type_checker.CheckValue(value) | ||
1535 | self._values[handle_id] = value | ||
1536 | self._has_bits[handle_id] = True | ||
1537 | self._extended_message._MarkByteSizeDirty() | ||
1538 | self._extended_message._MaybeCallTransitionToNonemptyCallback() | ||
1539 | else: | ||
1540 | raise TypeError('Extension is repeated and/or a composite type.') | ||
1541 | |||
1542 | def _AddMissingHandle(self, extension_handle, handle_id): | ||
1543 | """Helper internal to ExtensionDict.""" | ||
1544 | # Special handling for non-repeated message extensions, which (like | ||
1545 | # normal fields of this kind) are initialized lazily. | ||
1546 | # REQUIRES: _lock already held. | ||
1547 | cpp_type = extension_handle.cpp_type | ||
1548 | label = extension_handle.label | ||
1549 | if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE | ||
1550 | and label != _FieldDescriptor.LABEL_REPEATED): | ||
1551 | self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id) | ||
1552 | else: | ||
1553 | self._values[handle_id] = _DefaultValueForField( | ||
1554 | self._extended_message, extension_handle) | ||
1555 | |||
1556 | def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id): | ||
1557 | """Helper internal to ExtensionDict.""" | ||
1558 | # REQUIRES: _lock already held. | ||
1559 | value = extension_handle.message_type._concrete_class() | ||
1560 | value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id)) | ||
1561 | self._values[handle_id] = value | ||
1562 | |||
1563 | def _SubmessageTransitionedToNonempty(self, handle_id): | ||
1564 | """Called when a submessage with a given handle id first transitions to | ||
1565 | being nonempty. Called by _ExtensionListener. | ||
1566 | """ | ||
1567 | assert handle_id in self._has_bits | ||
1568 | self._has_bits[handle_id] = True | ||
1569 | self._extended_message._MaybeCallTransitionToNonemptyCallback() | ||
1570 | |||
1571 | def _SubmessageByteSizeBecameDirty(self): | ||
1572 | """Called whenever a submessage's cached byte size becomes invalid | ||
1573 | (goes from being "clean" to being "dirty"). Called by _ExtensionListener. | ||
1574 | """ | ||
1575 | self._extended_message._MarkByteSizeDirty() | ||
1576 | |||
1577 | # We may wish to widen the public interface of Message.Extensions | ||
1578 | # to expose some of this private functionality in the future. | ||
1579 | # For now, we make all this functionality module-private and just | ||
1580 | # implement what we need for serialization/deserialization, | ||
1581 | # HasField()/ClearField(), etc. | ||
1582 | |||
1583 | def _HasExtension(self, extension_handle): | ||
1584 | """Method for internal use by this module. | ||
1585 | Returns true iff we "have" this extension in the sense of the | ||
1586 | "has" bit being set. | ||
1587 | """ | ||
1588 | handle_id = id(extension_handle) | ||
1589 | # Note that this is different from the other checks. | ||
1590 | if handle_id not in self._has_bits: | ||
1591 | raise KeyError('Extension not known to this class, or is repeated field.') | ||
1592 | return self._has_bits[handle_id] | ||
1593 | |||
1594 | # Intentionally pretty similar to ClearField() above. | ||
1595 | def _ClearExtension(self, extension_handle): | ||
1596 | """Method for internal use by this module. | ||
1597 | Clears the specified extension, unsetting its "has" bit. | ||
1598 | """ | ||
1599 | handle_id = id(extension_handle) | ||
1600 | if handle_id not in self._known_extensions: | ||
1601 | raise KeyError('Extension not known to this class') | ||
1602 | default_value = _DefaultValueForField(self._extended_message, | ||
1603 | extension_handle) | ||
1604 | if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: | ||
1605 | self._extended_message._MarkByteSizeDirty() | ||
1606 | else: | ||
1607 | cpp_type = extension_handle.cpp_type | ||
1608 | if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: | ||
1609 | if handle_id in self._values: | ||
1610 | # Future modifications to this object shouldn't set any | ||
1611 | # "has" bits here. | ||
1612 | self._values[handle_id]._SetListener(None) | ||
1613 | if self._has_bits[handle_id]: | ||
1614 | self._has_bits[handle_id] = False | ||
1615 | self._extended_message._MarkByteSizeDirty() | ||
1616 | if handle_id in self._values: | ||
1617 | del self._values[handle_id] | ||
1618 | |||
1619 | def _ListSetExtensions(self): | ||
1620 | """Method for internal use by this module. | ||
1621 | |||
1622 | Returns an sequence of all extensions that are currently "set" | ||
1623 | in this extension dict. A "set" extension is a repeated extension, | ||
1624 | or a non-repeated extension with its "has" bit set. | ||
1625 | |||
1626 | The returned sequence contains (field_descriptor, value) pairs, | ||
1627 | where value is the current value of the extension with the given | ||
1628 | field descriptor. | ||
1629 | |||
1630 | The sequence values are in arbitrary order. | ||
1631 | """ | ||
1632 | self._lock.acquire() # Read-only methods must lock around self._values. | ||
1633 | try: | ||
1634 | set_extensions = [] | ||
1635 | for handle_id, value in self._values.iteritems(): | ||
1636 | handle = self._known_extensions[handle_id] | ||
1637 | if (handle.label == _FieldDescriptor.LABEL_REPEATED | ||
1638 | or self._has_bits[handle_id]): | ||
1639 | set_extensions.append((handle, value)) | ||
1640 | return set_extensions | ||
1641 | finally: | ||
1642 | self._lock.release() | ||
1643 | |||
1644 | def _AllExtensionsByNumber(self): | ||
1645 | """Method for internal use by this module. | ||
1646 | |||
1647 | Returns: A dict mapping field_number to (handle, field_descriptor), | ||
1648 | for *all* registered extensions for this dict. | ||
1649 | """ | ||
1650 | # TODO(robinson): Precompute and store this away. Note that we'll have to | ||
1651 | # be careful when we move away from having _known_extensions as a | ||
1652 | # deep-copied member of this object. | ||
1653 | return dict((f.number, f) for f in self._known_extensions.itervalues()) | ||