source: trunk/src/allmydata/storage/http_client.py

Last change on this file was fced1ab0, checked in by Itamar Turner-Trauring <itamar@…>, at 2024-01-24T18:50:55Z

Switch to using pycddl for CBOR decoding.

  • Property mode set to 100644
File size: 41.5 KB
Line 
1"""
2HTTP client that talks to the HTTP storage server.
3"""
4
5from __future__ import annotations
6
7
8from typing import (
9    Union,
10    Optional,
11    Sequence,
12    Mapping,
13    BinaryIO,
14    cast,
15    TypedDict,
16    Set,
17    Dict,
18    Callable,
19    ClassVar,
20)
21from base64 import b64encode
22from io import BytesIO
23from os import SEEK_END
24
25from attrs import define, asdict, frozen, field
26from eliot import start_action, register_exception_extractor
27from eliot.twisted import DeferredContext
28
29from pycddl import Schema
30from collections_extended import RangeMap
31from werkzeug.datastructures import Range, ContentRange
32from twisted.web.http_headers import Headers
33from twisted.web import http
34from twisted.web.iweb import IPolicyForHTTPS, IResponse, IAgent
35from twisted.internet.defer import Deferred, succeed
36from twisted.internet.interfaces import (
37    IOpenSSLClientConnectionCreator,
38    IReactorTime,
39    IDelayedCall,
40)
41from twisted.internet.ssl import CertificateOptions
42from twisted.protocols.tls import TLSMemoryBIOProtocol
43from twisted.web.client import Agent, HTTPConnectionPool
44from zope.interface import implementer
45from hyperlink import DecodedURL
46import treq
47from treq.client import HTTPClient
48from treq.testing import StubTreq
49from OpenSSL import SSL
50from werkzeug.http import parse_content_range_header
51
52from .http_common import (
53    swissnum_auth_header,
54    Secrets,
55    get_content_type,
56    CBOR_MIME_TYPE,
57    get_spki_hash,
58    response_is_not_html,
59)
60from ..interfaces import VersionMessage
61from .common import si_b2a, si_to_human_readable
62from ..util.hashutil import timing_safe_compare
63from ..util.deferredutil import async_to_deferred
64from ..util.tor_provider import _Provider as TorProvider
65from ..util.cputhreadpool import defer_to_thread
66from ..util.cbor import dumps
67
68try:
69    from txtorcon import Tor  # type: ignore
70except ImportError:
71
72    class Tor:  # type: ignore[no-redef]
73        pass
74
75
76def _encode_si(si: bytes) -> str:
77    """Encode the storage index into Unicode string."""
78    return str(si_b2a(si), "ascii")
79
80
81class ClientException(Exception):
82    """An unexpected response code from the server."""
83
84    def __init__(
85        self, code: int, message: Optional[str] = None, body: Optional[bytes] = None
86    ):
87        Exception.__init__(self, code, message, body)
88        self.code = code
89        self.message = message
90        self.body = body
91
92
93register_exception_extractor(ClientException, lambda e: {"response_code": e.code})
94
95
96# Schemas for server responses.
97#
98# Tags are of the form #6.nnn, where the number is documented at
99# https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml. Notably, #6.258
100# indicates a set.
101_SCHEMAS: Mapping[str, Schema] = {
102    "get_version": Schema(
103        # Note that the single-quoted (`'`) string keys in this schema
104        # represent *byte* strings - per the CDDL specification.  Text strings
105        # are represented using strings with *double* quotes (`"`).
106        """
107        response = {'http://allmydata.org/tahoe/protocols/storage/v1' => {
108                 'maximum-immutable-share-size' => uint
109                 'maximum-mutable-share-size' => uint
110                 'available-space' => uint
111                 }
112                 'application-version' => bstr
113              }
114    """
115    ),
116    "allocate_buckets": Schema(
117        """
118    response = {
119      already-have: #6.258([0*256 uint])
120      allocated: #6.258([0*256 uint])
121    }
122    """
123    ),
124    "immutable_write_share_chunk": Schema(
125        """
126    response = {
127      required: [0* {begin: uint, end: uint}]
128    }
129    """
130    ),
131    "list_shares": Schema(
132        """
133    response = #6.258([0*256 uint])
134    """
135    ),
136    "mutable_read_test_write": Schema(
137        """
138        response = {
139          "success": bool,
140          "data": {0*256 share_number: [0* bstr]}
141        }
142        share_number = uint
143        """
144    ),
145    "mutable_list_shares": Schema(
146        """
147        response = #6.258([0*256 uint])
148        """
149    ),
150}
151
152
153@define
154class _LengthLimitedCollector:
155    """
156    Collect data using ``treq.collect()``, with limited length.
157    """
158
159    remaining_length: int
160    timeout_on_silence: IDelayedCall
161    f: BytesIO = field(factory=BytesIO)
162
163    def __call__(self, data: bytes) -> None:
164        self.timeout_on_silence.reset(60)
165        self.remaining_length -= len(data)
166        if self.remaining_length < 0:
167            raise ValueError("Response length was too long")
168        self.f.write(data)
169
170
171def limited_content(
172    response: IResponse,
173    clock: IReactorTime,
174    max_length: int = 30 * 1024 * 1024,
175) -> Deferred[BinaryIO]:
176    """
177    Like ``treq.content()``, but limit data read from the response to a set
178    length.  If the response is longer than the max allowed length, the result
179    fails with a ``ValueError``.
180
181    A potentially useful future improvement would be using a temporary file to
182    store the content; since filesystem buffering means that would use memory
183    for small responses and disk for large responses.
184
185    This will time out if no data is received for 60 seconds; so long as a
186    trickle of data continues to arrive, it will continue to run.
187    """
188    result_deferred = succeed(None)
189
190    # Sadly, addTimeout() won't work because we need access to the IDelayedCall
191    # in order to reset it on each data chunk received.
192    timeout = clock.callLater(60, result_deferred.cancel)
193    collector = _LengthLimitedCollector(max_length, timeout)
194
195    with start_action(
196        action_type="allmydata:storage:http-client:limited-content",
197        max_length=max_length,
198    ).context():
199        d = DeferredContext(result_deferred)
200
201    # Make really sure everything gets called in Deferred context, treq might
202    # call collector directly...
203    d.addCallback(lambda _: treq.collect(response, collector))
204
205    def done(_: object) -> BytesIO:
206        timeout.cancel()
207        collector.f.seek(0)
208        return collector.f
209
210    def failed(f):
211        if timeout.active():
212            timeout.cancel()
213        return f
214
215    result = d.addCallbacks(done, failed)
216    return result.addActionFinish()
217
218
219@define
220class ImmutableCreateResult(object):
221    """Result of creating a storage index for an immutable."""
222
223    already_have: set[int]
224    allocated: set[int]
225
226
227class _TLSContextFactory(CertificateOptions):
228    """
229    Create a context that validates the way Tahoe-LAFS wants to: based on a
230    pinned certificate hash, rather than a certificate authority.
231
232    Originally implemented as part of Foolscap.  To comply with the license,
233    here's the original licensing terms:
234
235    Copyright (c) 2006-2008 Brian Warner
236
237    Permission is hereby granted, free of charge, to any person obtaining a
238    copy of this software and associated documentation files (the "Software"),
239    to deal in the Software without restriction, including without limitation
240    the rights to use, copy, modify, merge, publish, distribute, sublicense,
241    and/or sell copies of the Software, and to permit persons to whom the
242    Software is furnished to do so, subject to the following conditions:
243
244    The above copyright notice and this permission notice shall be included in
245    all copies or substantial portions of the Software.
246
247    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
248    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
249    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
250    THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
251    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
252    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
253    DEALINGS IN THE SOFTWARE.
254    """
255
256    def __init__(self, expected_spki_hash: bytes):
257        self.expected_spki_hash = expected_spki_hash
258        CertificateOptions.__init__(self)
259
260    def getContext(self) -> SSL.Context:
261        def always_validate(conn, cert, errno, depth, preverify_ok):
262            # This function is called to validate the certificate received by
263            # the other end. OpenSSL calls it multiple times, for each errno
264            # for each certificate.
265
266            # We do not care about certificate authorities or revocation
267            # lists, we just want to know that the certificate has a valid
268            # signature and follow the chain back to one which is
269            # self-signed. We need to protect against forged signatures, but
270            # not the usual TLS concerns about invalid CAs or revoked
271            # certificates.
272            things_are_ok = (
273                SSL.X509VerificationCodes.OK,
274                SSL.X509VerificationCodes.ERR_CERT_NOT_YET_VALID,
275                SSL.X509VerificationCodes.ERR_CERT_HAS_EXPIRED,
276                SSL.X509VerificationCodes.ERR_DEPTH_ZERO_SELF_SIGNED_CERT,
277                SSL.X509VerificationCodes.ERR_SELF_SIGNED_CERT_IN_CHAIN,
278            )
279            # TODO can we do this once instead of multiple times?
280            if errno in things_are_ok and timing_safe_compare(
281                get_spki_hash(cert.to_cryptography()), self.expected_spki_hash
282            ):
283                return 1
284            # TODO: log the details of the error, because otherwise they get
285            # lost in the PyOpenSSL exception that will eventually be raised
286            # (possibly OpenSSL.SSL.Error: certificate verify failed)
287            return 0
288
289        ctx = CertificateOptions.getContext(self)
290
291        # VERIFY_PEER means we ask the the other end for their certificate.
292        ctx.set_verify(SSL.VERIFY_PEER, always_validate)
293        return ctx
294
295
296@implementer(IPolicyForHTTPS)
297@implementer(IOpenSSLClientConnectionCreator)
298@define
299class _StorageClientHTTPSPolicy:
300    """
301    A HTTPS policy that ensures the SPKI hash of the public key matches a known
302    hash, i.e. pinning-based validation.
303    """
304
305    expected_spki_hash: bytes
306
307    # IPolicyForHTTPS
308    def creatorForNetloc(self, hostname: str, port: int) -> _StorageClientHTTPSPolicy:
309        return self
310
311    # IOpenSSLClientConnectionCreator
312    def clientConnectionForTLS(
313        self, tlsProtocol: TLSMemoryBIOProtocol
314    ) -> SSL.Connection:
315        return SSL.Connection(
316            _TLSContextFactory(self.expected_spki_hash).getContext(), None
317        )
318
319
320@define
321class StorageClientFactory:
322    """
323    Create ``StorageClient`` instances, using appropriate
324    ``twisted.web.iweb.IAgent`` for different connection methods: normal TCP,
325    Tor, and eventually I2P.
326
327    There is some caching involved since there might be shared setup work, e.g.
328    connecting to the local Tor service only needs to happen once.
329    """
330
331    _default_connection_handlers: dict[str, str]
332    _tor_provider: Optional[TorProvider]
333    # Cache the Tor instance created by the provider, if relevant.
334    _tor_instance: Optional[Tor] = None
335
336    # If set, we're doing unit testing and we should call this with any
337    # HTTPConnectionPool that gets passed/created to ``create_agent()``.
338    TEST_MODE_REGISTER_HTTP_POOL: ClassVar[
339        Optional[Callable[[HTTPConnectionPool], None]]
340    ] = None
341
342    @classmethod
343    def start_test_mode(cls, callback: Callable[[HTTPConnectionPool], None]) -> None:
344        """Switch to testing mode.
345
346        In testing mode we register the pool with test system using the given
347        callback so it can Do Things, most notably killing off idle HTTP
348        connections at test shutdown and, in some tests, in the midddle of the
349        test.
350        """
351        cls.TEST_MODE_REGISTER_HTTP_POOL = callback
352
353    @classmethod
354    def stop_test_mode(cls) -> None:
355        """Stop testing mode."""
356        cls.TEST_MODE_REGISTER_HTTP_POOL = None
357
358    async def _create_agent(
359        self,
360        nurl: DecodedURL,
361        reactor: object,
362        tls_context_factory: IPolicyForHTTPS,
363        pool: HTTPConnectionPool,
364    ) -> IAgent:
365        """Create a new ``IAgent``, possibly using Tor."""
366        if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
367            self.TEST_MODE_REGISTER_HTTP_POOL(pool)
368
369        # TODO default_connection_handlers should really be an object, not a
370        # dict, so we can ask "is this using Tor" without poking at a
371        # dictionary with arbitrary strings... See
372        # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4032
373        handler = self._default_connection_handlers["tcp"]
374
375        if handler == "tcp":
376            return Agent(reactor, tls_context_factory, pool=pool)
377        if handler == "tor" or nurl.scheme == "pb+tor":
378            assert self._tor_provider is not None
379            if self._tor_instance is None:
380                self._tor_instance = await self._tor_provider.get_tor_instance(reactor)
381            return self._tor_instance.web_agent(
382                pool=pool, tls_context_factory=tls_context_factory
383            )
384        else:
385            # I2P support will be added here. See
386            # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4037
387            raise RuntimeError(f"Unsupported tcp connection handler: {handler}")
388
389    async def create_storage_client(
390        self,
391        nurl: DecodedURL,
392        reactor: IReactorTime,
393        pool: Optional[HTTPConnectionPool] = None,
394    ) -> StorageClient:
395        """Create a new ``StorageClient`` for the given NURL."""
396        assert nurl.fragment == "v=1"
397        assert nurl.scheme in ("pb", "pb+tor")
398        if pool is None:
399            pool = HTTPConnectionPool(reactor)
400            pool.maxPersistentPerHost = 10
401
402        certificate_hash = nurl.user.encode("ascii")
403        agent = await self._create_agent(
404            nurl,
405            reactor,
406            _StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
407            pool,
408        )
409        treq_client = HTTPClient(agent)
410        https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port)
411        swissnum = nurl.path[0].encode("ascii")
412        response_check = lambda _: None
413        if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
414            response_check = response_is_not_html
415
416        return StorageClient(
417            https_url,
418            swissnum,
419            treq_client,
420            pool,
421            reactor,
422            response_check,
423        )
424
425
426@define(hash=True)
427class StorageClient(object):
428    """
429    Low-level HTTP client that talks to the HTTP storage server.
430
431    Create using a ``StorageClientFactory`` instance.
432    """
433
434    # The URL should be a HTTPS URL ("https://...")
435    _base_url: DecodedURL
436    _swissnum: bytes
437    _treq: Union[treq, StubTreq, HTTPClient]
438    _pool: HTTPConnectionPool
439    _clock: IReactorTime
440    # Are we running unit tests?
441    _analyze_response: Callable[[IResponse], None] = lambda _: None
442
443    def relative_url(self, path: str) -> DecodedURL:
444        """Get a URL relative to the base URL."""
445        return self._base_url.click(path)
446
447    def _get_headers(self, headers: Optional[Headers]) -> Headers:
448        """Return the basic headers to be used by default."""
449        if headers is None:
450            headers = Headers()
451        headers.addRawHeader(
452            "Authorization",
453            swissnum_auth_header(self._swissnum),
454        )
455        return headers
456
457    @async_to_deferred
458    async def request(
459        self,
460        method: str,
461        url: DecodedURL,
462        lease_renew_secret: Optional[bytes] = None,
463        lease_cancel_secret: Optional[bytes] = None,
464        upload_secret: Optional[bytes] = None,
465        write_enabler_secret: Optional[bytes] = None,
466        headers: Optional[Headers] = None,
467        message_to_serialize: object = None,
468        timeout: float = 60,
469        **kwargs,
470    ) -> IResponse:
471        """
472        Like ``treq.request()``, but with optional secrets that get translated
473        into corresponding HTTP headers.
474
475        If ``message_to_serialize`` is set, it will be serialized (by default
476        with CBOR) and set as the request body.  It should not be mutated
477        during execution of this function!
478
479        Default timeout is 60 seconds.
480        """
481        with start_action(
482            action_type="allmydata:storage:http-client:request",
483            method=method,
484            url=url.to_text(),
485            timeout=timeout,
486        ) as ctx:
487            response = await self._request(
488                method,
489                url,
490                lease_renew_secret,
491                lease_cancel_secret,
492                upload_secret,
493                write_enabler_secret,
494                headers,
495                message_to_serialize,
496                timeout,
497                **kwargs,
498            )
499            ctx.add_success_fields(response_code=response.code)
500            return response
501
502    async def _request(
503        self,
504        method: str,
505        url: DecodedURL,
506        lease_renew_secret: Optional[bytes] = None,
507        lease_cancel_secret: Optional[bytes] = None,
508        upload_secret: Optional[bytes] = None,
509        write_enabler_secret: Optional[bytes] = None,
510        headers: Optional[Headers] = None,
511        message_to_serialize: object = None,
512        timeout: float = 60,
513        **kwargs,
514    ) -> IResponse:
515        """The implementation of request()."""
516        headers = self._get_headers(headers)
517
518        # Add secrets:
519        for secret, value in [
520            (Secrets.LEASE_RENEW, lease_renew_secret),
521            (Secrets.LEASE_CANCEL, lease_cancel_secret),
522            (Secrets.UPLOAD, upload_secret),
523            (Secrets.WRITE_ENABLER, write_enabler_secret),
524        ]:
525            if value is None:
526                continue
527            headers.addRawHeader(
528                "X-Tahoe-Authorization",
529                b"%s %s" % (secret.value.encode("ascii"), b64encode(value).strip()),
530            )
531
532        # Note we can accept CBOR:
533        headers.addRawHeader("Accept", CBOR_MIME_TYPE)
534
535        # If there's a request message, serialize it and set the Content-Type
536        # header:
537        if message_to_serialize is not None:
538            if "data" in kwargs:
539                raise TypeError(
540                    "Can't use both `message_to_serialize` and `data` "
541                    "as keyword arguments at the same time"
542                )
543            kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
544            headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
545
546        response = await self._treq.request(
547            method, url, headers=headers, timeout=timeout, **kwargs
548        )
549        self._analyze_response(response)
550
551        return response
552
553    async def decode_cbor(self, response: IResponse, schema: Schema) -> object:
554        """Given HTTP response, return decoded CBOR body."""
555        with start_action(action_type="allmydata:storage:http-client:decode-cbor"):
556            if response.code > 199 and response.code < 300:
557                content_type = get_content_type(response.headers)
558                if content_type == CBOR_MIME_TYPE:
559                    f = await limited_content(response, self._clock)
560                    data = f.read()
561
562                    def validate_and_decode():
563                        return schema.validate_cbor(data, True)
564
565                    return await defer_to_thread(validate_and_decode)
566                else:
567                    raise ClientException(
568                        -1,
569                        "Server didn't send CBOR, content type is {}".format(
570                            content_type
571                        ),
572                    )
573            else:
574                data = (
575                    await limited_content(response, self._clock, max_length=10_000)
576                ).read()
577                raise ClientException(response.code, response.phrase, data)
578
579    def shutdown(self) -> Deferred[object]:
580        """Shutdown any connections."""
581        return self._pool.closeCachedConnections()
582
583
584@define(hash=True)
585class StorageClientGeneral(object):
586    """
587    High-level HTTP APIs that aren't immutable- or mutable-specific.
588    """
589
590    _client: StorageClient
591
592    @async_to_deferred
593    async def get_version(self) -> VersionMessage:
594        """
595        Return the version metadata for the server.
596        """
597        with start_action(
598            action_type="allmydata:storage:http-client:get-version",
599        ):
600            return await self._get_version()
601
602    async def _get_version(self) -> VersionMessage:
603        """Implementation of get_version()."""
604        url = self._client.relative_url("/storage/v1/version")
605        response = await self._client.request("GET", url)
606        decoded_response = cast(
607            Dict[bytes, object],
608            await self._client.decode_cbor(response, _SCHEMAS["get_version"]),
609        )
610        # Add some features we know are true because the HTTP API
611        # specification requires them and because other parts of the storage
612        # client implementation assumes they will be present.
613        cast(
614            Dict[bytes, object],
615            decoded_response[b"http://allmydata.org/tahoe/protocols/storage/v1"],
616        ).update(
617            {
618                b"tolerates-immutable-read-overrun": True,
619                b"delete-mutable-shares-with-zero-length-writev": True,
620                b"fills-holes-with-zero-bytes": True,
621                b"prevents-read-past-end-of-share-data": True,
622            }
623        )
624        return decoded_response
625
626    @async_to_deferred
627    async def add_or_renew_lease(
628        self, storage_index: bytes, renew_secret: bytes, cancel_secret: bytes
629    ) -> None:
630        """
631        Add or renew a lease.
632
633        If the renewal secret matches an existing lease, it is renewed.
634        Otherwise a new lease is added.
635        """
636        with start_action(
637            action_type="allmydata:storage:http-client:add-or-renew-lease",
638            storage_index=si_to_human_readable(storage_index),
639        ):
640            return await self._add_or_renew_lease(
641                storage_index, renew_secret, cancel_secret
642            )
643
644    async def _add_or_renew_lease(
645        self, storage_index: bytes, renew_secret: bytes, cancel_secret: bytes
646    ) -> None:
647        url = self._client.relative_url(
648            "/storage/v1/lease/{}".format(_encode_si(storage_index))
649        )
650        response = await self._client.request(
651            "PUT",
652            url,
653            lease_renew_secret=renew_secret,
654            lease_cancel_secret=cancel_secret,
655        )
656
657        if response.code == http.NO_CONTENT:
658            return
659        else:
660            raise ClientException(response.code)
661
662
663@define
664class UploadProgress(object):
665    """
666    Progress of immutable upload, per the server.
667    """
668
669    # True when upload has finished.
670    finished: bool
671    # Remaining ranges to upload.
672    required: RangeMap
673
674
675@async_to_deferred
676async def read_share_chunk(
677    client: StorageClient,
678    share_type: str,
679    storage_index: bytes,
680    share_number: int,
681    offset: int,
682    length: int,
683) -> bytes:
684    """
685    Download a chunk of data from a share.
686
687    TODO https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3857 Failed downloads
688    should be transparently retried and redownloaded by the implementation a
689    few times so that if a failure percolates up, the caller can assume the
690    failure isn't a short-term blip.
691
692    NOTE: the underlying HTTP protocol is somewhat more flexible than this API,
693    insofar as it doesn't always require a range.  In practice a range is
694    always provided by the current callers.
695    """
696    url = client.relative_url(
697        "/storage/v1/{}/{}/{}".format(
698            share_type, _encode_si(storage_index), share_number
699        )
700    )
701    # The default 60 second timeout is for getting the response, so it doesn't
702    # include the time it takes to download the body... so we will will deal
703    # with that later, via limited_content().
704    response = await client.request(
705        "GET",
706        url,
707        headers=Headers(
708            # Ranges in HTTP are _inclusive_, Python's convention is exclusive,
709            # but Range constructor does that the conversion for us.
710            {"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
711        ),
712        unbuffered=True,  # Don't buffer the response in memory.
713    )
714
715    if response.code == http.NO_CONTENT:
716        return b""
717
718    content_type = get_content_type(response.headers)
719    if content_type != "application/octet-stream":
720        raise ValueError(
721            f"Content-type was wrong: {content_type}, should be application/octet-stream"
722        )
723
724    if response.code == http.PARTIAL_CONTENT:
725        content_range = parse_content_range_header(
726            response.headers.getRawHeaders("content-range")[0] or ""
727        )
728        if (
729            content_range is None
730            or content_range.stop is None
731            or content_range.start is None
732        ):
733            raise ValueError(
734                "Content-Range was missing, invalid, or in format we don't support"
735            )
736        supposed_length = content_range.stop - content_range.start
737        if supposed_length > length:
738            raise ValueError("Server sent more than we asked for?!")
739        # It might also send less than we asked for. That's (probably) OK, e.g.
740        # if we went past the end of the file.
741        body = await limited_content(response, client._clock, supposed_length)
742        body.seek(0, SEEK_END)
743        actual_length = body.tell()
744        if actual_length != supposed_length:
745            # Most likely a mutable that got changed out from under us, but
746            # conceivably could be a bug...
747            raise ValueError(
748                f"Length of response sent from server ({actual_length}) "
749                + f"didn't match Content-Range header ({supposed_length})"
750            )
751        body.seek(0)
752        return body.read()
753    else:
754        # Technically HTTP allows sending an OK with full body under these
755        # circumstances, but the server is not designed to do that so we ignore
756        # that possibility for now...
757        raise ClientException(response.code)
758
759
760@async_to_deferred
761async def advise_corrupt_share(
762    client: StorageClient,
763    share_type: str,
764    storage_index: bytes,
765    share_number: int,
766    reason: str,
767) -> None:
768    assert isinstance(reason, str)
769    url = client.relative_url(
770        "/storage/v1/{}/{}/{}/corrupt".format(
771            share_type, _encode_si(storage_index), share_number
772        )
773    )
774    message = {"reason": reason}
775    response = await client.request("POST", url, message_to_serialize=message)
776    if response.code == http.OK:
777        return
778    else:
779        raise ClientException(
780            response.code,
781        )
782
783
784@define(hash=True)
785class StorageClientImmutables(object):
786    """
787    APIs for interacting with immutables.
788    """
789
790    _client: StorageClient
791
792    @async_to_deferred
793    async def create(
794        self,
795        storage_index: bytes,
796        share_numbers: set[int],
797        allocated_size: int,
798        upload_secret: bytes,
799        lease_renew_secret: bytes,
800        lease_cancel_secret: bytes,
801    ) -> ImmutableCreateResult:
802        """
803        Create a new storage index for an immutable.
804
805        TODO https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3857 retry
806        internally on failure, to ensure the operation fully succeeded.  If
807        sufficient number of failures occurred, the result may fire with an
808        error, but there's no expectation that user code needs to have a
809        recovery codepath; it will most likely just report an error to the
810        user.
811
812        Result fires when creating the storage index succeeded, if creating the
813        storage index failed the result will fire with an exception.
814        """
815        with start_action(
816            action_type="allmydata:storage:http-client:immutable:create",
817            storage_index=si_to_human_readable(storage_index),
818            share_numbers=share_numbers,
819            allocated_size=allocated_size,
820        ) as ctx:
821            result = await self._create(
822                storage_index,
823                share_numbers,
824                allocated_size,
825                upload_secret,
826                lease_renew_secret,
827                lease_cancel_secret,
828            )
829            ctx.add_success_fields(
830                already_have=result.already_have, allocated=result.allocated
831            )
832            return result
833
834    async def _create(
835        self,
836        storage_index: bytes,
837        share_numbers: set[int],
838        allocated_size: int,
839        upload_secret: bytes,
840        lease_renew_secret: bytes,
841        lease_cancel_secret: bytes,
842    ) -> ImmutableCreateResult:
843        """Implementation of create()."""
844        url = self._client.relative_url(
845            "/storage/v1/immutable/" + _encode_si(storage_index)
846        )
847        message = {"share-numbers": share_numbers, "allocated-size": allocated_size}
848
849        response = await self._client.request(
850            "POST",
851            url,
852            lease_renew_secret=lease_renew_secret,
853            lease_cancel_secret=lease_cancel_secret,
854            upload_secret=upload_secret,
855            message_to_serialize=message,
856        )
857        decoded_response = cast(
858            Mapping[str, Set[int]],
859            await self._client.decode_cbor(response, _SCHEMAS["allocate_buckets"]),
860        )
861        return ImmutableCreateResult(
862            already_have=decoded_response["already-have"],
863            allocated=decoded_response["allocated"],
864        )
865
866    @async_to_deferred
867    async def abort_upload(
868        self, storage_index: bytes, share_number: int, upload_secret: bytes
869    ) -> None:
870        """Abort the upload."""
871        with start_action(
872            action_type="allmydata:storage:http-client:immutable:abort-upload",
873            storage_index=si_to_human_readable(storage_index),
874            share_number=share_number,
875        ):
876            return await self._abort_upload(storage_index, share_number, upload_secret)
877
878    async def _abort_upload(
879        self, storage_index: bytes, share_number: int, upload_secret: bytes
880    ) -> None:
881        """Implementation of ``abort_upload()``."""
882        url = self._client.relative_url(
883            "/storage/v1/immutable/{}/{}/abort".format(
884                _encode_si(storage_index), share_number
885            )
886        )
887        response = await self._client.request(
888            "PUT",
889            url,
890            upload_secret=upload_secret,
891        )
892
893        if response.code == http.OK:
894            return
895        else:
896            raise ClientException(
897                response.code,
898            )
899
900    @async_to_deferred
901    async def write_share_chunk(
902        self,
903        storage_index: bytes,
904        share_number: int,
905        upload_secret: bytes,
906        offset: int,
907        data: bytes,
908    ) -> UploadProgress:
909        """
910        Upload a chunk of data for a specific share.
911
912        TODO https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3857 The
913        implementation should retry failed uploads transparently a number of
914        times, so that if a failure percolates up, the caller can assume the
915        failure isn't a short-term blip.
916
917        Result fires when the upload succeeded, with a boolean indicating
918        whether the _complete_ share (i.e. all chunks, not just this one) has
919        been uploaded.
920        """
921        with start_action(
922            action_type="allmydata:storage:http-client:immutable:write-share-chunk",
923            storage_index=si_to_human_readable(storage_index),
924            share_number=share_number,
925            offset=offset,
926            data_len=len(data),
927        ) as ctx:
928            result = await self._write_share_chunk(
929                storage_index, share_number, upload_secret, offset, data
930            )
931            ctx.add_success_fields(finished=result.finished)
932            return result
933
934    async def _write_share_chunk(
935        self,
936        storage_index: bytes,
937        share_number: int,
938        upload_secret: bytes,
939        offset: int,
940        data: bytes,
941    ) -> UploadProgress:
942        """Implementation of ``write_share_chunk()``."""
943        url = self._client.relative_url(
944            "/storage/v1/immutable/{}/{}".format(
945                _encode_si(storage_index), share_number
946            )
947        )
948        response = await self._client.request(
949            "PATCH",
950            url,
951            upload_secret=upload_secret,
952            data=data,
953            headers=Headers(
954                {
955                    "content-range": [
956                        ContentRange("bytes", offset, offset + len(data)).to_header()
957                    ]
958                }
959            ),
960        )
961
962        if response.code == http.OK:
963            # Upload is still unfinished.
964            finished = False
965        elif response.code == http.CREATED:
966            # Upload is done!
967            finished = True
968        else:
969            raise ClientException(
970                response.code,
971            )
972        body = cast(
973            Mapping[str, Sequence[Mapping[str, int]]],
974            await self._client.decode_cbor(
975                response, _SCHEMAS["immutable_write_share_chunk"]
976            ),
977        )
978        remaining = RangeMap()
979        for chunk in body["required"]:
980            remaining.set(True, chunk["begin"], chunk["end"])
981        return UploadProgress(finished=finished, required=remaining)
982
983    @async_to_deferred
984    async def read_share_chunk(
985        self, storage_index: bytes, share_number: int, offset: int, length: int
986    ) -> bytes:
987        """
988        Download a chunk of data from a share.
989        """
990        with start_action(
991            action_type="allmydata:storage:http-client:immutable:read-share-chunk",
992            storage_index=si_to_human_readable(storage_index),
993            share_number=share_number,
994            offset=offset,
995            length=length,
996        ) as ctx:
997            result = await read_share_chunk(
998                self._client, "immutable", storage_index, share_number, offset, length
999            )
1000            ctx.add_success_fields(data_len=len(result))
1001            return result
1002
1003    @async_to_deferred
1004    async def list_shares(self, storage_index: bytes) -> Set[int]:
1005        """
1006        Return the set of shares for a given storage index.
1007        """
1008        with start_action(
1009            action_type="allmydata:storage:http-client:immutable:list-shares",
1010            storage_index=si_to_human_readable(storage_index),
1011        ) as ctx:
1012            result = await self._list_shares(storage_index)
1013            ctx.add_success_fields(shares=result)
1014            return result
1015
1016    async def _list_shares(self, storage_index: bytes) -> Set[int]:
1017        """Implementation of ``list_shares()``."""
1018        url = self._client.relative_url(
1019            "/storage/v1/immutable/{}/shares".format(_encode_si(storage_index))
1020        )
1021        response = await self._client.request(
1022            "GET",
1023            url,
1024        )
1025        if response.code == http.OK:
1026            return cast(
1027                Set[int],
1028                await self._client.decode_cbor(response, _SCHEMAS["list_shares"]),
1029            )
1030        else:
1031            raise ClientException(response.code)
1032
1033    @async_to_deferred
1034    async def advise_corrupt_share(
1035        self,
1036        storage_index: bytes,
1037        share_number: int,
1038        reason: str,
1039    ) -> None:
1040        """Indicate a share has been corrupted, with a human-readable message."""
1041        with start_action(
1042            action_type="allmydata:storage:http-client:immutable:advise-corrupt-share",
1043            storage_index=si_to_human_readable(storage_index),
1044            share_number=share_number,
1045            reason=reason,
1046        ):
1047            await advise_corrupt_share(
1048                self._client, "immutable", storage_index, share_number, reason
1049            )
1050
1051
1052@frozen
1053class WriteVector:
1054    """Data to write to a chunk."""
1055
1056    offset: int
1057    data: bytes
1058
1059
1060@frozen
1061class TestVector:
1062    """Checks to make on a chunk before writing to it."""
1063
1064    offset: int
1065    size: int
1066    specimen: bytes
1067
1068
1069@frozen
1070class ReadVector:
1071    """
1072    Reads to do on chunks, as part of a read/test/write operation.
1073    """
1074
1075    offset: int
1076    size: int
1077
1078
1079@frozen
1080class TestWriteVectors:
1081    """Test and write vectors for a specific share."""
1082
1083    test_vectors: Sequence[TestVector] = field(factory=list)
1084    write_vectors: Sequence[WriteVector] = field(factory=list)
1085    new_length: Optional[int] = None
1086
1087    def asdict(self) -> dict:
1088        """Return dictionary suitable for sending over CBOR."""
1089        d = asdict(self)
1090        d["test"] = d.pop("test_vectors")
1091        d["write"] = d.pop("write_vectors")
1092        d["new-length"] = d.pop("new_length")
1093        return d
1094
1095
1096@frozen
1097class ReadTestWriteResult:
1098    """Result of sending read-test-write vectors."""
1099
1100    success: bool
1101    # Map share numbers to reads corresponding to the request's list of
1102    # ReadVectors:
1103    reads: Mapping[int, Sequence[bytes]]
1104
1105
1106# Result type for mutable read/test/write HTTP response. Can't just use
1107# dict[int,list[bytes]] because on Python 3.8 that will error out.
1108MUTABLE_RTW = TypedDict(
1109    "MUTABLE_RTW", {"success": bool, "data": Mapping[int, Sequence[bytes]]}
1110)
1111
1112
1113@frozen
1114class StorageClientMutables:
1115    """
1116    APIs for interacting with mutables.
1117    """
1118
1119    _client: StorageClient
1120
1121    @async_to_deferred
1122    async def read_test_write_chunks(
1123        self,
1124        storage_index: bytes,
1125        write_enabler_secret: bytes,
1126        lease_renew_secret: bytes,
1127        lease_cancel_secret: bytes,
1128        testwrite_vectors: dict[int, TestWriteVectors],
1129        read_vector: list[ReadVector],
1130    ) -> ReadTestWriteResult:
1131        """
1132        Read, test, and possibly write chunks to a particular mutable storage
1133        index.
1134
1135        Reads are done before writes.
1136
1137        Given a mapping between share numbers and test/write vectors, the tests
1138        are done and if they are valid the writes are done.
1139        """
1140        with start_action(
1141            action_type="allmydata:storage:http-client:mutable:read-test-write",
1142            storage_index=si_to_human_readable(storage_index),
1143        ):
1144            return await self._read_test_write_chunks(
1145                storage_index,
1146                write_enabler_secret,
1147                lease_renew_secret,
1148                lease_cancel_secret,
1149                testwrite_vectors,
1150                read_vector,
1151            )
1152
1153    async def _read_test_write_chunks(
1154        self,
1155        storage_index: bytes,
1156        write_enabler_secret: bytes,
1157        lease_renew_secret: bytes,
1158        lease_cancel_secret: bytes,
1159        testwrite_vectors: dict[int, TestWriteVectors],
1160        read_vector: list[ReadVector],
1161    ) -> ReadTestWriteResult:
1162        """Implementation of ``read_test_write_chunks()``."""
1163        url = self._client.relative_url(
1164            "/storage/v1/mutable/{}/read-test-write".format(_encode_si(storage_index))
1165        )
1166        message = {
1167            "test-write-vectors": {
1168                share_number: twv.asdict()
1169                for (share_number, twv) in testwrite_vectors.items()
1170            },
1171            "read-vector": [asdict(r) for r in read_vector],
1172        }
1173        response = await self._client.request(
1174            "POST",
1175            url,
1176            write_enabler_secret=write_enabler_secret,
1177            lease_renew_secret=lease_renew_secret,
1178            lease_cancel_secret=lease_cancel_secret,
1179            message_to_serialize=message,
1180        )
1181        if response.code == http.OK:
1182            result = cast(
1183                MUTABLE_RTW,
1184                await self._client.decode_cbor(
1185                    response, _SCHEMAS["mutable_read_test_write"]
1186                ),
1187            )
1188            return ReadTestWriteResult(success=result["success"], reads=result["data"])
1189        else:
1190            raise ClientException(response.code, (await response.content()))
1191
1192    @async_to_deferred
1193    async def read_share_chunk(
1194        self,
1195        storage_index: bytes,
1196        share_number: int,
1197        offset: int,
1198        length: int,
1199    ) -> bytes:
1200        """
1201        Download a chunk of data from a share.
1202        """
1203        with start_action(
1204            action_type="allmydata:storage:http-client:mutable:read-share-chunk",
1205            storage_index=si_to_human_readable(storage_index),
1206            share_number=share_number,
1207            offset=offset,
1208            length=length,
1209        ) as ctx:
1210            result = await read_share_chunk(
1211                self._client, "mutable", storage_index, share_number, offset, length
1212            )
1213            ctx.add_success_fields(data_len=len(result))
1214            return result
1215
1216    @async_to_deferred
1217    async def list_shares(self, storage_index: bytes) -> Set[int]:
1218        """
1219        List the share numbers for a given storage index.
1220        """
1221        with start_action(
1222            action_type="allmydata:storage:http-client:mutable:list-shares",
1223            storage_index=si_to_human_readable(storage_index),
1224        ) as ctx:
1225            result = await self._list_shares(storage_index)
1226            ctx.add_success_fields(shares=result)
1227            return result
1228
1229    async def _list_shares(self, storage_index: bytes) -> Set[int]:
1230        """Implementation of ``list_shares()``."""
1231        url = self._client.relative_url(
1232            "/storage/v1/mutable/{}/shares".format(_encode_si(storage_index))
1233        )
1234        response = await self._client.request("GET", url)
1235        if response.code == http.OK:
1236            return cast(
1237                Set[int],
1238                await self._client.decode_cbor(
1239                    response,
1240                    _SCHEMAS["mutable_list_shares"],
1241                ),
1242            )
1243        else:
1244            raise ClientException(response.code)
1245
1246    @async_to_deferred
1247    async def advise_corrupt_share(
1248        self,
1249        storage_index: bytes,
1250        share_number: int,
1251        reason: str,
1252    ) -> None:
1253        """Indicate a share has been corrupted, with a human-readable message."""
1254        with start_action(
1255            action_type="allmydata:storage:http-client:mutable:advise-corrupt-share",
1256            storage_index=si_to_human_readable(storage_index),
1257            share_number=share_number,
1258            reason=reason,
1259        ):
1260            await advise_corrupt_share(
1261                self._client, "mutable", storage_index, share_number, reason
1262            )
Note: See TracBrowser for help on using the repository browser.