import asyncio import logging import os from concurrent.futures import Executor, ProcessPoolExecutor from datetime import datetime, timezone from functools import cache, partial from multiprocessing import freeze_support try: from aiohttp import web from multidict import MultiMapping from .middlewares import cors except ImportError as ie: raise ImportError( f"aiohttp dependency is not installed: {ie}. " + "Please re-install black with the '[d]' extra install " + "to obtain aiohttp_cors: `pip install black[d]`" ) from None import click import black from _black_version import version as __version__ from black.concurrency import maybe_use_uvloop # This is used internally by tests to shut down the server prematurely _stop_signal = asyncio.Event() # Request headers PROTOCOL_VERSION_HEADER = "X-Protocol-Version" LINE_LENGTH_HEADER = "X-Line-Length" PYTHON_VARIANT_HEADER = "X-Python-Variant" SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line" SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma" PREVIEW = "X-Preview" UNSTABLE = "X-Unstable" ENABLE_UNSTABLE_FEATURE = "X-Enable-Unstable-Feature" FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" DIFF_HEADER = "X-Diff" BLACK_HEADERS = [ PROTOCOL_VERSION_HEADER, LINE_LENGTH_HEADER, PYTHON_VARIANT_HEADER, SKIP_SOURCE_FIRST_LINE, SKIP_STRING_NORMALIZATION_HEADER, SKIP_MAGIC_TRAILING_COMMA, PREVIEW, UNSTABLE, ENABLE_UNSTABLE_FEATURE, FAST_OR_SAFE_HEADER, DIFF_HEADER, ] # Response headers BLACK_VERSION_HEADER = "X-Black-Version" DEFAULT_MAX_BODY_SIZE = 5 * 1024 * 1024 DEFAULT_WORKERS = os.cpu_count() or 1 class HeaderError(Exception): pass class InvalidVariantHeader(Exception): pass @click.command(context_settings={"help_option_names": ["-h", "--help"]}) @click.option( "--bind-host", type=str, help="Address to bind the server to.", default="localhost", show_default=True, ) @click.option( "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True ) @click.option( "--cors-allow-origin", "cors_allow_origins", multiple=True, help="Origin allowed to access blackd over CORS. Can be passed multiple times.", ) @click.option( "--max-body-size", type=click.IntRange(min=1), default=DEFAULT_MAX_BODY_SIZE, show_default=True, help="Maximum request body size in bytes.", ) @click.version_option(version=black.__version__) def main( bind_host: str, bind_port: int, cors_allow_origins: tuple[str, ...], max_body_size: int, ) -> None: logging.basicConfig(level=logging.INFO) app = make_app(cors_allow_origins=cors_allow_origins, max_body_size=max_body_size) ver = black.__version__ black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}") loop = maybe_use_uvloop() try: web.run_app( app, host=bind_host, port=bind_port, handle_signals=True, print=None, loop=loop, ) finally: if not loop.is_closed(): loop.close() @cache def executor() -> Executor: return ProcessPoolExecutor(max_workers=DEFAULT_WORKERS) def make_app( *, cors_allow_origins: tuple[str, ...] = (), max_body_size: int = DEFAULT_MAX_BODY_SIZE, ) -> web.Application: app = web.Application( client_max_size=max_body_size, middlewares=[ cors( allow_headers=(*BLACK_HEADERS, "Content-Type"), allow_origins=frozenset(cors_allow_origins), expose_headers=(BLACK_VERSION_HEADER,), ) ], ) app.add_routes([ web.post( "/", partial( handle, executor=executor(), executor_semaphore=asyncio.BoundedSemaphore(DEFAULT_WORKERS), ), ) ]) return app async def handle( request: web.Request, executor: Executor, executor_semaphore: asyncio.BoundedSemaphore, ) -> web.Response: headers = {BLACK_VERSION_HEADER: __version__} try: if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1": return web.Response( status=501, text="This server only supports protocol version 1" ) fast = False if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": fast = True try: mode = parse_mode(request.headers) except HeaderError as e: return web.Response(status=400, text=e.args[0]) req_bytes = await request.read() charset = request.charset if request.charset is not None else "utf8" req_str = req_bytes.decode(charset) then = datetime.now(timezone.utc) header = "" if mode.skip_source_first_line: first_newline_position: int = req_str.find("\n") + 1 header = req_str[:first_newline_position] req_str = req_str[first_newline_position:] only_diff = bool(request.headers.get(DIFF_HEADER, False)) formatted_str = await format_code( req_str=req_str, fast=fast, mode=mode, then=then, only_diff=only_diff, executor=executor, executor_semaphore=executor_semaphore, ) # Put the source first line back req_str = header + req_str formatted_str = header + formatted_str return web.Response( content_type=request.content_type, charset=charset, headers=headers, text=formatted_str, ) except black.NothingChanged: return web.Response(status=204, headers=headers) except black.InvalidInput as e: return web.Response(status=400, headers=headers, text=str(e)) except web.HTTPException: raise except Exception as e: logging.exception("Exception during handling a request") return web.Response(status=500, headers=headers, text=str(e)) async def format_code( *, req_str: str, fast: bool, mode: black.FileMode, then: datetime, only_diff: bool, executor: Executor, executor_semaphore: asyncio.BoundedSemaphore, ) -> str: async with executor_semaphore: loop = asyncio.get_event_loop() formatted_str = await loop.run_in_executor( executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode) ) if not only_diff: return formatted_str now = datetime.now(timezone.utc) src_name = f"In\t{then}" dst_name = f"Out\t{now}" return await loop.run_in_executor( executor, partial(black.diff, req_str, formatted_str, src_name, dst_name), ) def parse_mode(headers: MultiMapping[str]) -> black.Mode: try: line_length = int(headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)) except ValueError: raise HeaderError("Invalid line length header value") from None if PYTHON_VARIANT_HEADER in headers: value = headers[PYTHON_VARIANT_HEADER] try: pyi, versions = parse_python_variant_header(value) except InvalidVariantHeader as e: raise HeaderError( f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}", ) from None else: pyi = False versions = set() skip_string_normalization = bool( headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) ) skip_magic_trailing_comma = bool(headers.get(SKIP_MAGIC_TRAILING_COMMA, False)) skip_source_first_line = bool(headers.get(SKIP_SOURCE_FIRST_LINE, False)) preview = bool(headers.get(PREVIEW, False)) unstable = bool(headers.get(UNSTABLE, False)) enable_features: set[black.Preview] = set() enable_unstable_features = headers.get(ENABLE_UNSTABLE_FEATURE, "").split(",") for piece in enable_unstable_features: piece = piece.strip() if piece: try: enable_features.add(black.Preview[piece]) except KeyError: raise HeaderError( f"Invalid value for {ENABLE_UNSTABLE_FEATURE}: {piece}", ) from None return black.FileMode( target_versions=versions, is_pyi=pyi, line_length=line_length, skip_source_first_line=skip_source_first_line, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, preview=preview, unstable=unstable, enabled_features=enable_features, ) def parse_python_variant_header(value: str) -> tuple[bool, set[black.TargetVersion]]: if value == "pyi": return True, set() else: versions = set() for version in value.split(","): if version.startswith("py"): version = version[len("py") :] if "." in version: major_str, *rest = version.split(".") else: major_str = version[0] rest = [version[1:]] if len(version) > 1 else [] try: major = int(major_str) if major not in (2, 3): raise InvalidVariantHeader("major version must be 2 or 3") if len(rest) > 0: minor = int(rest[0]) if major == 2: raise InvalidVariantHeader("Python 2 is not supported") else: # Default to lowest supported minor version. minor = 7 if major == 2 else 3 version_str = f"PY{major}{minor}" if major == 3 and not hasattr(black.TargetVersion, version_str): raise InvalidVariantHeader(f"3.{minor} is not supported") versions.add(black.TargetVersion[version_str]) except (KeyError, ValueError): raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None return False, versions def patched_main() -> None: freeze_support() main() if __name__ == "__main__": patched_main()