# Copyright (c) OpenMMLab. All rights reserved.
import functools
import hashlib
import importlib
import os
import os.path as osp
import pkg_resources
import re
import subprocess
import tarfile
import typing
from collections import defaultdict
from email.parser import FeedParser
from pkg_resources import get_distribution, parse_version
from typing import Any, List, Optional, Sequence, Tuple, Union

import click
import requests
from pip._vendor.packaging import version
from requests.exceptions import InvalidURL, RequestException, Timeout
from requests.models import Response

from .default import PKG2PROJECT
from .progress_bars import rich_progress_bar

[docs]def parse_url(url: str) -> Tuple[str, str]: """Parse username and repo from url. Args: url (str): Url for parsing username and repo name. Example: >>> parse_url('') 'open-mmlab', 'mmcv' >>> parse_ulr('') 'open-mmlab', 'mmcv' """ if url.startswith('git@'): res = url.split(':')[-1].split('/') elif 'git' in url: res = url.split('/')[-2:] else: raise ValueError(highlighted_error(f'{url} is invalid.')) username = res[0] repo = res[1].split('.')[0] return username, repo
[docs]def is_installed(package: str) -> bool: """Check package whether installed. Args: package (str): Name of package to be checked. """ # refresh the pkg_resources # more datails at importlib.reload(pkg_resources) try: get_distribution(package) return True except pkg_resources.DistributionNotFound: return False
[docs]def ensure_installation(func): """A decorator to make sure a package has been installed. Before invoking those functions which depend on installed package, the decorator makes sure the package has been installed. """ @functools.wraps(func) def wrapper(package): if not is_installed(package): raise RuntimeError( highlighted_error(f'{package} is not installed.')) return func(package) return wrapper
[docs]@ensure_installation def parse_home_page(package: str) -> Optional[str]: """Parse home page from package metadata. Args: package (str): Package to parse home page. """ home_page = None pkg = get_distribution(package) if pkg.has_metadata('METADATA'): metadata = pkg.get_metadata('METADATA') feed_parser = FeedParser() feed_parser.feed(metadata) home_page = feed_parser.close().get('home-page') return home_page
[docs]def get_github_url(package: str) -> str: """Get github url. Args: package (str): Name of package, like mmcls. Example: >>> get_github_url('mmcls') '' """ home_page = None if is_installed(package): home_page = parse_home_page(package) if not home_page: try: pkg_info = get_package_info_from_pypi(package) home_page = pkg_info['info'].get('home_page') except Exception: pass if home_page: if home_page.endswith('.git'): github_url = home_page elif home_page.endswith('.com'): github_url = home_page.replace('.com', '.git') else: github_url = home_page + '.git' return github_url else: raise ValueError( highlighted_error(f'Failed to get github url of {package}.'))
[docs]def get_content_from_url(url: str, timeout: int = 15, stream: bool = False, check_certificate: bool = True) -> Response: """Get content from url. Args: url (str): Url for getting content. timeout (int): Set the socket timeout. Default: 15. check_certificate (bool): Whether to check the ssl certificate. Default: True. """ try: response = requests.get( url, timeout=timeout, stream=stream, verify=check_certificate) except InvalidURL as err: raise highlighted_error(err) # type: ignore except Timeout as err: raise highlighted_error(err) # type: ignore except RequestException as err: raise highlighted_error(err) # type: ignore except Exception as err: raise highlighted_error(err) # type: ignore return response
[docs]@typing.no_type_check def download_from_file(url: str, dest_path: str, hash_prefix: Optional[str] = None, check_certificate: bool = True) -> None: """Download object at the given URL to a local path. Args: url (str): URL of the object to download. dest_path (str): Path where object will be saved. hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. Default: None. check_certificate (bool): Whether to check the ssl certificate. Default: True. """ if hash_prefix is not None: sha256 = hashlib.sha256() response = get_content_from_url( url, stream=True, check_certificate=check_certificate) size = int(response.headers.get('content-length')) with open(dest_path, 'wb') as fw: content_iter = response.iter_content(chunk_size=1024) for chunk in rich_progress_bar(content_iter, size=size): if chunk: fw.write(chunk) fw.flush() if hash_prefix is not None: sha256.update(chunk) if hash_prefix is not None: digest = sha256.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: raise RuntimeError( highlighted_error( f'invalid hash value, expected "{hash_prefix}", but got ' f'"{digest}"'))
[docs]def split_package_version(package: str) -> Tuple[str, ...]: """Split the package which maybe contains version info. Args: package (str): Name of package to split. Example: >>> split_package_version('mmcls') 'mmcls', '' >>> split_package_version('mmcls=0.11.0') 'mmcls', '0.11.0' >>> split_package_version('mmcls==0.11.0') 'mmcls', '0.11.0' """ if '=' in package: return tuple(re.split(r'=+', package)) else: return package, ''
[docs]def get_package_version(repo_root: str) -> Tuple[str, str]: """Get package and version from local repo. Args: repo_root (str): Directory of repo. """ for file_name in os.listdir(repo_root): version_path = osp.join(repo_root, file_name, '') if osp.exists(version_path): with open(version_path, encoding='utf-8') as f: exec(compile(, version_path, 'exec')) return file_name, locals()['__version__'] return '', ''
[docs]@ensure_installation def get_installed_version(package: str) -> str: """Get the version of package from local environment. Args: package (str): Name of package. """ return get_distribution(package).version
[docs]def get_package_info_from_pypi(package: str, timeout: int = 15) -> dict: """Get package information from pypi. Args: package (str): Package to get information. timeout (int): Set the socket timeout. Default: 15. """ pkg_url = f'{package}/json' response = get_content_from_url(pkg_url, timeout) return response.json()
[docs]def get_release_version(package: str, timeout: int = 15) -> List[str]: """Get release version from pypi. The return list of versions is sorted by ascending order. Args: package (str): Package to get version. timeout (int): Set the socket timeout. Default: 15. """ pkg_info = get_package_info_from_pypi(package, timeout) releases = pkg_info['releases'] return sorted(releases, key=parse_version)
[docs]def get_latest_version(package: str, timeout: int = 15) -> str: """Get latest version of package. Args: package (str): Package to get latest version. timeout (int): Set the socket timeout. Default: 15. Example: >>> get_latest_version('mmcv-full') '0.11.0' """ release_version = get_release_version(package, timeout) return release_version[-1]
def is_version_equal(version1: str, version2: str) -> bool: return version.parse(version1) == version.parse(version2)
[docs]@ensure_installation def package2module(package: str): """Infer module name from package. Args: package (str): Package to infer module name. """ pkg = get_distribution(package) if pkg.has_metadata('top_level.txt'): module_name = pkg.get_metadata('top_level.txt').split('\n')[0] return module_name else: raise ValueError( highlighted_error(f'can not infer the module name of {package}'))
[docs]@ensure_installation def get_installed_path(package: str) -> str: """Get installed path of package. Args: package (str): Name of package. Example: >>> get_installed_path('mmcls') >>> '.../lib/python3.7/site-packages/mmcls' """ # if the package name is not the same as module name, module name should be # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name pkg = get_distribution(package) possible_path = osp.join(pkg.location, package) if osp.exists(possible_path): return possible_path else: return osp.join(pkg.location, package2module(package))
[docs]def get_torch_cuda_version() -> Tuple[str, str]: """Get PyTorch version and CUDA version if it is available. Example: >>> get_torch_cuda_version() '1.8.0', '102' """ try: import torch except ImportError as err: raise err torch_v = torch.__version__ if '+' in torch_v: # 1.8.1+cu111 -> 1.8.1 torch_v = torch_v.split('+')[0] if torch.version.cuda is not None: # torch.version.cuda like 10.2 -> 102 cuda_v = ''.join(torch.version.cuda.split('.')) else: cuda_v = 'cpu' return torch_v, cuda_v
[docs]def cast2lowercase(input: Union[list, tuple, str]) -> Any: """Cast input into lowercase. Example: >>> cast2lowercase('Hello World') 'hello world' >>> cast2lowercase(['Hello', 'World']) ['hello', 'world'] """ inputs = [] outputs = [] if isinstance(input, str): inputs = [input] else: inputs = input # type: ignore for _input in inputs: outputs.append(_input.lower()) if isinstance(input, str): return outputs[0] elif isinstance(input, tuple): return tuple(outputs) else: return outputs
[docs]def recursively_find(root: str, base_name: str, followlinks=False) -> list: """Recursive list a directory, return all files with a given base_name. Args: root (str): The root directory to list. base_name (str): The base_name. followlinks (bool): Follow symbolic links. Defaults to False. Return: Files with given base_name. """ files = [] for _root, _, _files in os.walk(root, followlinks=followlinks): if base_name in _files: files.append(osp.join(_root, base_name)) return files
def highlighted_error(msg: Union[str, Exception]) -> str: return, fg='red', bold=True) # type: ignore def color_echo(msg: str, color: str) -> None: click.echo(, fg=color)) # type: ignore def echo_error(msg: Union[str, Exception]) -> None: color_echo(msg=msg, color='red') # type: ignore def echo_warning(msg: Union[str, Exception]) -> None: color_echo(msg=msg, color='yellow') # type: ignore def echo_success(msg: str) -> None: color_echo(msg=msg, color='green') def exit_with_error(msg: Union[str, Exception]) -> None: echo_error(msg) exit(1) def call_command(cmd: list) -> None: try: subprocess.check_call(cmd) except Exception as e: raise highlighted_error(e) # type: ignore
[docs]def string2args(text: str) -> dict: """Parse string to arguments. Args: text (str): The string to be parsed, which should be of the format: "--arg1 value1 value2 --arg2 value1 ... --argn value1". Using '=' is also OK, like "--argn=value1". It also support flag args like "--arg1". Return: A dictionary that contains parsed args. Note that the type of values will all be strings. Example: >>> text = '--arg1 value1 value2 --arg2 value3 --arg3 value4' >>> string2args(text) args = { 'arg1': [value1, value2], 'arg2': [value3], 'arg3': [value4] } """ ret: dict = defaultdict(list) name = None items = text.split() for item in items: if name is None: assert item.startswith('--') if item.startswith('--'): if name is not None and ret[name] == []: ret[name] = bool if '=' in item: name, value = item[2:].split('=') ret[name] = [value] name = None else: name = item[2:] else: ret[name].append(item) if name is not None and ret[name] == []: ret[name] = bool return ret
[docs]def args2string(args: dict) -> str: """Convert args dictionary to a string. Args: args (dict): A dictionary that contains parsed args. Return: A converted string. Example: >>> args = { 'arg1': [value1, value2], 'arg2': [value3], 'arg3': [value4] } >>> args2string(args) '--arg1 value1 value2 --arg2 value3 --arg3 value4' """ text = [] for k in args: text.append(f'--{k}') if args[k] is not bool: text.extend([str(x) for x in args[k]]) return ' '.join(text)
[docs]def get_config(cfg, name): """Given the argument name, read the value from the config file. The name can be multi-level, like '' """ name = name.split('.') suffix = '' for item in name: if isinstance(cfg, Sequence) and not isinstance(cfg, str): cfg = cfg[int(item)] else: assert item in cfg, f'attribute {item} not cfg{suffix}' cfg = cfg[item] suffix += f'.{item}' return cfg
[docs]def set_config(cfg, name, value): """Given the argument name and value, set the value of the config file. The name can be multi-level, like '' """ name = name.split('.') suffix = '' for item in name[:-1]: if isinstance(cfg, Sequence) and not isinstance(cfg, str): cfg = cfg[int(item)] else: assert item in cfg, f'attribute {item} not cfg{suffix}' cfg = cfg[item] suffix += f'.{item}' assert name[-1] in cfg, f'attribute {item} not cfg{suffix}' cfg[name[-1]] = value
[docs]def extract_tar(tar_path: str, dst: str) -> None: """Extract file from tar. Args: tar_path (str): Path for extracting. dst (str): Destination to save file. """ assert tarfile.is_tarfile(tar_path), f'{tar_path} is an invalid path.' with, 'r') as tar_file: tar_file.extractall(dst)
[docs]def module_full_name(abbr: str) -> str: """Get the full name of the module given abbreviation. Args: abbr (str): The abbreviation, should be the sub-string of one (and only one) supported module. Return: str: The full name of the corresponding module. If abbr is the sub-string of zero / multiple module names, return empty string. """ names = [x for x in PKG2PROJECT if abbr in x] if len(names) == 1: return names[0] elif abbr in names or is_installed(abbr): return abbr return ''
