Skip to content

Commit 9ea2eee

Browse files
committed
chore: add type hints and refactor microbenchmark helpers
Refactored _aggregate_download_results to return a DownloadResult and added type hints to helper functions in test_reads.py for better clarity and maintainability.
1 parent 4face09 commit 9ea2eee

1 file changed

Lines changed: 21 additions & 17 deletions

File tree

  • packages/google-cloud-storage/tests/perf/microbenchmarks/time_based/reads

packages/google-cloud-storage/tests/perf/microbenchmarks/time_based/reads/test_reads.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import random
2121
import time
2222
from io import BytesIO
23-
from typing import NamedTuple
23+
from typing import List, NamedTuple, Optional
2424

2525
import pytest
2626

@@ -51,21 +51,26 @@ async def create_client():
5151
return AsyncGrpcClient()
5252

5353

54-
def _aggregate_download_results(results):
54+
def _aggregate_download_results(results: List[DownloadResult]) -> DownloadResult:
5555
if not results:
5656
raise ValueError("At least one download result is required.")
5757

5858
total_bytes = sum(result.total_bytes for result in results)
5959
measured_start_time = min(result.measured_start_time for result in results)
6060
measured_end_time = max(result.measured_end_time for result in results)
61-
measured_elapsed_time = measured_end_time - measured_start_time
62-
if measured_elapsed_time <= 0:
61+
if measured_end_time <= measured_start_time:
6362
raise ValueError("Measured elapsed time must be positive.")
6463

65-
return total_bytes, measured_elapsed_time
64+
return DownloadResult(
65+
total_bytes=total_bytes,
66+
measured_start_time=measured_start_time,
67+
measured_end_time=measured_end_time,
68+
)
6669

6770

68-
def _calculate_average_throughput_mib_s(download_bytes_list, download_elapsed_times):
71+
def _calculate_average_throughput_mib_s(
72+
download_bytes_list: List[int], download_elapsed_times: List[float]
73+
) -> float:
6974
total_bytes_downloaded = sum(download_bytes_list)
7075
total_elapsed_time = sum(download_elapsed_times)
7176
if total_elapsed_time <= 0:
@@ -74,15 +79,19 @@ def _calculate_average_throughput_mib_s(download_bytes_list, download_elapsed_ti
7479
return (total_bytes_downloaded / total_elapsed_time) / (1024 * 1024)
7580

7681

77-
def _record_measured_start(measured_start_time, current_time):
82+
def _record_measured_start(
83+
measured_start_time: Optional[float], current_time: float
84+
) -> float:
7885
if measured_start_time is None:
7986
return current_time
8087
return measured_start_time
8188

8289

8390
def _build_download_result(
84-
total_bytes_downloaded, measured_start_time, measured_end_time
85-
):
91+
total_bytes_downloaded: int,
92+
measured_start_time: Optional[float],
93+
measured_end_time: Optional[float],
94+
) -> DownloadResult:
8695
if measured_start_time is None or measured_end_time is None:
8796
raise ValueError("No downloads completed during the measured interval.")
8897

@@ -224,13 +233,7 @@ async def _worker_coro():
224233
results = await asyncio.gather(*tasks)
225234

226235
await mrd.close()
227-
total_bytes, measured_elapsed_time = _aggregate_download_results(results)
228-
measured_start_time = min(result.measured_start_time for result in results)
229-
return DownloadResult(
230-
total_bytes=total_bytes,
231-
measured_start_time=measured_start_time,
232-
measured_end_time=measured_start_time + measured_elapsed_time,
233-
)
236+
return _aggregate_download_results(results)
234237

235238

236239
def _download_files_worker(process_idx, filename, params, bucket_type):
@@ -246,7 +249,8 @@ def download_files_mp_mc_wrapper(pool, files_names, params, bucket_type):
246249
args = [(i, files_names[i], params, bucket_type) for i in range(len(files_names))]
247250

248251
results = pool.starmap(_download_files_worker, args)
249-
return _aggregate_download_results(results)
252+
agg_res = _aggregate_download_results(results)
253+
return agg_res.total_bytes, agg_res.measured_end_time - agg_res.measured_start_time
250254

251255

252256
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)