1
1
mirror of https://github.com/KenanZhu/AutoLibrary.git synced 2026-06-18 07:23:03 +08:00
Files
AutoLibrary/src/managers/driver/WebDriverDownloader.py
T

462 lines
14 KiB
Python

# -*- coding: utf-8 -*-
"""
Copyright (c) 2026 KenanZhu.
All rights reserved.
This software is provided "as is", without any warranty of any kind.
You may use, modify, and distribute this file under the terms of the MIT License.
See the LICENSE file for details.
"""
import os
import time
import shutil
import threading
import requests
import zipfile
import tarfile
from enum import Enum
from pathlib import Path
from typing import Optional, Callable
class WebDriverType(Enum):
"""
Web driver type
"""
CHROME = "chrome"
FIREFOX = "firefox"
EDGE = "edge"
class WebDriverArch(Enum):
"""
Web driver architecture
"""
class Chrome(Enum):
"""
Chrome web driver architecture
"""
WINX86_32 = "win32"
WINX86_64 = "win64"
# LINUX86_32 : no support for linux 32bit
LINUX86_64 = "linux64"
# LINUXARM : no support for linux arm64
MACX86_64 = "mac-x64"
MACARM = "mac-arm64"
class Firefox(Enum):
"""
Firefox web driver architecture
"""
WINX86_32 = "win32"
WINX86_64 = "win64"
WINARM = "win-aarch64"
LINUXX86_32 = "linux32"
LINUXX86_64 = "linux64"
LINUXARM = "linux-aarch64"
MACX86_64 = "macos"
MACARM = "macos-aarch64"
class Edge(Enum):
"""
Edge web driver architecture
"""
WINX86_32 = "win32"
WINX86_64 = "win64"
WINARM = "arm64"
# LINUX86_32 : no support for linux 32bit
LINUXX86_64 = "linux64"
# LINUXARM : no support for linux arm64
MACX86_64 = "mac64"
MACARM = "mac64_m1"
class WebDriverName:
"""
Web driver name
"""
def __init__(
self,
driver_type: WebDriverType
):
self.driver_type = driver_type
def __str__(
self
) -> str:
match self.driver_type:
case WebDriverType.CHROME:
return "chromedriver"
case WebDriverType.FIREFOX:
return "geckodriver"
case WebDriverType.EDGE:
return "msedgedriver"
case _:
raise ValueError(f"不受支持的 web driver 类型 : {self.driver_type}")
class WebDriverExecName:
"""
Web driver executable file name
"""
def __init__(
self,
driver_type: WebDriverType,
arch: WebDriverArch
):
self.driver_type = driver_type
self.arch = arch
def __str__(
self
) -> str:
is_win = True if self.arch is WebDriverArch.Chrome.WINX86_32 or\
self.arch is WebDriverArch.Chrome.WINX86_64 or\
self.arch is WebDriverArch.Firefox.WINX86_32 or\
self.arch is WebDriverArch.Firefox.WINX86_64 or\
self.arch is WebDriverArch.Edge.WINX86_32 or\
self.arch is WebDriverArch.Edge.WINX86_64 else False
match self.driver_type:
case WebDriverType.CHROME:
return f"{WebDriverName(self.driver_type)}" + (".exe" if is_win else "")
case WebDriverType.FIREFOX:
return f"{WebDriverName(self.driver_type)}" + (".exe" if is_win else "")
case WebDriverType.EDGE:
return f"{WebDriverName(self.driver_type)}" + (".exe" if is_win else "")
case _:
raise ValueError(f"不受支持的 web driver 类型 : {self.driver_type}")
class WebDriverFileName:
"""\
Web driver compressed file name
"""
def __init__(
self,
version: str,
driver_type: WebDriverType,
arch: WebDriverArch
):
self.version = version
self.driver_type = driver_type
self.arch = arch
def __str__(
self
) -> str:
match self.driver_type:
case WebDriverType.CHROME:
return f"{WebDriverName(self.driver_type)}-{self.arch.value}.zip"
case WebDriverType.FIREFOX:
if self.arch is WebDriverArch.Firefox.WINX86_32 or\
self.arch is WebDriverArch.Firefox.WINX86_64:
suffix = "zip"
else:
suffix = "tar.gz"
return f"{WebDriverName(self.driver_type)}-v{self.version}-{self.arch.value}.{suffix}"
case WebDriverType.EDGE:
return f"edgedriver_{self.arch.value}.zip" # Edge web driver file name is different
case _:
raise ValueError(f"不受支持的 web driver 类型 : {self.driver_type}")
class WebDriverURL:
"""
Web driver download URL
"""
def __init__(
self,
version: str,
driver_type: WebDriverType,
arch: WebDriverArch
):
self.version = version
self.driver_type = driver_type
self.arch = arch
self.file_name = str(WebDriverFileName(self.version, self.driver_type, self.arch))
def __str__(
self
) -> str:
match self.driver_type:
case WebDriverType.CHROME:
return f"https://storage.googleapis.com/chrome-for-testing-public/"\
f"{self.version}/"\
f"{self.arch.value}/"\
f"{self.file_name}"
case WebDriverType.FIREFOX:
return f"https://github.com/mozilla/geckodriver/releases/download/"\
f"v{self.version}/"\
f"{self.file_name}"
case WebDriverType.EDGE:
return f"https://msedgedriver.microsoft.com/"\
f"{self.version}/"\
f"{self.file_name}"
case _:
raise ValueError(f"不受支持的 web driver 类型 : {self.driver_type}")
class WebDriverDownloader:
"""
Base class for WebDriver downloaders
Args:
driver_type (WebDriverType): Web driver type
version (str): WebDriver version
arch (WebDriverArch): WebDriver architecture
download_dir (str): Download directory
"""
def __init__(
self,
driver_type: WebDriverType,
driver_version: str,
driver_arch: WebDriverArch,
download_dir: str
):
self.driver_type = driver_type
self.arch = driver_arch
self.version = driver_version
self.download_url = str(WebDriverURL(self.version, self.driver_type, self.arch))
self.download_dir = Path(download_dir)/self.driver_type.value/self.version/self.arch.value
self.download_dir.mkdir(mode=0o0755, parents=True, exist_ok=True)
self.download_path = self.download_dir/str(WebDriverFileName(self.version, self.driver_type, self.arch))
def download(
self,
progress_callback: Optional[Callable[[float, int, float, str], None]] = None,
cancel_event: Optional[threading.Event] = None
) -> Optional[Path]:
try:
# downlaod file : 0% - 98%
if not self._download(progress_callback, cancel_event=cancel_event):
return None
# verify file : 98% - 99%
if not self._verify(progress_callback):
progress_callback(0, 100, 0.0, "验证失败")
return None
# extract file : 99% - 100%
driver_path = self._extract(progress_callback)
if not driver_path:
progress_callback(0, 100, 0.0, "解压失败")
return None
return driver_path
except Exception as e:
raise e
def _download(
self,
progress_callback: Optional[Callable[[float, int, float, str], None]] = None,
max_retries: int = 3,
cancel_event: Optional[threading.Event] = None
) -> bool:
CHUNK_SIZE = 8192*8 # 64KB chunk
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
'Accept-Encoding': 'gzip, deflate'
}
for attempt in range(max_retries):
try:
if cancel_event and cancel_event.is_set():
return False
# resume download if file exists
if self.download_path.exists():
downloaded_size = self.download_path.stat().st_size
headers_ = headers.copy()
headers_['Range'] = f"bytes={downloaded_size}-"
mode = 'ab'
else:
downloaded_size = 0
headers_ = headers
mode = 'wb'
# get response
response = requests.get(str(self.download_url), headers=headers_, stream=True, timeout=10)
if response.status_code not in [200, 206]:
if self.download_path.exists():
self.download_path.unlink()
downloaded_size = 0
mode = 'wb'
response = requests.get(str(self.download_url), headers=headers, stream=True)
response.raise_for_status()
# get total size
total_size = int(response.headers.get('Content-Length', 0))
if response.status_code == 206: # Partial Content - server supports Range
total_size += downloaded_size
last_callback_time = time.time()
last_callback_size = downloaded_size
callback_interval = 0.1
with open(self.download_path, mode) as f:
for chunk in response.iter_content(CHUNK_SIZE):
current_time = time.time()
if cancel_event and cancel_event.is_set():
response.close()
return False
if not chunk:
continue
f.write(chunk)
downloaded_size += len(chunk)
if not progress_callback or total_size <= 0:
continue
current_progress = (downloaded_size/total_size)*98.0
if current_time - last_callback_time >= callback_interval or current_progress >= 98.0:
elapsed = current_time - last_callback_time
if elapsed > 0:
speed = (downloaded_size - last_callback_size)/(elapsed*1024.0)
else:
speed = 0.0
progress_callback(current_progress, 100, speed, "下载中...")
last_callback_time = current_time
last_callback_size = downloaded_size
if total_size > 0 and self.download_path.stat().st_size < total_size:
raise Exception(f"下载不完整 : {self.download_path.stat().st_size}/{total_size} 字节")
return True
except Exception as e:
if cancel_event and cancel_event.is_set():
return False
if attempt < max_retries - 1:
progress_callback(0, 100, 0.0, f"{attempt+1} 次重试...")
time.sleep(1)
continue
raise e
def _verify(
self,
progress_callback: Optional[Callable[[float, int, float, str], None]] = None
) -> bool:
progress_callback(98, 100, 0.0, "验证完成")
return True
def _extract(
self,
progress_callback: Optional[Callable[[float, int, float, str], None]] = None
) -> Optional[Path]:
try:
progress_callback(98, 100, 0.0, "解压中...")
file_path_str = str(self.download_path)
if file_path_str.endswith('.tar.gz'):
with tarfile.open(self.download_path, 'r:gz') as tar_ref:
tar_ref.extractall(self.download_dir)
else:
with zipfile.ZipFile(self.download_path, 'r') as zip_ref:
zip_ref.extractall(self.download_dir)
driver_file = None
for root, _, files in os.walk(self.download_dir):
for file in files:
expected_name = str(WebDriverExecName(self.driver_type, self.arch))
if file == str(expected_name):
src_path = Path(root, file)
dst_path = self.download_dir/file
src_path.rename(dst_path)
driver_file = dst_path
break
if driver_file:
break
if not driver_file:
raise FileNotFoundError(f"未找到 web driver 文件 : {expected_name}")
progress_callback(100, 100, 0.0, "解压完成")
self.download_path.unlink()
self._cleanup(driver_file)
return driver_file
except Exception:
return None
def _cleanup(
self,
driver_file: Path
) -> None:
for item in self.download_dir.iterdir():
if item != driver_file:
if item.is_dir():
shutil.rmtree(item)
else:
item.unlink()
class ChromeDriverDownloader(WebDriverDownloader):
"""
Chrome web driver downloader
Only support version higher than 114
"""
def __init__(
self,
version: str,
arch: WebDriverArch,
download_dir: str
):
super().__init__(WebDriverType.CHROME, version, arch, download_dir)
class FirefoxDriverDownloader(WebDriverDownloader):
"""
Firefox web driver downloader
This class do not resolve version mapping,
only support driver version higher than 0.17.0
"""
def __init__(
self,
version: str,
arch: WebDriverArch,
download_dir: str
):
super().__init__(WebDriverType.FIREFOX, version, arch, download_dir)
class EdgeDriverDownloader(WebDriverDownloader):
"""
Edge web driver downloader
"""
def __init__(
self,
version: str,
arch: WebDriverArch,
download_dir: str
):
super().__init__(WebDriverType.EDGE, version, arch, download_dir)