| # |
| # SPDX-License-Identifier: GPL-2.0-only |
| # |
| |
| from pathlib import Path |
| import bb.compress.lz4 |
| import bb.compress.zstd |
| import contextlib |
| import os |
| import shutil |
| import tempfile |
| import unittest |
| import subprocess |
| |
| |
| class CompressionTests(object): |
| def setUp(self): |
| self._t = tempfile.TemporaryDirectory() |
| self.tmpdir = Path(self._t.name) |
| self.addCleanup(self._t.cleanup) |
| |
| def _file_helper(self, mode_suffix, data): |
| tmp_file = self.tmpdir / "compressed" |
| |
| with self.do_open(tmp_file, mode="w" + mode_suffix) as f: |
| f.write(data) |
| |
| with self.do_open(tmp_file, mode="r" + mode_suffix) as f: |
| read_data = f.read() |
| |
| self.assertEqual(read_data, data) |
| |
| def test_text_file(self): |
| self._file_helper("t", "Hello") |
| |
| def test_binary_file(self): |
| self._file_helper("b", "Hello".encode("utf-8")) |
| |
| def _pipe_helper(self, mode_suffix, data): |
| rfd, wfd = os.pipe() |
| with open(rfd, "rb") as r, open(wfd, "wb") as w: |
| with self.do_open(r, mode="r" + mode_suffix) as decompress: |
| with self.do_open(w, mode="w" + mode_suffix) as compress: |
| compress.write(data) |
| read_data = decompress.read() |
| |
| self.assertEqual(read_data, data) |
| |
| def test_text_pipe(self): |
| self._pipe_helper("t", "Hello") |
| |
| def test_binary_pipe(self): |
| self._pipe_helper("b", "Hello".encode("utf-8")) |
| |
| def test_bad_decompress(self): |
| tmp_file = self.tmpdir / "compressed" |
| with tmp_file.open("wb") as f: |
| f.write(b"\x00") |
| |
| with self.assertRaises(OSError): |
| with self.do_open(tmp_file, mode="rb", stderr=subprocess.DEVNULL) as f: |
| data = f.read() |
| |
| |
| class LZ4Tests(CompressionTests, unittest.TestCase): |
| def setUp(self): |
| if shutil.which("lz4c") is None: |
| self.skipTest("'lz4c' not found") |
| super().setUp() |
| |
| @contextlib.contextmanager |
| def do_open(self, *args, **kwargs): |
| with bb.compress.lz4.open(*args, **kwargs) as f: |
| yield f |
| |
| |
| class ZStdTests(CompressionTests, unittest.TestCase): |
| def setUp(self): |
| if shutil.which("zstd") is None: |
| self.skipTest("'zstd' not found") |
| super().setUp() |
| |
| @contextlib.contextmanager |
| def do_open(self, *args, **kwargs): |
| with bb.compress.zstd.open(*args, **kwargs) as f: |
| yield f |
| |
| |
| class PZStdTests(CompressionTests, unittest.TestCase): |
| def setUp(self): |
| if shutil.which("pzstd") is None: |
| self.skipTest("'pzstd' not found") |
| super().setUp() |
| |
| @contextlib.contextmanager |
| def do_open(self, *args, **kwargs): |
| with bb.compress.zstd.open(*args, num_threads=2, **kwargs) as f: |
| yield f |