blob: d3ddf67f1ccee102c350f275e077ad41386a4d61 [file] [log] [blame]
Andrew Geissler5f350902021-07-23 13:09:54 -04001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5from pathlib import Path
6import bb.compress.lz4
7import bb.compress.zstd
8import contextlib
9import os
10import shutil
11import tempfile
12import unittest
13import subprocess
14
15
16class CompressionTests(object):
17 def setUp(self):
18 self._t = tempfile.TemporaryDirectory()
19 self.tmpdir = Path(self._t.name)
20 self.addCleanup(self._t.cleanup)
21
22 def _file_helper(self, mode_suffix, data):
23 tmp_file = self.tmpdir / "compressed"
24
25 with self.do_open(tmp_file, mode="w" + mode_suffix) as f:
26 f.write(data)
27
28 with self.do_open(tmp_file, mode="r" + mode_suffix) as f:
29 read_data = f.read()
30
31 self.assertEqual(read_data, data)
32
33 def test_text_file(self):
34 self._file_helper("t", "Hello")
35
36 def test_binary_file(self):
37 self._file_helper("b", "Hello".encode("utf-8"))
38
39 def _pipe_helper(self, mode_suffix, data):
40 rfd, wfd = os.pipe()
41 with open(rfd, "rb") as r, open(wfd, "wb") as w:
42 with self.do_open(r, mode="r" + mode_suffix) as decompress:
43 with self.do_open(w, mode="w" + mode_suffix) as compress:
44 compress.write(data)
45 read_data = decompress.read()
46
47 self.assertEqual(read_data, data)
48
49 def test_text_pipe(self):
50 self._pipe_helper("t", "Hello")
51
52 def test_binary_pipe(self):
53 self._pipe_helper("b", "Hello".encode("utf-8"))
54
55 def test_bad_decompress(self):
56 tmp_file = self.tmpdir / "compressed"
57 with tmp_file.open("wb") as f:
58 f.write(b"\x00")
59
60 with self.assertRaises(OSError):
61 with self.do_open(tmp_file, mode="rb", stderr=subprocess.DEVNULL) as f:
62 data = f.read()
63
64
65class LZ4Tests(CompressionTests, unittest.TestCase):
66 def setUp(self):
67 if shutil.which("lz4c") is None:
68 self.skipTest("'lz4c' not found")
69 super().setUp()
70
71 @contextlib.contextmanager
72 def do_open(self, *args, **kwargs):
73 with bb.compress.lz4.open(*args, **kwargs) as f:
74 yield f
75
76
77class ZStdTests(CompressionTests, unittest.TestCase):
78 def setUp(self):
79 if shutil.which("zstd") is None:
80 self.skipTest("'zstd' not found")
81 super().setUp()
82
83 @contextlib.contextmanager
84 def do_open(self, *args, **kwargs):
85 with bb.compress.zstd.open(*args, **kwargs) as f:
86 yield f
87
88
89class PZStdTests(CompressionTests, unittest.TestCase):
90 def setUp(self):
91 if shutil.which("pzstd") is None:
92 self.skipTest("'pzstd' not found")
93 super().setUp()
94
95 @contextlib.contextmanager
96 def do_open(self, *args, **kwargs):
97 with bb.compress.zstd.open(*args, num_threads=2, **kwargs) as f:
98 yield f