source: trunk/src/allmydata/test/test_spans.py

Last change on this file was 2243ce3, checked in by Alexandre Detiste <alexandre.detiste@…>, at 2024-02-28T00:07:08Z

remove "from past.builtins import long"

  • Property mode set to 100644
File size: 21.0 KB
Line 
1"""
2Tests for allmydata.util.spans.
3"""
4
5import binascii
6import hashlib
7
8from twisted.trial import unittest
9
10from allmydata.util.spans import Spans, overlap, DataSpans
11
12
13def sha256(data):
14    """
15    :param bytes data: data to hash
16
17    :returns: a hex-encoded SHA256 hash of the data
18    """
19    return binascii.hexlify(hashlib.sha256(data).digest())
20
21
22class SimpleSpans(object):
23    # this is a simple+inefficient form of util.spans.Spans . We compare the
24    # behavior of this reference model against the real (efficient) form.
25
26    def __init__(self, _span_or_start=None, length=None):
27        self._have = set()
28        if length is not None:
29            for i in range(_span_or_start, _span_or_start+length):
30                self._have.add(i)
31        elif _span_or_start:
32            for (start,length) in _span_or_start:
33                self.add(start, length)
34
35    def add(self, start, length):
36        for i in range(start, start+length):
37            self._have.add(i)
38        return self
39
40    def remove(self, start, length):
41        for i in range(start, start+length):
42            self._have.discard(i)
43        return self
44
45    def each(self):
46        return sorted(self._have)
47
48    def __iter__(self):
49        items = sorted(self._have)
50        prevstart = None
51        prevend = None
52        for i in items:
53            if prevstart is None:
54                prevstart = prevend = i
55                continue
56            if i == prevend+1:
57                prevend = i
58                continue
59            yield (prevstart, prevend-prevstart+1)
60            prevstart = prevend = i
61        if prevstart is not None:
62            yield (prevstart, prevend-prevstart+1)
63
64    def __bool__(self): # this gets us bool()
65        return bool(self.len())
66
67    def len(self):
68        return len(self._have)
69
70    def __add__(self, other):
71        s = self.__class__(self)
72        for (start, length) in other:
73            s.add(start, length)
74        return s
75
76    def __sub__(self, other):
77        s = self.__class__(self)
78        for (start, length) in other:
79            s.remove(start, length)
80        return s
81
82    def __iadd__(self, other):
83        for (start, length) in other:
84            self.add(start, length)
85        return self
86
87    def __isub__(self, other):
88        for (start, length) in other:
89            self.remove(start, length)
90        return self
91
92    def __and__(self, other):
93        s = self.__class__()
94        for i in other.each():
95            if i in self._have:
96                s.add(i, 1)
97        return s
98
99    def __contains__(self, start_and_length):
100        (start, length) = start_and_length
101        for i in range(start, start+length):
102            if i not in self._have:
103                return False
104        return True
105
106class ByteSpans(unittest.TestCase):
107    def test_basic(self):
108        s = Spans()
109        self.failUnlessEqual(list(s), [])
110        self.failIf(s)
111        self.failIf((0,1) in s)
112        self.failUnlessEqual(s.len(), 0)
113
114        s1 = Spans(3, 4) # 3,4,5,6
115        self._check1(s1)
116
117        s2 = Spans(s1)
118        self._check1(s2)
119
120        s2.add(10,2) # 10,11
121        self._check1(s1)
122        self.failUnless((10,1) in s2)
123        self.failIf((10,1) in s1)
124        self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
125        self.failUnlessEqual(s2.len(), 6)
126
127        s2.add(15,2).add(20,2)
128        self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
129        self.failUnlessEqual(s2.len(), 10)
130
131        s2.remove(4,3).remove(15,1)
132        self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
133        self.failUnlessEqual(s2.len(), 6)
134
135        s1 = SimpleSpans(3, 4) # 3 4 5 6
136        s2 = SimpleSpans(5, 4) # 5 6 7 8
137        i = s1 & s2
138        self.failUnlessEqual(list(i.each()), [5, 6])
139
140    def _check1(self, s):
141        self.failUnlessEqual(list(s), [(3,4)])
142        self.failUnless(s)
143        self.failUnlessEqual(s.len(), 4)
144        self.failIf((0,1) in s)
145        self.failUnless((3,4) in s)
146        self.failUnless((3,1) in s)
147        self.failUnless((5,2) in s)
148        self.failUnless((6,1) in s)
149        self.failIf((6,2) in s)
150        self.failIf((7,1) in s)
151        self.failUnlessEqual(list(s.each()), [3,4,5,6])
152
153    def test_large(self):
154        s = Spans(4, 2**65) # don't do this with a SimpleSpans
155        self.failUnlessEqual(list(s), [(4, 2**65)])
156        self.failUnless(s)
157        self.failUnlessEqual(s.len(), 2**65)
158        self.failIf((0,1) in s)
159        self.failUnless((4,2) in s)
160        self.failUnless((2**65,2) in s)
161
162    def test_math(self):
163        s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
164        s2 = Spans(5, 3) # 5,6,7
165        s3 = Spans(8, 4) # 8,9,10,11
166
167        s = s1 - s2
168        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
169        s = s1 - s3
170        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
171        s = s2 - s3
172        self.failUnlessEqual(list(s.each()), [5,6,7])
173        s = s1 & s2
174        self.failUnlessEqual(list(s.each()), [5,6,7])
175        s = s2 & s1
176        self.failUnlessEqual(list(s.each()), [5,6,7])
177        s = s1 & s3
178        self.failUnlessEqual(list(s.each()), [8,9])
179        s = s3 & s1
180        self.failUnlessEqual(list(s.each()), [8,9])
181        s = s2 & s3
182        self.failUnlessEqual(list(s.each()), [])
183        s = s3 & s2
184        self.failUnlessEqual(list(s.each()), [])
185        s = Spans() & s3
186        self.failUnlessEqual(list(s.each()), [])
187        s = s3 & Spans()
188        self.failUnlessEqual(list(s.each()), [])
189
190        s = s1 + s2
191        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
192        s = s1 + s3
193        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
194        s = s2 + s3
195        self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
196
197        s = Spans(s1)
198        s -= s2
199        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
200        s = Spans(s1)
201        s -= s3
202        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
203        s = Spans(s2)
204        s -= s3
205        self.failUnlessEqual(list(s.each()), [5,6,7])
206
207        s = Spans(s1)
208        s += s2
209        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
210        s = Spans(s1)
211        s += s3
212        self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
213        s = Spans(s2)
214        s += s3
215        self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
216
217    def test_random(self):
218        # attempt to increase coverage of corner cases by comparing behavior
219        # of a simple-but-slow model implementation against the
220        # complex-but-fast actual implementation, in a large number of random
221        # operations
222        S1 = SimpleSpans
223        S2 = Spans
224        s1 = S1(); s2 = S2()
225        seed = b""
226        def _create(subseed):
227            ns1 = S1(); ns2 = S2()
228            for i in range(10):
229                what = sha256(subseed+bytes(i))
230                start = int(what[2:4], 16)
231                length = max(1,int(what[5:6], 16))
232                ns1.add(start, length); ns2.add(start, length)
233            return ns1, ns2
234
235        #print()
236        for i in range(1000):
237            what = sha256(seed+bytes(i))
238            op = what[0:1]
239            subop = what[1:2]
240            start = int(what[2:4], 16)
241            length = max(1,int(what[5:6], 16))
242            #print(what)
243            if op in b"0":
244                if subop in b"01234":
245                    s1 = S1(); s2 = S2()
246                elif subop in b"5678":
247                    s1 = S1(start, length); s2 = S2(start, length)
248                else:
249                    s1 = S1(s1); s2 = S2(s2)
250                #print("s2 = %s" % s2.dump())
251            elif op in b"123":
252                #print("s2.add(%d,%d)" % (start, length))
253                s1.add(start, length); s2.add(start, length)
254            elif op in b"456":
255                #print("s2.remove(%d,%d)" % (start, length))
256                s1.remove(start, length); s2.remove(start, length)
257            elif op in b"78":
258                ns1, ns2 = _create(what[7:11])
259                #print("s2 + %s" % ns2.dump())
260                s1 = s1 + ns1; s2 = s2 + ns2
261            elif op in b"9a":
262                ns1, ns2 = _create(what[7:11])
263                #print("%s - %s" % (s2.dump(), ns2.dump()))
264                s1 = s1 - ns1; s2 = s2 - ns2
265            elif op in b"bc":
266                ns1, ns2 = _create(what[7:11])
267                #print("s2 += %s" % ns2.dump())
268                s1 += ns1; s2 += ns2
269            elif op in b"de":
270                ns1, ns2 = _create(what[7:11])
271                #print("%s -= %s" % (s2.dump(), ns2.dump()))
272                s1 -= ns1; s2 -= ns2
273            else:
274                ns1, ns2 = _create(what[7:11])
275                #print("%s &= %s" % (s2.dump(), ns2.dump()))
276                s1 = s1 & ns1; s2 = s2 & ns2
277            #print("s2 now %s" % s2.dump())
278            self.failUnlessEqual(list(s1.each()), list(s2.each()))
279            self.failUnlessEqual(s1.len(), s2.len())
280            self.failUnlessEqual(bool(s1), bool(s2))
281            self.failUnlessEqual(list(s1), list(s2))
282            for j in range(10):
283                what = sha256(what[12:14]+bytes(j))
284                start = int(what[2:4], 16)
285                length = max(1, int(what[5:6], 16))
286                span = (start, length)
287                self.failUnlessEqual(bool(span in s1), bool(span in s2))
288
289
290    # s()
291    # s(start,length)
292    # s(s0)
293    # s.add(start,length) : returns s
294    # s.remove(start,length)
295    # s.each() -> list of byte offsets, mostly for testing
296    # list(s) -> list of (start,length) tuples, one per span
297    # (start,length) in s -> True if (start..start+length-1) are all members
298    #  NOT equivalent to x in list(s)
299    # s.len() -> number of bytes, for testing, bool(), and accounting/limiting
300    # bool(s)  (__nonzeron__)
301    # s = s1+s2, s1-s2, +=s1, -=s1
302
303    def test_overlap(self):
304        for a in range(20):
305            for b in range(10):
306                for c in range(20):
307                    for d in range(10):
308                        self._test_overlap(a,b,c,d)
309
310    def _test_overlap(self, a, b, c, d):
311        s1 = set(range(a,a+b))
312        s2 = set(range(c,c+d))
313        #print("---")
314        #self._show_overlap(s1, "1")
315        #self._show_overlap(s2, "2")
316        o = overlap(a,b,c,d)
317        expected = s1.intersection(s2)
318        if not expected:
319            self.failUnlessEqual(o, None)
320        else:
321            start,length = o
322            so = set(range(start,start+length))
323            #self._show(so, "o")
324            self.failUnlessEqual(so, expected)
325
326    def _show_overlap(self, s, c):
327        import sys
328        out = sys.stdout
329        if s:
330            for i in range(max(s)):
331                if i in s:
332                    out.write(c)
333                else:
334                    out.write(" ")
335        out.write("\n")
336
337def extend(s, start, length, fill):
338    if len(s) >= start+length:
339        return s
340    assert len(fill) == 1
341    return s + fill*(start+length-len(s))
342
343def replace(s, start, data):
344    assert len(s) >= start+len(data)
345    return s[:start] + data + s[start+len(data):]
346
347class SimpleDataSpans(object):
348    def __init__(self, other=None):
349        self.missing = "" # "1" where missing, "0" where found
350        self.data = b""
351        if other:
352            for (start, data) in other.get_chunks():
353                self.add(start, data)
354
355    def __bool__(self): # this gets us bool()
356        return bool(self.len())
357
358    def len(self):
359        return len(self.missing.replace("1", ""))
360
361    def _dump(self):
362        return [i for (i,c) in enumerate(self.missing) if c == "0"]
363
364    def _have(self, start, length):
365        m = self.missing[start:start+length]
366        if not m or len(m)<length or int(m):
367            return False
368        return True
369    def get_chunks(self):
370        for i in self._dump():
371            yield (i, self.data[i:i+1])
372    def get_spans(self):
373        return SimpleSpans([(start,len(data))
374                            for (start,data) in self.get_chunks()])
375    def get(self, start, length):
376        if self._have(start, length):
377            return self.data[start:start+length]
378        return None
379    def pop(self, start, length):
380        data = self.get(start, length)
381        if data:
382            self.remove(start, length)
383        return data
384    def remove(self, start, length):
385        self.missing = replace(extend(self.missing, start, length, "1"),
386                               start, "1"*length)
387    def add(self, start, data):
388        self.missing = replace(extend(self.missing, start, len(data), "1"),
389                               start, "0"*len(data))
390        self.data = replace(extend(self.data, start, len(data), b" "),
391                            start, data)
392
393
394class StringSpans(unittest.TestCase):
395    def do_basic(self, klass):
396        ds = klass()
397        self.failUnlessEqual(ds.len(), 0)
398        self.failUnlessEqual(list(ds._dump()), [])
399        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 0)
400        s1 = ds.get_spans()
401        self.failUnlessEqual(ds.get(0, 4), None)
402        self.failUnlessEqual(ds.pop(0, 4), None)
403        ds.remove(0, 4)
404
405        ds.add(2, b"four")
406        self.failUnlessEqual(ds.len(), 4)
407        self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
408        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
409        s1 = ds.get_spans()
410        self.failUnless((2,2) in s1)
411        self.failUnlessEqual(ds.get(0, 4), None)
412        self.failUnlessEqual(ds.pop(0, 4), None)
413        self.failUnlessEqual(ds.get(4, 4), None)
414
415        ds2 = klass(ds)
416        self.failUnlessEqual(ds2.len(), 4)
417        self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
418        self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 4)
419        self.failUnlessEqual(ds2.get(0, 4), None)
420        self.failUnlessEqual(ds2.pop(0, 4), None)
421        self.failUnlessEqual(ds2.pop(2, 3), b"fou")
422        self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 1)
423        self.failUnlessEqual(ds2.get(2, 3), None)
424        self.failUnlessEqual(ds2.get(5, 1), b"r")
425        self.failUnlessEqual(ds.get(2, 3), b"fou")
426        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
427
428        ds.add(0, b"23")
429        self.failUnlessEqual(ds.len(), 6)
430        self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
431        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 6)
432        self.failUnlessEqual(ds.get(0, 4), b"23fo")
433        self.failUnlessEqual(ds.pop(0, 4), b"23fo")
434        self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 2)
435        self.failUnlessEqual(ds.get(0, 4), None)
436        self.failUnlessEqual(ds.pop(0, 4), None)
437
438        ds = klass()
439        ds.add(2, b"four")
440        ds.add(3, b"ea")
441        self.failUnlessEqual(ds.get(2, 4), b"fear")
442
443        ds = klass()
444        ds.add(2, b"four")
445        ds.add(3, b"ea")
446        self.failUnlessEqual(ds.get(2, 4), b"fear")
447
448
449    def do_scan(self, klass):
450        # do a test with gaps and spans of size 1 and 2
451        #  left=(1,11) * right=(1,11) * gapsize=(1,2)
452        # 111, 112, 121, 122, 211, 212, 221, 222
453        #    211
454        #      121
455        #         112
456        #            212
457        #               222
458        #                   221
459        #                      111
460        #                        122
461        #  11 1  1 11 11  11  1 1  111
462        # 0123456789012345678901234567
463        # abcdefghijklmnopqrstuvwxyz-=
464        pieces = [(1, b"bc"),
465                  (4, b"e"),
466                  (7, b"h"),
467                  (9, b"jk"),
468                  (12, b"mn"),
469                  (16, b"qr"),
470                  (20, b"u"),
471                  (22, b"w"),
472                  (25, b"z-="),
473                  ]
474        p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
475        S = b"abcdefghijklmnopqrstuvwxyz-="
476        # TODO: when adding data, add capital letters, to make sure we aren't
477        # just leaving the old data in place
478        l = len(S)
479        def base():
480            ds = klass()
481            for start, data in pieces:
482                ds.add(start, data)
483            return ds
484        def dump(s):
485            p = set(s._dump())
486            d = b"".join([((i not in p) and b" " or S[i]) for i in range(l)])
487            assert len(d) == l
488            return d
489        DEBUG = False
490        for start in range(0, l):
491            for end in range(start+1, l):
492                # add [start-end) to the baseline
493                which = "%d-%d" % (start, end-1)
494                p_added = set(range(start, end))
495                b = base()
496                if DEBUG:
497                    print()
498                    print(dump(b), which)
499                    add = klass(); add.add(start, S[start:end])
500                    print(dump(add))
501                b.add(start, S[start:end])
502                if DEBUG:
503                    print(dump(b))
504                # check that the new span is there
505                d = b.get(start, end-start)
506                self.failUnlessEqual(d, S[start:end], which)
507                # check that all the original pieces are still there
508                for t_start, t_data in pieces:
509                    t_len = len(t_data)
510                    self.failUnlessEqual(b.get(t_start, t_len),
511                                         S[t_start:t_start+t_len],
512                                         "%s %d+%d" % (which, t_start, t_len))
513                # check that a lot of subspans are mostly correct
514                for t_start in range(l):
515                    for t_len in range(1,4):
516                        d = b.get(t_start, t_len)
517                        if d is not None:
518                            which2 = "%s+(%d-%d)" % (which, t_start,
519                                                     t_start+t_len-1)
520                            self.failUnlessEqual(d, S[t_start:t_start+t_len],
521                                                 which2)
522                        # check that removing a subspan gives the right value
523                        b2 = klass(b)
524                        b2.remove(t_start, t_len)
525                        removed = set(range(t_start, t_start+t_len))
526                        for i in range(l):
527                            exp = (((i in p_elements) or (i in p_added))
528                                   and (i not in removed))
529                            which2 = "%s-(%d-%d)" % (which, t_start,
530                                                     t_start+t_len-1)
531                            self.failUnlessEqual(bool(b2.get(i, 1)), exp,
532                                                 which2+" %d" % i)
533
534    def test_test(self):
535        self.do_basic(SimpleDataSpans)
536        self.do_scan(SimpleDataSpans)
537
538    def test_basic(self):
539        self.do_basic(DataSpans)
540        self.do_scan(DataSpans)
541
542    def test_random(self):
543        # attempt to increase coverage of corner cases by comparing behavior
544        # of a simple-but-slow model implementation against the
545        # complex-but-fast actual implementation, in a large number of random
546        # operations
547        S1 = SimpleDataSpans
548        S2 = DataSpans
549        s1 = S1(); s2 = S2()
550        seed = b""
551        def _randstr(length, seed):
552            created = 0
553            pieces = []
554            while created < length:
555                piece = sha256(seed + bytes(created))
556                pieces.append(piece)
557                created += len(piece)
558            return b"".join(pieces)[:length]
559        def _create(subseed):
560            ns1 = S1(); ns2 = S2()
561            for i in range(10):
562                what = sha256(subseed+bytes(i))
563                start = int(what[2:4], 16)
564                length = max(1,int(what[5:6], 16))
565                ns1.add(start, _randstr(length, what[7:9]));
566                ns2.add(start, _randstr(length, what[7:9]))
567            return ns1, ns2
568
569        #print()
570        for i in range(1000):
571            what = sha256(seed+bytes(i))
572            op = what[0:1]
573            subop = what[1:2]
574            start = int(what[2:4], 16)
575            length = max(1,int(what[5:6], 16))
576            #print(what)
577            if op in b"0":
578                if subop in b"0123456":
579                    s1 = S1(); s2 = S2()
580                else:
581                    s1, s2 = _create(what[7:11])
582                #print("s2 = %s" % list(s2._dump()))
583            elif op in b"123456":
584                #print("s2.add(%d,%d)" % (start, length))
585                s1.add(start, _randstr(length, what[7:9]));
586                s2.add(start, _randstr(length, what[7:9]))
587            elif op in b"789abc":
588                #print("s2.remove(%d,%d)" % (start, length))
589                s1.remove(start, length); s2.remove(start, length)
590            else:
591                #print("s2.pop(%d,%d)" % (start, length))
592                d1 = s1.pop(start, length); d2 = s2.pop(start, length)
593                self.failUnlessEqual(d1, d2)
594            #print("s1 now %s" % list(s1._dump()))
595            #print("s2 now %s" % list(s2._dump()))
596            self.failUnlessEqual(s1.len(), s2.len())
597            self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
598            for j in range(100):
599                what = sha256(what[12:14]+bytes(j))
600                start = int(what[2:4], 16)
601                length = max(1, int(what[5:6], 16))
602                d1 = s1.get(start, length); d2 = s2.get(start, length)
603                self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))
Note: See TracBrowser for help on using the repository browser.