source: trunk/src/allmydata/mutable/retrieve.py

Last change on this file was 81a5ae6, checked in by Itamar Turner-Trauring <itamar@…>, at 2023-12-06T21:01:14Z

Simplify

  • Property mode set to 100644
File size: 42.4 KB
Line 
1"""
2Ported to Python 3.
3"""
4from __future__ import annotations
5
6import time
7from itertools import count
8
9from zope.interface import implementer
10from twisted.internet import defer
11from twisted.python import failure
12from twisted.internet.interfaces import IPushProducer, IConsumer
13from foolscap.api import eventually, fireEventually, DeadReferenceError, \
14     RemoteException
15
16from allmydata.crypto import aes
17from allmydata.crypto import rsa
18from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
19     DownloadStopped, MDMF_VERSION, SDMF_VERSION
20from allmydata.util.assertutil import _assert, precondition
21from allmydata.util import hashutil, log, mathutil, deferredutil
22from allmydata.util.dictutil import DictOfSets
23from allmydata.util.cputhreadpool import defer_to_thread
24from allmydata import hashtree, codec
25from allmydata.storage.server import si_b2a
26
27from allmydata.mutable.common import CorruptShareError, BadShareError, \
28     UncoordinatedWriteError, decrypt_privkey
29from allmydata.mutable.layout import MDMFSlotReadProxy
30
31@implementer(IRetrieveStatus)
32class RetrieveStatus(object):
33    statusid_counter = count(0)
34    def __init__(self):
35        self.timings = {}
36        self.timings["fetch_per_server"] = {}
37        self.timings["decode"] = 0.0
38        self.timings["decrypt"] = 0.0
39        self.timings["cumulative_verify"] = 0.0
40        self._problems = {}
41        self.active = True
42        self.storage_index = None
43        self.helper = False
44        self.encoding = ("?","?")
45        self.size = None
46        self.status = "Not started"
47        self.progress = 0.0
48        self.counter = next(self.statusid_counter)
49        self.started = time.time()
50
51    def get_started(self):
52        return self.started
53    def get_storage_index(self):
54        return self.storage_index
55    def get_encoding(self):
56        return self.encoding
57    def using_helper(self):
58        return self.helper
59    def get_size(self):
60        return self.size
61    def get_status(self):
62        return self.status
63    def get_progress(self):
64        return self.progress
65    def get_active(self):
66        return self.active
67    def get_counter(self):
68        return self.counter
69    def get_problems(self):
70        return self._problems
71
72    def add_fetch_timing(self, server, elapsed):
73        if server not in self.timings["fetch_per_server"]:
74            self.timings["fetch_per_server"][server] = []
75        self.timings["fetch_per_server"][server].append(elapsed)
76    def accumulate_decode_time(self, elapsed):
77        self.timings["decode"] += elapsed
78    def accumulate_decrypt_time(self, elapsed):
79        self.timings["decrypt"] += elapsed
80    def set_storage_index(self, si):
81        self.storage_index = si
82    def set_helper(self, helper):
83        self.helper = helper
84    def set_encoding(self, k, n):
85        self.encoding = (k, n)
86    def set_size(self, size):
87        self.size = size
88    def set_status(self, status):
89        self.status = status
90    def set_progress(self, value):
91        self.progress = value
92    def set_active(self, value):
93        self.active = value
94    def add_problem(self, server, f):
95        serverid = server.get_serverid()
96        self._problems[serverid] = f
97
98class Marker(object):
99    pass
100
101@implementer(IPushProducer)
102class Retrieve(object):
103    # this class is currently single-use. Eventually (in MDMF) we will make
104    # it multi-use, in which case you can call download(range) multiple
105    # times, and each will have a separate response chain. However the
106    # Retrieve object will remain tied to a specific version of the file, and
107    # will use a single ServerMap instance.
108
109    def __init__(self, filenode, storage_broker, servermap, verinfo,
110                 fetch_privkey=False, verify=False):
111        self._node = filenode
112        _assert(self._node.get_pubkey())
113        self._storage_broker = storage_broker
114        self._storage_index = filenode.get_storage_index()
115        _assert(self._node.get_readkey())
116        self._last_failure = None
117        prefix = si_b2a(self._storage_index)[:5]
118        self._log_number = log.msg("Retrieve(%r): starting" % prefix)
119        self._running = True
120        self._decoding = False
121        self._bad_shares = set()
122
123        self.servermap = servermap
124        self.verinfo = verinfo
125        # TODO: make it possible to use self.verinfo.datalength instead
126        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
127         offsets_tuple) = self.verinfo
128        self._data_length = datalength
129        # during repair, we may be called upon to grab the private key, since
130        # it wasn't picked up during a verify=False checker run, and we'll
131        # need it for repair to generate a new version.
132        self._need_privkey = verify or (fetch_privkey
133                                        and not self._node.get_privkey())
134
135        if self._need_privkey:
136            # TODO: Evaluate the need for this. We'll use it if we want
137            # to limit how many queries are on the wire for the privkey
138            # at once.
139            self._privkey_query_markers = [] # one Marker for each time we've
140                                             # tried to get the privkey.
141
142        # verify means that we are using the downloader logic to verify all
143        # of our shares. This tells the downloader a few things.
144        #
145        # 1. We need to download all of the shares.
146        # 2. We don't need to decode or decrypt the shares, since our
147        #    caller doesn't care about the plaintext, only the
148        #    information about which shares are or are not valid.
149        # 3. When we are validating readers, we need to validate the
150        #    signature on the prefix. Do we? We already do this in the
151        #    servermap update?
152        self._verify = verify
153
154        self._status = RetrieveStatus()
155        self._status.set_storage_index(self._storage_index)
156        self._status.set_helper(False)
157        self._status.set_progress(0.0)
158        self._status.set_active(True)
159        self._status.set_size(datalength)
160        self._status.set_encoding(k, N)
161        self.readers = {}
162        self._stopped = False
163        self._pause_deferred = None
164        self._offset = None
165        self._read_length = None
166        self.log("got seqnum %d" % self.verinfo[0])
167
168
169    def get_status(self):
170        return self._status
171
172    def log(self, *args, **kwargs):
173        if "parent" not in kwargs:
174            kwargs["parent"] = self._log_number
175        if "facility" not in kwargs:
176            kwargs["facility"] = "tahoe.mutable.retrieve"
177        return log.msg(*args, **kwargs)
178
179    def _set_current_status(self, state):
180        seg = "%d/%d" % (self._current_segment, self._last_segment)
181        self._status.set_status("segment %s (%s)" % (seg, state))
182
183    ###################
184    # IPushProducer
185
186    def pauseProducing(self):
187        """
188        I am called by my download target if we have produced too much
189        data for it to handle. I make the downloader stop producing new
190        data until my resumeProducing method is called.
191        """
192        if self._pause_deferred is not None:
193            return
194
195        # fired when the download is unpaused.
196        self._old_status = self._status.get_status()
197        self._set_current_status("paused")
198
199        self._pause_deferred = defer.Deferred()
200
201
202    def resumeProducing(self):
203        """
204        I am called by my download target once it is ready to begin
205        receiving data again.
206        """
207        if self._pause_deferred is None:
208            return
209
210        p = self._pause_deferred
211        self._pause_deferred = None
212        self._status.set_status(self._old_status)
213
214        eventually(p.callback, None)
215
216    def stopProducing(self):
217        self._stopped = True
218        self.resumeProducing()
219
220
221    def _check_for_paused(self, res):
222        """
223        I am called just before a write to the consumer. I return a
224        Deferred that eventually fires with the data that is to be
225        written to the consumer. If the download has not been paused,
226        the Deferred fires immediately. Otherwise, the Deferred fires
227        when the downloader is unpaused.
228        """
229        if self._pause_deferred is not None:
230            d = defer.Deferred()
231            self._pause_deferred.addCallback(lambda ignored: d.callback(res))
232            return d
233        return res
234
235    def _check_for_stopped(self, res):
236        if self._stopped:
237            raise DownloadStopped("our Consumer called stopProducing()")
238        return res
239
240
241    def download(self, consumer=None, offset=0, size=None):
242        precondition(self._verify or IConsumer.providedBy(consumer))
243        if size is None:
244            size = self._data_length - offset
245        if self._verify:
246            _assert(size == self._data_length, (size, self._data_length))
247        self.log("starting download")
248        self._done_deferred = defer.Deferred()
249        if consumer:
250            self._consumer = consumer
251            # we provide IPushProducer, so streaming=True, per IConsumer.
252            self._consumer.registerProducer(self, streaming=True)
253        self._started = time.time()
254        self._started_fetching = time.time()
255        if size == 0:
256            # short-circuit the rest of the process
257            self._done()
258        else:
259            self._start_download(consumer, offset, size)
260        return self._done_deferred
261
262    def _start_download(self, consumer, offset, size):
263        precondition((0 <= offset < self._data_length)
264                     and (size > 0)
265                     and (offset+size <= self._data_length),
266                     (offset, size, self._data_length))
267
268        self._offset = offset
269        self._read_length = size
270        self._setup_encoding_parameters()
271        self._setup_download()
272
273        # The download process beyond this is a state machine.
274        # _add_active_servers will select the servers that we want to use
275        # for the download, and then attempt to start downloading. After
276        # each segment, it will check for doneness, reacting to broken
277        # servers and corrupt shares as necessary. If it runs out of good
278        # servers before downloading all of the segments, _done_deferred
279        # will errback.  Otherwise, it will eventually callback with the
280        # contents of the mutable file.
281        self.loop()
282
283    def loop(self):
284        d = fireEventually(None) # avoid #237 recursion limit problem
285        d.addCallback(lambda ign: self._activate_enough_servers())
286        d.addCallback(lambda ign: self._download_current_segment())
287        # when we're done, _download_current_segment will call _done. If we
288        # aren't, it will call loop() again.
289        d.addErrback(self._error)
290
291    def _setup_download(self):
292        self._status.set_status("Retrieving Shares")
293
294        # how many shares do we need?
295        (seqnum,
296         root_hash,
297         IV,
298         segsize,
299         datalength,
300         k,
301         N,
302         prefix,
303         offsets_tuple) = self.verinfo
304
305        # first, which servers can we use?
306        versionmap = self.servermap.make_versionmap()
307        shares = versionmap[self.verinfo]
308        # this sharemap is consumed as we decide to send requests
309        self.remaining_sharemap = DictOfSets()
310        for (shnum, server, timestamp) in shares:
311            self.remaining_sharemap.add(shnum, server)
312            # Reuse the SlotReader from the servermap.
313            key = (self.verinfo, server.get_serverid(),
314                   self._storage_index, shnum)
315            if key in self.servermap.proxies:
316                reader = self.servermap.proxies[key]
317            else:
318                reader = MDMFSlotReadProxy(server.get_storage_server(),
319                                           self._storage_index, shnum, None)
320            reader.server = server
321            self.readers[shnum] = reader
322
323        if len(self.remaining_sharemap) < k:
324            self._raise_notenoughshareserror()
325
326        self.shares = {} # maps shnum to validated blocks
327        self._active_readers = [] # list of active readers for this dl.
328        self._block_hash_trees = {} # shnum => hashtree
329
330        for i in range(self._total_shares):
331            # So we don't have to do this later.
332            self._block_hash_trees[i] = hashtree.IncompleteHashTree(self._num_segments)
333
334        # We need one share hash tree for the entire file; its leaves
335        # are the roots of the block hash trees for the shares that
336        # comprise it, and its root is in the verinfo.
337        self.share_hash_tree = hashtree.IncompleteHashTree(N)
338        self.share_hash_tree.set_hashes({0: root_hash})
339
340    def decode(self, blocks_and_salts, segnum):
341        """
342        I am a helper method that the mutable file update process uses
343        as a shortcut to decode and decrypt the segments that it needs
344        to fetch in order to perform a file update. I take in a
345        collection of blocks and salts, and pick some of those to make a
346        segment with. I return the plaintext associated with that
347        segment.
348        """
349        # We don't need the block hash trees in this case.
350        self._block_hash_trees = None
351        self._offset = 0
352        self._read_length = self._data_length
353        self._setup_encoding_parameters()
354
355        # _decode_blocks() expects the output of a gatherResults that
356        # contains the outputs of _validate_block() (each of which is a dict
357        # mapping shnum to (block,salt) bytestrings).
358        d = self._decode_blocks([blocks_and_salts], segnum)
359        d.addCallback(self._decrypt_segment)
360        return d
361
362
363    def _setup_encoding_parameters(self):
364        """
365        I set up the encoding parameters, including k, n, the number
366        of segments associated with this file, and the segment decoders.
367        """
368        (seqnum,
369         root_hash,
370         IV,
371         segsize,
372         datalength,
373         k,
374         n,
375         known_prefix,
376         offsets_tuple) = self.verinfo
377        self._required_shares = k
378        self._total_shares = n
379        self._segment_size = segsize
380        #self._data_length = datalength # set during __init__()
381
382        if not IV:
383            self._version = MDMF_VERSION
384        else:
385            self._version = SDMF_VERSION
386
387        if datalength and segsize:
388            self._num_segments = mathutil.div_ceil(datalength, segsize)
389            self._tail_data_size = datalength % segsize
390        else:
391            self._num_segments = 0
392            self._tail_data_size = 0
393
394        self._segment_decoder = codec.CRSDecoder()
395        self._segment_decoder.set_params(segsize, k, n)
396
397        if  not self._tail_data_size:
398            self._tail_data_size = segsize
399
400        self._tail_segment_size = mathutil.next_multiple(self._tail_data_size,
401                                                         self._required_shares)
402        if self._tail_segment_size == self._segment_size:
403            self._tail_decoder = self._segment_decoder
404        else:
405            self._tail_decoder = codec.CRSDecoder()
406            self._tail_decoder.set_params(self._tail_segment_size,
407                                          self._required_shares,
408                                          self._total_shares)
409
410        self.log("got encoding parameters: "
411                 "k: %d "
412                 "n: %d "
413                 "%d segments of %d bytes each (%d byte tail segment)" % \
414                 (k, n, self._num_segments, self._segment_size,
415                  self._tail_segment_size))
416
417        # Our last task is to tell the downloader where to start and
418        # where to stop. We use three parameters for that:
419        #   - self._start_segment: the segment that we need to start
420        #     downloading from.
421        #   - self._current_segment: the next segment that we need to
422        #     download.
423        #   - self._last_segment: The last segment that we were asked to
424        #     download.
425        #
426        #  We say that the download is complete when
427        #  self._current_segment > self._last_segment. We use
428        #  self._start_segment and self._last_segment to know when to
429        #  strip things off of segments, and how much to strip.
430        if self._offset:
431            self.log("got offset: %d" % self._offset)
432            # our start segment is the first segment containing the
433            # offset we were given.
434            start = self._offset // self._segment_size
435
436            _assert(start <= self._num_segments,
437                    start=start, num_segments=self._num_segments,
438                    offset=self._offset, segment_size=self._segment_size)
439            self._start_segment = start
440            self.log("got start segment: %d" % self._start_segment)
441        else:
442            self._start_segment = 0
443
444        # We might want to read only part of the file, and need to figure out
445        # where to stop reading. Our end segment is the last segment
446        # containing part of the segment that we were asked to read.
447        _assert(self._read_length > 0, self._read_length)
448        end_data = self._offset + self._read_length
449
450        # We don't actually need to read the byte at end_data, but the one
451        # before it.
452        end = (end_data - 1) // self._segment_size
453        _assert(0 <= end < self._num_segments,
454                end=end, num_segments=self._num_segments,
455                end_data=end_data, offset=self._offset,
456                read_length=self._read_length, segment_size=self._segment_size)
457        self._last_segment = end
458        self.log("got end segment: %d" % self._last_segment)
459
460        self._current_segment = self._start_segment
461
462    def _activate_enough_servers(self):
463        """
464        I populate self._active_readers with enough active readers to
465        retrieve the contents of this mutable file. I am called before
466        downloading starts, and (eventually) after each validation
467        error, connection error, or other problem in the download.
468        """
469        # TODO: It would be cool to investigate other heuristics for
470        # reader selection. For instance, the cost (in time the user
471        # spends waiting for their file) of selecting a really slow server
472        # that happens to have a primary share is probably more than
473        # selecting a really fast server that doesn't have a primary
474        # share. Maybe the servermap could be extended to provide this
475        # information; it could keep track of latency information while
476        # it gathers more important data, and then this routine could
477        # use that to select active readers.
478        #
479        # (these and other questions would be easier to answer with a
480        #  robust, configurable tahoe-lafs simulator, which modeled node
481        #  failures, differences in node speed, and other characteristics
482        #  that we expect storage servers to have.  You could have
483        #  presets for really stable grids (like allmydata.com),
484        #  friendnets, make it easy to configure your own settings, and
485        #  then simulate the effect of big changes on these use cases
486        #  instead of just reasoning about what the effect might be. Out
487        #  of scope for MDMF, though.)
488
489        # XXX: Why don't format= log messages work here?
490
491        known_shnums = set(self.remaining_sharemap.keys())
492        used_shnums = set([r.shnum for r in self._active_readers])
493        unused_shnums = known_shnums - used_shnums
494
495        if self._verify:
496            new_shnums = unused_shnums # use them all
497        elif len(self._active_readers) < self._required_shares:
498            # need more shares
499            more = self._required_shares - len(self._active_readers)
500            # We favor lower numbered shares, since FEC is faster with
501            # primary shares than with other shares, and lower-numbered
502            # shares are more likely to be primary than higher numbered
503            # shares.
504            new_shnums = sorted(unused_shnums)[:more]
505            if len(new_shnums) < more:
506                # We don't have enough readers to retrieve the file; fail.
507                self._raise_notenoughshareserror()
508        else:
509            new_shnums = []
510
511        self.log("adding %d new servers to the active list" % len(new_shnums))
512        for shnum in new_shnums:
513            reader = self.readers[shnum]
514            self._active_readers.append(reader)
515            self.log("added reader for share %d" % shnum)
516            # Each time we add a reader, we check to see if we need the
517            # private key. If we do, we politely ask for it and then continue
518            # computing. If we find that we haven't gotten it at the end of
519            # segment decoding, then we'll take more drastic measures.
520            if self._need_privkey and not self._node.is_readonly():
521                d = reader.get_encprivkey()
522                d.addCallback(self._try_to_validate_privkey, reader, reader.server)
523                # XXX: don't just drop the Deferred. We need error-reporting
524                # but not flow-control here.
525
526    def _try_to_validate_prefix(self, prefix, reader):
527        """
528        I check that the prefix returned by a candidate server for
529        retrieval matches the prefix that the servermap knows about
530        (and, hence, the prefix that was validated earlier). If it does,
531        I return True, which means that I approve of the use of the
532        candidate server for segment retrieval. If it doesn't, I return
533        False, which means that another server must be chosen.
534        """
535        (seqnum,
536         root_hash,
537         IV,
538         segsize,
539         datalength,
540         k,
541         N,
542         known_prefix,
543         offsets_tuple) = self.verinfo
544        if known_prefix != prefix:
545            self.log("prefix from share %d doesn't match" % reader.shnum)
546            raise UncoordinatedWriteError("Mismatched prefix -- this could "
547                                          "indicate an uncoordinated write")
548        # Otherwise, we're okay -- no issues.
549
550    def _mark_bad_share(self, server, shnum, reader, f):
551        """
552        I mark the given (server, shnum) as a bad share, which means that it
553        will not be used anywhere else.
554
555        There are several reasons to want to mark something as a bad
556        share. These include:
557
558            - A connection error to the server.
559            - A mismatched prefix (that is, a prefix that does not match
560              our local conception of the version information string).
561            - A failing block hash, salt hash, share hash, or other
562              integrity check.
563
564        This method will ensure that readers that we wish to mark bad
565        (for these reasons or other reasons) are not used for the rest
566        of the download. Additionally, it will attempt to tell the
567        remote server (with no guarantee of success) that its share is
568        corrupt.
569        """
570        self.log("marking share %d on server %r as bad" % \
571                 (shnum, server.get_name()))
572        prefix = self.verinfo[-2]
573        self.servermap.mark_bad_share(server, shnum, prefix)
574        self._bad_shares.add((server, shnum, f))
575        self._status.add_problem(server, f)
576        self._last_failure = f
577
578        # Remove the reader from _active_readers
579        self._active_readers.remove(reader)
580        for shnum in list(self.remaining_sharemap.keys()):
581            self.remaining_sharemap.discard(shnum, reader.server)
582
583        if f.check(BadShareError):
584            self.notify_server_corruption(server, shnum, str(f.value))
585
586    def _download_current_segment(self):
587        """
588        I download, validate, decode, decrypt, and assemble the segment
589        that this Retrieve is currently responsible for downloading.
590        """
591
592        if self._current_segment > self._last_segment:
593            # No more segments to download, we're done.
594            self.log("got plaintext, done")
595            return self._done()
596        elif self._verify and len(self._active_readers) == 0:
597            self.log("no more good shares, no need to keep verifying")
598            return self._done()
599        self.log("on segment %d of %d" %
600                 (self._current_segment + 1, self._num_segments))
601        d = self._process_segment(self._current_segment)
602        d.addCallback(lambda ign: self.loop())
603        return d
604
605    def _process_segment(self, segnum):
606        """
607        I download, validate, decode, and decrypt one segment of the
608        file that this Retrieve is retrieving. This means coordinating
609        the process of getting k blocks of that file, validating them,
610        assembling them into one segment with the decoder, and then
611        decrypting them.
612        """
613        self.log("processing segment %d" % segnum)
614
615        # TODO: The old code uses a marker. Should this code do that
616        # too? What did the Marker do?
617
618        # We need to ask each of our active readers for its block and
619        # salt. We will then validate those. If validation is
620        # successful, we will assemble the results into plaintext.
621        ds = []
622        for reader in self._active_readers:
623            started = time.time()
624            d1 = reader.get_block_and_salt(segnum)
625            d2,d3 = self._get_needed_hashes(reader, segnum)
626            d = deferredutil.gatherResults([d1,d2,d3])
627            d.addCallback(self._validate_block, segnum, reader, reader.server, started)
628            # _handle_bad_share takes care of recoverable errors (by dropping
629            # that share and returning None). Any other errors (i.e. code
630            # bugs) are passed through and cause the retrieve to fail.
631            d.addErrback(self._handle_bad_share, [reader])
632            ds.append(d)
633        dl = deferredutil.gatherResults(ds)
634        if self._verify:
635            dl.addCallback(lambda ignored: "")
636            dl.addCallback(self._set_segment)
637        else:
638            dl.addCallback(self._maybe_decode_and_decrypt_segment, segnum)
639        return dl
640
641
642    def _maybe_decode_and_decrypt_segment(self, results, segnum):
643        """
644        I take the results of fetching and validating the blocks from
645        _process_segment. If validation and fetching succeeded without
646        incident, I will proceed with decoding and decryption. Otherwise, I
647        will do nothing.
648        """
649        self.log("trying to decode and decrypt segment %d" % segnum)
650
651        # 'results' is the output of a gatherResults set up in
652        # _process_segment(). Each component Deferred will either contain the
653        # non-Failure output of _validate_block() for a single block (i.e.
654        # {segnum:(block,salt)}), or None if _validate_block threw an
655        # exception and _validation_or_decoding_failed handled it (by
656        # dropping that server).
657
658        if None in results:
659            self.log("some validation operations failed; not proceeding")
660            return defer.succeed(None)
661        self.log("everything looks ok, building segment %d" % segnum)
662        d = self._decode_blocks(results, segnum)
663        d.addCallback(self._decrypt_segment)
664        # check to see whether we've been paused before writing
665        # anything.
666        d.addCallback(self._check_for_paused)
667        d.addCallback(self._check_for_stopped)
668        d.addCallback(self._set_segment)
669        return d
670
671
672    def _set_segment(self, segment):
673        """
674        Given a plaintext segment, I register that segment with the
675        target that is handling the file download.
676        """
677        self.log("got plaintext for segment %d" % self._current_segment)
678
679        if self._read_length == 0:
680            self.log("on first+last segment, size=0, using 0 bytes")
681            segment = b""
682
683        if self._current_segment == self._last_segment:
684            # trim off the tail
685            wanted = (self._offset + self._read_length) % self._segment_size
686            if wanted != 0:
687                self.log("on the last segment: using first %d bytes" % wanted)
688                segment = segment[:wanted]
689            else:
690                self.log("on the last segment: using all %d bytes" %
691                         len(segment))
692
693        if self._current_segment == self._start_segment:
694            # Trim off the head, if offset != 0. This should also work if
695            # start==last, because we trim the tail first.
696            skip = self._offset % self._segment_size
697            self.log("on the first segment: skipping first %d bytes" % skip)
698            segment = segment[skip:]
699
700        if not self._verify:
701            self._consumer.write(segment)
702        else:
703            # we don't care about the plaintext if we are doing a verify.
704            segment = None
705        self._current_segment += 1
706
707
708    def _handle_bad_share(self, f, readers):
709        """
710        I am called when a block or a salt fails to correctly validate, or when
711        the decryption or decoding operation fails for some reason.  I react to
712        this failure by notifying the remote server of corruption, and then
713        removing the remote server from further activity.
714        """
715        # these are the errors we can tolerate: by giving up on this share
716        # and finding others to replace it. Any other errors (i.e. coding
717        # bugs) are re-raised, causing the download to fail.
718        f.trap(DeadReferenceError, RemoteException, BadShareError)
719
720        # DeadReferenceError happens when we try to fetch data from a server
721        # that has gone away. RemoteException happens if the server had an
722        # internal error. BadShareError encompasses: (UnknownVersionError,
723        # LayoutInvalid, struct.error) which happen when we get obviously
724        # wrong data, and CorruptShareError which happens later, when we
725        # perform integrity checks on the data.
726
727        precondition(isinstance(readers, list), readers)
728        bad_shnums = [reader.shnum for reader in readers]
729
730        self.log("validation or decoding failed on share(s) %s, server(s) %s "
731                 ", segment %d: %s" % \
732                 (bad_shnums, readers, self._current_segment, str(f)))
733        for reader in readers:
734            self._mark_bad_share(reader.server, reader.shnum, reader, f)
735        return None
736
737
738    @deferredutil.async_to_deferred
739    async def _validate_block(self, results, segnum, reader, server, started):
740        """
741        I validate a block from one share on a remote server.
742        """
743        # Grab the part of the block hash tree that is necessary to
744        # validate this block, then generate the block hash root.
745        self.log("validating share %d for segment %d" % (reader.shnum,
746                                                             segnum))
747        elapsed = time.time() - started
748        self._status.add_fetch_timing(server, elapsed)
749        self._set_current_status("validating blocks")
750
751        block_and_salt, blockhashes, sharehashes = results
752        block, salt = block_and_salt
753        _assert(isinstance(block, bytes), (block, salt))
754
755        blockhashes = dict(enumerate(blockhashes))
756        self.log("the reader gave me the following blockhashes: %s" % \
757                 list(blockhashes.keys()))
758        self.log("the reader gave me the following sharehashes: %s" % \
759                 list(sharehashes.keys()))
760        bht = self._block_hash_trees[reader.shnum]
761
762        if bht.needed_hashes(segnum, include_leaf=True):
763            try:
764                bht.set_hashes(blockhashes)
765            except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
766                    IndexError) as e:
767                raise CorruptShareError(server,
768                                        reader.shnum,
769                                        "block hash tree failure: %s" % e)
770
771        if self._version == MDMF_VERSION:
772            blockhash = await defer_to_thread(hashutil.block_hash, salt + block)
773        else:
774            blockhash = await defer_to_thread(hashutil.block_hash, block)
775        # If this works without an error, then validation is
776        # successful.
777        try:
778           bht.set_hashes(leaves={segnum: blockhash})
779        except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
780                IndexError) as e:
781            raise CorruptShareError(server,
782                                    reader.shnum,
783                                    "block hash tree failure: %s" % e)
784
785        # Reaching this point means that we know that this segment
786        # is correct. Now we need to check to see whether the share
787        # hash chain is also correct.
788        # SDMF wrote share hash chains that didn't contain the
789        # leaves, which would be produced from the block hash tree.
790        # So we need to validate the block hash tree first. If
791        # successful, then bht[0] will contain the root for the
792        # shnum, which will be a leaf in the share hash tree, which
793        # will allow us to validate the rest of the tree.
794        try:
795            self.share_hash_tree.set_hashes(hashes=sharehashes,
796                                        leaves={reader.shnum: bht[0]})
797        except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
798                IndexError) as e:
799            raise CorruptShareError(server,
800                                    reader.shnum,
801                                    "corrupt hashes: %s" % e)
802
803        self.log('share %d is valid for segment %d' % (reader.shnum,
804                                                       segnum))
805        return {reader.shnum: (block, salt)}
806
807
808    def _get_needed_hashes(self, reader, segnum):
809        """
810        I get the hashes needed to validate segnum from the reader, then return
811        to my caller when this is done.
812        """
813        bht = self._block_hash_trees[reader.shnum]
814        needed = bht.needed_hashes(segnum, include_leaf=True)
815        # The root of the block hash tree is also a leaf in the share
816        # hash tree. So we don't need to fetch it from the remote
817        # server. In the case of files with one segment, this means that
818        # we won't fetch any block hash tree from the remote server,
819        # since the hash of each share of the file is the entire block
820        # hash tree, and is a leaf in the share hash tree. This is fine,
821        # since any share corruption will be detected in the share hash
822        # tree.
823        #needed.discard(0)
824        self.log("getting blockhashes for segment %d, share %d: %s" % \
825                 (segnum, reader.shnum, str(needed)))
826        # TODO is force_remote necessary here?
827        d1 = reader.get_blockhashes(needed, force_remote=False)
828        if self.share_hash_tree.needed_hashes(reader.shnum):
829            need = self.share_hash_tree.needed_hashes(reader.shnum)
830            self.log("also need sharehashes for share %d: %s" % (reader.shnum,
831                                                                 str(need)))
832            d2 = reader.get_sharehashes(need, force_remote=False)
833        else:
834            d2 = defer.succeed({}) # the logic in the next method
835                                   # expects a dict
836        return d1,d2
837
838
839    def _decode_blocks(self, results, segnum):
840        """
841        I take a list of k blocks and salts, and decode that into a
842        single encrypted segment.
843        """
844        # 'results' is one or more dicts (each {shnum:(block,salt)}), and we
845        # want to merge them all
846        blocks_and_salts = {}
847        for d in results:
848            blocks_and_salts.update(d)
849
850        # All of these blocks should have the same salt; in SDMF, it is
851        # the file-wide IV, while in MDMF it is the per-segment salt. In
852        # either case, we just need to get one of them and use it.
853        #
854        # d.items()[0] is like (shnum, (block, salt))
855        # d.items()[0][1] is like (block, salt)
856        # d.items()[0][1][1] is the salt.
857        salt = list(blocks_and_salts.items())[0][1][1]
858        # Next, extract just the blocks from the dict. We'll use the
859        # salt in the next step.
860        share_and_shareids = [(k, v[0]) for k, v in blocks_and_salts.items()]
861        d2 = dict(share_and_shareids)
862        shareids = []
863        shares = []
864        for shareid, share in d2.items():
865            shareids.append(shareid)
866            shares.append(share)
867
868        self._set_current_status("decoding")
869        started = time.time()
870        _assert(len(shareids) >= self._required_shares, len(shareids))
871        # zfec really doesn't want extra shares
872        shareids = shareids[:self._required_shares]
873        shares = shares[:self._required_shares]
874        self.log("decoding segment %d" % segnum)
875        if segnum == self._num_segments - 1:
876            d = self._tail_decoder.decode(shares, shareids)
877        else:
878            d = self._segment_decoder.decode(shares, shareids)
879
880        # For larger shares, this can take a few milliseconds. As such, we want
881        # to unblock the event loop. In newer Python b"".join() will release
882        # the GIL: https://github.com/python/cpython/issues/80232
883        @deferredutil.async_to_deferred
884        async def _got_buffers(buffers):
885            return await defer_to_thread(lambda: b"".join(buffers))
886
887        d.addCallback(_got_buffers)
888
889        def _process(segment):
890            self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
891                     segnum=segnum,
892                     numsegs=self._num_segments,
893                     level=log.NOISY)
894            self.log(" joined length %d, datalength %d" %
895                     (len(segment), self._data_length))
896            if segnum == self._num_segments - 1:
897                size_to_use = self._tail_data_size
898            else:
899                size_to_use = self._segment_size
900            segment = segment[:size_to_use]
901            self.log(" segment len=%d" % len(segment))
902            self._status.accumulate_decode_time(time.time() - started)
903            return segment, salt
904        d.addCallback(_process)
905        return d
906
907    @deferredutil.async_to_deferred
908    async def _decrypt_segment(self, segment_and_salt):
909        """
910        I take a single segment and its salt, and decrypt it. I return
911        the plaintext of the segment that is in my argument.
912        """
913        segment, salt = segment_and_salt
914        self._set_current_status("decrypting")
915        self.log("decrypting segment %d" % self._current_segment)
916        started = time.time()
917        readkey = self._node.get_readkey()
918
919        def decrypt():
920            key = hashutil.ssk_readkey_data_hash(salt, readkey)
921            decryptor = aes.create_decryptor(key)
922            return aes.decrypt_data(decryptor, segment)
923
924        plaintext = await defer_to_thread(decrypt)
925        self._status.accumulate_decrypt_time(time.time() - started)
926        return plaintext
927
928
929    def notify_server_corruption(self, server, shnum, reason):
930        if isinstance(reason, str):
931            reason = reason.encode("utf-8")
932        storage_server = server.get_storage_server()
933        storage_server.advise_corrupt_share(
934            b"mutable",
935            self._storage_index,
936            shnum,
937            reason,
938        )
939
940    @deferredutil.async_to_deferred
941    async def _try_to_validate_privkey(self, enc_privkey, reader, server):
942        node_writekey = self._node.get_writekey()
943
944        def get_privkey():
945            alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey)
946            alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
947            if alleged_writekey != node_writekey:
948                return None
949            privkey, _ = rsa.create_signing_keypair_from_string(alleged_privkey_s)
950            return privkey
951
952        privkey = await defer_to_thread(get_privkey)
953        if privkey is None:
954            self.log("invalid privkey from %s shnum %d" %
955                     (reader, reader.shnum),
956                     level=log.WEIRD, umid="YIw4tA")
957            if self._verify:
958                self.servermap.mark_bad_share(server, reader.shnum,
959                                              self.verinfo[-2])
960                e = CorruptShareError(server,
961                                      reader.shnum,
962                                      "invalid privkey")
963                f = failure.Failure(e)
964                self._bad_shares.add((server, reader.shnum, f))
965            return
966
967        # it's good
968        self.log("got valid privkey from shnum %d on reader %s" %
969                 (reader.shnum, reader))
970        self._node._populate_encprivkey(enc_privkey)
971        self._node._populate_privkey(privkey)
972        self._need_privkey = False
973
974    def _done(self):
975        """
976        I am called by _download_current_segment when the download process
977        has finished successfully. After making some useful logging
978        statements, I return the decrypted contents to the owner of this
979        Retrieve object through self._done_deferred.
980        """
981        self._running = False
982        self._status.set_active(False)
983        now = time.time()
984        self._status.timings['total'] = now - self._started
985        self._status.timings['fetch'] = now - self._started_fetching
986        self._status.set_status("Finished")
987        self._status.set_progress(1.0)
988
989        # remember the encoding parameters, use them again next time
990        (seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
991         offsets_tuple) = self.verinfo
992        self._node._populate_required_shares(k)
993        self._node._populate_total_shares(N)
994
995        if self._verify:
996            ret = self._bad_shares
997            self.log("done verifying, found %d bad shares" % len(ret))
998        else:
999            # TODO: upload status here?
1000            ret = self._consumer
1001            self._consumer.unregisterProducer()
1002        eventually(self._done_deferred.callback, ret)
1003
1004    def _raise_notenoughshareserror(self):
1005        """
1006        I am called when there are not enough active servers left to complete
1007        the download. After making some useful logging statements, I throw an
1008        exception to that effect to the caller of this Retrieve object through
1009        self._done_deferred.
1010        """
1011
1012        format = ("ran out of servers: "
1013                  "have %(have)d of %(total)d segments; "
1014                  "found %(bad)d bad shares; "
1015                  "have %(remaining)d remaining shares of the right version; "
1016                  "encoding %(k)d-of-%(n)d")
1017        args = {"have": self._current_segment,
1018                "total": self._num_segments,
1019                "need": self._last_segment,
1020                "k": self._required_shares,
1021                "n": self._total_shares,
1022                "bad": len(self._bad_shares),
1023                "remaining": len(self.remaining_sharemap),
1024               }
1025        raise NotEnoughSharesError("%s, last failure: %s" %
1026                                   (format % args, str(self._last_failure)))
1027
1028    def _error(self, f):
1029        # all errors, including NotEnoughSharesError, land here
1030        self._running = False
1031        self._status.set_active(False)
1032        now = time.time()
1033        self._status.timings['total'] = now - self._started
1034        self._status.timings['fetch'] = now - self._started_fetching
1035        self._status.set_status("Failed")
1036        eventually(self._done_deferred.errback, f)
Note: See TracBrowser for help on using the repository browser.