source: trunk/src/allmydata/test/test_encode.py

Last change on this file was 4da491a, checked in by Alexandre Detiste <alexandre.detiste@…>, at 2024-03-11T20:37:27Z

remove more usage of "future"

  • Property mode set to 100644
File size: 15.4 KB
Line 
1"""
2Ported to Python 3.
3"""
4
5from past.builtins import chr as byteschr
6
7from zope.interface import implementer
8from twisted.trial import unittest
9from twisted.internet import defer
10from twisted.python.failure import Failure
11from foolscap.api import fireEventually
12from allmydata import uri
13from allmydata.immutable import encode, upload, checker
14from allmydata.util import hashutil
15from allmydata.util.assertutil import _assert
16from allmydata.util.consumer import download_to_data
17from allmydata.interfaces import IStorageBucketWriter, IStorageBucketReader
18from allmydata.test.no_network import GridTestMixin
19
20class LostPeerError(Exception):
21    pass
22
23def flip_bit(good): # flips the last bit
24    return good[:-1] + byteschr(ord(good[-1]) ^ 0x01)
25
26@implementer(IStorageBucketWriter, IStorageBucketReader)
27class FakeBucketReaderWriterProxy(object):
28    # these are used for both reading and writing
29    def __init__(self, mode="good", peerid="peer"):
30        self.mode = mode
31        self.blocks = {}
32        self.plaintext_hashes = []
33        self.crypttext_hashes = []
34        self.block_hashes = None
35        self.share_hashes = None
36        self.closed = False
37        self.peerid = peerid
38
39    def get_peerid(self):
40        return self.peerid
41
42    def _start(self):
43        if self.mode == "lost-early":
44            f = Failure(LostPeerError("I went away early"))
45            return fireEventually(f)
46        return defer.succeed(self)
47
48    def put_header(self):
49        return self._start()
50
51    def put_block(self, segmentnum, data):
52        if self.mode == "lost-early":
53            f = Failure(LostPeerError("I went away early"))
54            return fireEventually(f)
55        def _try():
56            assert not self.closed
57            assert segmentnum not in self.blocks
58            if self.mode == "lost" and segmentnum >= 1:
59                raise LostPeerError("I'm going away now")
60            self.blocks[segmentnum] = data
61        return defer.maybeDeferred(_try)
62
63    def put_crypttext_hashes(self, hashes):
64        def _try():
65            assert not self.closed
66            assert not self.crypttext_hashes
67            self.crypttext_hashes = hashes
68        return defer.maybeDeferred(_try)
69
70    def put_block_hashes(self, blockhashes):
71        def _try():
72            assert not self.closed
73            assert self.block_hashes is None
74            self.block_hashes = blockhashes
75        return defer.maybeDeferred(_try)
76
77    def put_share_hashes(self, sharehashes):
78        def _try():
79            assert not self.closed
80            assert self.share_hashes is None
81            self.share_hashes = sharehashes
82        return defer.maybeDeferred(_try)
83
84    def put_uri_extension(self, uri_extension):
85        def _try():
86            assert not self.closed
87            self.uri_extension = uri_extension
88        return defer.maybeDeferred(_try)
89
90    def close(self):
91        def _try():
92            assert not self.closed
93            self.closed = True
94        return defer.maybeDeferred(_try)
95
96    def abort(self):
97        return defer.succeed(None)
98
99    def get_block_data(self, blocknum, blocksize, size):
100        d = self._start()
101        def _try(unused=None):
102            assert isinstance(blocknum, int)
103            if self.mode == "bad block":
104                return flip_bit(self.blocks[blocknum])
105            return self.blocks[blocknum]
106        d.addCallback(_try)
107        return d
108
109    def get_plaintext_hashes(self):
110        d = self._start()
111        def _try(unused=None):
112            hashes = self.plaintext_hashes[:]
113            return hashes
114        d.addCallback(_try)
115        return d
116
117    def get_crypttext_hashes(self):
118        d = self._start()
119        def _try(unused=None):
120            hashes = self.crypttext_hashes[:]
121            if self.mode == "bad crypttext hashroot":
122                hashes[0] = flip_bit(hashes[0])
123            if self.mode == "bad crypttext hash":
124                hashes[1] = flip_bit(hashes[1])
125            return hashes
126        d.addCallback(_try)
127        return d
128
129    def get_block_hashes(self, at_least_these=()):
130        d = self._start()
131        def _try(unused=None):
132            if self.mode == "bad blockhash":
133                hashes = self.block_hashes[:]
134                hashes[1] = flip_bit(hashes[1])
135                return hashes
136            return self.block_hashes
137        d.addCallback(_try)
138        return d
139
140    def get_share_hashes(self, at_least_these=()):
141        d = self._start()
142        def _try(unused=None):
143            if self.mode == "bad sharehash":
144                hashes = self.share_hashes[:]
145                hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
146                return hashes
147            if self.mode == "missing sharehash":
148                # one sneaky attack would be to pretend we don't know our own
149                # sharehash, which could manage to frame someone else.
150                # download.py is supposed to guard against this case.
151                return []
152            return self.share_hashes
153        d.addCallback(_try)
154        return d
155
156    def get_uri_extension(self):
157        d = self._start()
158        def _try(unused=None):
159            if self.mode == "bad uri_extension":
160                return flip_bit(self.uri_extension)
161            return self.uri_extension
162        d.addCallback(_try)
163        return d
164
165
166def make_data(length):
167    data = b"happy happy joy joy" * 100
168    assert length <= len(data)
169    return data[:length]
170
171class ValidatedExtendedURIProxy(unittest.TestCase):
172    K = 4
173    M = 10
174    SIZE = 200
175    SEGSIZE = 72
176    _TMP = SIZE%SEGSIZE
177    if _TMP == 0:
178        _TMP = SEGSIZE
179    if _TMP % K != 0:
180        _TMP += (K - (_TMP % K))
181    TAIL_SEGSIZE = _TMP
182    _TMP = SIZE // SEGSIZE
183    if SIZE % SEGSIZE != 0:
184        _TMP += 1
185    NUM_SEGMENTS = _TMP
186    mindict = { 'segment_size': SEGSIZE,
187                'crypttext_root_hash': b'0'*hashutil.CRYPTO_VAL_SIZE,
188                'share_root_hash': b'1'*hashutil.CRYPTO_VAL_SIZE }
189    optional_consistent = { 'crypttext_hash': b'2'*hashutil.CRYPTO_VAL_SIZE,
190                            'codec_name': b"crs",
191                            'codec_params': b"%d-%d-%d" % (SEGSIZE, K, M),
192                            'tail_codec_params': b"%d-%d-%d" % (TAIL_SEGSIZE, K, M),
193                            'num_segments': NUM_SEGMENTS,
194                            'size': SIZE,
195                            'needed_shares': K,
196                            'total_shares': M,
197                            'plaintext_hash': b"anything",
198                            'plaintext_root_hash': b"anything", }
199    # optional_inconsistent = { 'crypttext_hash': ('2'*(hashutil.CRYPTO_VAL_SIZE-1), "", 77),
200    optional_inconsistent = { 'crypttext_hash': (77,),
201                              'codec_name': (b"digital fountain", b""),
202                              'codec_params': (b"%d-%d-%d" % (SEGSIZE, K-1, M),
203                                               b"%d-%d-%d" % (SEGSIZE-1, K, M),
204                                               b"%d-%d-%d" % (SEGSIZE, K, M-1)),
205                              'tail_codec_params': (b"%d-%d-%d" % (TAIL_SEGSIZE, K-1, M),
206                                               b"%d-%d-%d" % (TAIL_SEGSIZE-1, K, M),
207                                               b"%d-%d-%d" % (TAIL_SEGSIZE, K, M-1)),
208                              'num_segments': (NUM_SEGMENTS-1,),
209                              'size': (SIZE-1,),
210                              'needed_shares': (K-1,),
211                              'total_shares': (M-1,), }
212
213    def _test(self, uebdict):
214        uebstring = uri.pack_extension(uebdict)
215        uebhash = hashutil.uri_extension_hash(uebstring)
216        fb = FakeBucketReaderWriterProxy()
217        fb.put_uri_extension(uebstring)
218        verifycap = uri.CHKFileVerifierURI(storage_index=b'x'*16, uri_extension_hash=uebhash, needed_shares=self.K, total_shares=self.M, size=self.SIZE)
219        vup = checker.ValidatedExtendedURIProxy(fb, verifycap)
220        return vup.start()
221
222    def _test_accept(self, uebdict):
223        return self._test(uebdict)
224
225    def _should_fail(self, res, expected_failures):
226        if isinstance(res, Failure):
227            res.trap(*expected_failures)
228        else:
229            self.fail("was supposed to raise %s, not get '%s'" % (expected_failures, res))
230
231    def _test_reject(self, uebdict):
232        d = self._test(uebdict)
233        d.addBoth(self._should_fail, (KeyError, checker.BadURIExtension))
234        return d
235
236    def test_accept_minimal(self):
237        return self._test_accept(self.mindict)
238
239    def test_reject_insufficient(self):
240        dl = []
241        for k in self.mindict.keys():
242            insuffdict = self.mindict.copy()
243            del insuffdict[k]
244            d = self._test_reject(insuffdict)
245        dl.append(d)
246        return defer.DeferredList(dl)
247
248    def test_accept_optional(self):
249        dl = []
250        for k in self.optional_consistent.keys():
251            mydict = self.mindict.copy()
252            mydict[k] = self.optional_consistent[k]
253            d = self._test_accept(mydict)
254        dl.append(d)
255        return defer.DeferredList(dl)
256
257    def test_reject_optional(self):
258        dl = []
259        for k in self.optional_inconsistent.keys():
260            for v in self.optional_inconsistent[k]:
261                mydict = self.mindict.copy()
262                mydict[k] = v
263                d = self._test_reject(mydict)
264                dl.append(d)
265        return defer.DeferredList(dl)
266
267class Encode(unittest.TestCase):
268    def do_encode(self, max_segment_size, datalen, NUM_SHARES, NUM_SEGMENTS,
269                  expected_block_hashes, expected_share_hashes):
270        data = make_data(datalen)
271        # force use of multiple segments
272        e = encode.Encoder()
273        u = upload.Data(data, convergence=b"some convergence string")
274        u.set_default_encoding_parameters({'max_segment_size': max_segment_size,
275                                           'k': 25, 'happy': 75, 'n': 100})
276        eu = upload.EncryptAnUploadable(u)
277        d = e.set_encrypted_uploadable(eu)
278
279        all_shareholders = []
280        def _ready(res):
281            k,happy,n = e.get_param("share_counts")
282            _assert(n == NUM_SHARES) # else we'll be completely confused
283            numsegs = e.get_param("num_segments")
284            _assert(numsegs == NUM_SEGMENTS, numsegs, NUM_SEGMENTS)
285            segsize = e.get_param("segment_size")
286            _assert( (NUM_SEGMENTS-1)*segsize < len(data) <= NUM_SEGMENTS*segsize,
287                     NUM_SEGMENTS, segsize,
288                     (NUM_SEGMENTS-1)*segsize, len(data), NUM_SEGMENTS*segsize)
289
290            shareholders = {}
291            servermap = {}
292            for shnum in range(NUM_SHARES):
293                peer = FakeBucketReaderWriterProxy()
294                shareholders[shnum] = peer
295                servermap.setdefault(shnum, set()).add(peer.get_peerid())
296                all_shareholders.append(peer)
297            e.set_shareholders(shareholders, servermap)
298            return e.start()
299        d.addCallback(_ready)
300
301        def _check(res):
302            verifycap = res
303            self.failUnless(isinstance(verifycap.uri_extension_hash, bytes))
304            self.failUnlessEqual(len(verifycap.uri_extension_hash), 32)
305            for i,peer in enumerate(all_shareholders):
306                self.failUnless(peer.closed)
307                self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
308                # each peer gets a full tree of block hashes. For 3 or 4
309                # segments, that's 7 hashes. For 5 segments it's 15 hashes.
310                self.failUnlessEqual(len(peer.block_hashes),
311                                     expected_block_hashes)
312                for h in peer.block_hashes:
313                    self.failUnlessEqual(len(h), 32)
314                # each peer also gets their necessary chain of share hashes.
315                # For 100 shares (rounded up to 128 leaves), that's 8 hashes
316                self.failUnlessEqual(len(peer.share_hashes),
317                                     expected_share_hashes)
318                for (hashnum, h) in peer.share_hashes:
319                    self.failUnless(isinstance(hashnum, int))
320                    self.failUnlessEqual(len(h), 32)
321        d.addCallback(_check)
322
323        return d
324
325    def test_send_74(self):
326        # 3 segments (25, 25, 24)
327        return self.do_encode(25, 74, 100, 3, 7, 8)
328    def test_send_75(self):
329        # 3 segments (25, 25, 25)
330        return self.do_encode(25, 75, 100, 3, 7, 8)
331    def test_send_51(self):
332        # 3 segments (25, 25, 1)
333        return self.do_encode(25, 51, 100, 3, 7, 8)
334
335    def test_send_76(self):
336        # encode a 76 byte file (in 4 segments: 25,25,25,1) to 100 shares
337        return self.do_encode(25, 76, 100, 4, 7, 8)
338    def test_send_99(self):
339        # 4 segments: 25,25,25,24
340        return self.do_encode(25, 99, 100, 4, 7, 8)
341    def test_send_100(self):
342        # 4 segments: 25,25,25,25
343        return self.do_encode(25, 100, 100, 4, 7, 8)
344
345    def test_send_124(self):
346        # 5 segments: 25, 25, 25, 25, 24
347        return self.do_encode(25, 124, 100, 5, 15, 8)
348    def test_send_125(self):
349        # 5 segments: 25, 25, 25, 25, 25
350        return self.do_encode(25, 125, 100, 5, 15, 8)
351    def test_send_101(self):
352        # 5 segments: 25, 25, 25, 25, 1
353        return self.do_encode(25, 101, 100, 5, 15, 8)
354
355
356class Roundtrip(GridTestMixin, unittest.TestCase):
357
358    # a series of 3*3 tests to check out edge conditions. One axis is how the
359    # plaintext is divided into segments: kn+(-1,0,1). Another way to express
360    # this is n%k == -1 or 0 or 1. For example, for 25-byte segments, we
361    # might test 74 bytes, 75 bytes, and 76 bytes.
362
363    # on the other axis is how many leaves in the block hash tree we wind up
364    # with, relative to a power of 2, so 2^a+(-1,0,1). Each segment turns
365    # into a single leaf. So we'd like to check out, e.g., 3 segments, 4
366    # segments, and 5 segments.
367
368    # that results in the following series of data lengths:
369    #  3 segs: 74, 75, 51
370    #  4 segs: 99, 100, 76
371    #  5 segs: 124, 125, 101
372
373    # all tests encode to 100 shares, which means the share hash tree will
374    # have 128 leaves, which means that buckets will be given an 8-long share
375    # hash chain
376
377    # all 3-segment files will have a 4-leaf blockhashtree, and thus expect
378    # to get 7 blockhashes. 4-segment files will also get 4-leaf block hash
379    # trees and 7 blockhashes. 5-segment files will get 8-leaf block hash
380    # trees, which gets 15 blockhashes.
381
382    def test_74(self): return self.do_test_size(74)
383    def test_75(self): return self.do_test_size(75)
384    def test_51(self): return self.do_test_size(51)
385    def test_99(self): return self.do_test_size(99)
386    def test_100(self): return self.do_test_size(100)
387    def test_76(self): return self.do_test_size(76)
388    def test_124(self): return self.do_test_size(124)
389    def test_125(self): return self.do_test_size(125)
390    def test_101(self): return self.do_test_size(101)
391
392    def upload(self, data):
393        u = upload.Data(data, None)
394        u.max_segment_size = 25
395        u.encoding_param_k = 25
396        u.encoding_param_happy = 1
397        u.encoding_param_n = 100
398        d = self.c0.upload(u)
399        d.addCallback(lambda ur: self.c0.create_node_from_uri(ur.get_uri()))
400        # returns a FileNode
401        return d
402
403    def do_test_size(self, size):
404        self.basedir = self.mktemp()
405        self.set_up_grid()
406        self.c0 = self.g.clients[0]
407        DATA = b"p"*size
408        d = self.upload(DATA)
409        d.addCallback(lambda n: download_to_data(n))
410        def _downloaded(newdata):
411            self.failUnlessEqual(newdata, DATA)
412        d.addCallback(_downloaded)
413        return d
Note: See TracBrowser for help on using the repository browser.