2020import random
2121import time
2222from io import BytesIO
23- from typing import NamedTuple
23+ from typing import List , NamedTuple , Optional
2424
2525import pytest
2626
@@ -51,21 +51,26 @@ async def create_client():
5151 return AsyncGrpcClient ()
5252
5353
54- def _aggregate_download_results (results ) :
54+ def _aggregate_download_results (results : List [ DownloadResult ]) -> DownloadResult :
5555 if not results :
5656 raise ValueError ("At least one download result is required." )
5757
5858 total_bytes = sum (result .total_bytes for result in results )
5959 measured_start_time = min (result .measured_start_time for result in results )
6060 measured_end_time = max (result .measured_end_time for result in results )
61- measured_elapsed_time = measured_end_time - measured_start_time
62- if measured_elapsed_time <= 0 :
61+ if measured_end_time <= measured_start_time :
6362 raise ValueError ("Measured elapsed time must be positive." )
6463
65- return total_bytes , measured_elapsed_time
64+ return DownloadResult (
65+ total_bytes = total_bytes ,
66+ measured_start_time = measured_start_time ,
67+ measured_end_time = measured_end_time ,
68+ )
6669
6770
68- def _calculate_average_throughput_mib_s (download_bytes_list , download_elapsed_times ):
71+ def _calculate_average_throughput_mib_s (
72+ download_bytes_list : List [int ], download_elapsed_times : List [float ]
73+ ) -> float :
6974 total_bytes_downloaded = sum (download_bytes_list )
7075 total_elapsed_time = sum (download_elapsed_times )
7176 if total_elapsed_time <= 0 :
@@ -74,15 +79,19 @@ def _calculate_average_throughput_mib_s(download_bytes_list, download_elapsed_ti
7479 return (total_bytes_downloaded / total_elapsed_time ) / (1024 * 1024 )
7580
7681
77- def _record_measured_start (measured_start_time , current_time ):
82+ def _record_measured_start (
83+ measured_start_time : Optional [float ], current_time : float
84+ ) -> float :
7885 if measured_start_time is None :
7986 return current_time
8087 return measured_start_time
8188
8289
8390def _build_download_result (
84- total_bytes_downloaded , measured_start_time , measured_end_time
85- ):
91+ total_bytes_downloaded : int ,
92+ measured_start_time : Optional [float ],
93+ measured_end_time : Optional [float ],
94+ ) -> DownloadResult :
8695 if measured_start_time is None or measured_end_time is None :
8796 raise ValueError ("No downloads completed during the measured interval." )
8897
@@ -224,13 +233,7 @@ async def _worker_coro():
224233 results = await asyncio .gather (* tasks )
225234
226235 await mrd .close ()
227- total_bytes , measured_elapsed_time = _aggregate_download_results (results )
228- measured_start_time = min (result .measured_start_time for result in results )
229- return DownloadResult (
230- total_bytes = total_bytes ,
231- measured_start_time = measured_start_time ,
232- measured_end_time = measured_start_time + measured_elapsed_time ,
233- )
236+ return _aggregate_download_results (results )
234237
235238
236239def _download_files_worker (process_idx , filename , params , bucket_type ):
@@ -246,7 +249,8 @@ def download_files_mp_mc_wrapper(pool, files_names, params, bucket_type):
246249 args = [(i , files_names [i ], params , bucket_type ) for i in range (len (files_names ))]
247250
248251 results = pool .starmap (_download_files_worker , args )
249- return _aggregate_download_results (results )
252+ agg_res = _aggregate_download_results (results )
253+ return agg_res .total_bytes , agg_res .measured_end_time - agg_res .measured_start_time
250254
251255
252256@pytest .mark .parametrize (
0 commit comments