|
13 | 13 | import subprocess |
14 | 14 | import sys |
15 | 15 | import typing as t |
16 | | -from collections.abc import Callable |
17 | 16 | from copy import deepcopy |
18 | 17 | from dataclasses import dataclass |
19 | 18 | from datetime import datetime |
20 | 19 | from io import StringIO |
21 | 20 | from time import perf_counter |
22 | 21 |
|
| 22 | +from libvcs._internal.run import ProgressCallbackProtocol |
23 | 23 | from libvcs._internal.shortcuts import create_project |
24 | 24 | from libvcs._internal.types import VCSLiteral |
25 | 25 | from libvcs.sync.git import GitSync |
|
48 | 48 |
|
49 | 49 | log = logging.getLogger(__name__) |
50 | 50 |
|
51 | | -ProgressCallback = Callable[[str, datetime], None] |
| 51 | +ProgressCallback: t.TypeAlias = ProgressCallbackProtocol |
| 52 | + |
| 53 | + |
| 54 | +class RepoPayloadBase(t.TypedDict): |
| 55 | + """Keyword arguments used to create a repo via libvcs.""" |
| 56 | + |
| 57 | + url: str |
| 58 | + path: str | os.PathLike[str] |
| 59 | + progress_callback: ProgressCallback | None |
| 60 | + |
| 61 | + |
| 62 | +class GitRepoPayload(RepoPayloadBase): |
| 63 | + """Keyword arguments for git repositories.""" |
| 64 | + |
| 65 | + vcs: t.Literal["git"] |
| 66 | + |
| 67 | + |
| 68 | +class HgRepoPayload(RepoPayloadBase): |
| 69 | + """Keyword arguments for Mercurial repositories.""" |
| 70 | + |
| 71 | + vcs: t.Literal["hg"] |
| 72 | + |
| 73 | + |
| 74 | +class SvnRepoPayload(RepoPayloadBase): |
| 75 | + """Keyword arguments for Subversion repositories.""" |
| 76 | + |
| 77 | + vcs: t.Literal["svn"] |
| 78 | + |
| 79 | + |
| 80 | +class RepoPayload(t.TypedDict): |
| 81 | + """Keyword arguments used to create a repo via libvcs.""" |
| 82 | + |
| 83 | + url: str |
| 84 | + path: str | os.PathLike[str] |
| 85 | + vcs: VCSLiteral | None |
| 86 | + progress_callback: ProgressCallback | None |
52 | 87 |
|
53 | 88 |
|
54 | 89 | PLAN_SYMBOLS: dict[PlanAction, str] = { |
@@ -836,28 +871,39 @@ def __init__(self, repo_url: str) -> None: |
836 | 871 |
|
837 | 872 |
|
838 | 873 | def update_repo( |
839 | | - repo_dict: t.Any, |
| 874 | + repo_dict: ConfigDict, |
840 | 875 | progress_callback: ProgressCallback | None = None, |
841 | 876 | # repo_dict: Dict[str, Union[str, Dict[str, GitRemote], pathlib.Path]] |
842 | 877 | ) -> GitSync: |
843 | 878 | """Synchronize a single repository.""" |
844 | | - repo_dict = deepcopy(repo_dict) |
845 | | - if "pip_url" not in repo_dict: |
846 | | - repo_dict["pip_url"] = repo_dict.pop("url") |
847 | | - if "url" not in repo_dict: |
848 | | - repo_dict["url"] = repo_dict.pop("pip_url") |
| 879 | + repo_payload = t.cast("dict[str, object]", deepcopy(repo_dict)) |
| 880 | + if "pip_url" not in repo_payload: |
| 881 | + repo_payload["pip_url"] = repo_payload.pop("url") |
| 882 | + if "url" not in repo_payload: |
| 883 | + repo_payload["url"] = repo_payload.pop("pip_url") |
| 884 | + |
| 885 | + repo_payload["progress_callback"] = progress_callback or progress_cb |
| 886 | + |
| 887 | + repo_url = t.cast("str", repo_payload["url"]) |
| 888 | + repo_vcs = t.cast("VCSLiteral | None", repo_payload.get("vcs")) |
| 889 | + if repo_vcs is None: |
| 890 | + vcs = guess_vcs(url=repo_url) |
| 891 | + if vcs is None: |
| 892 | + raise CouldNotGuessVCSFromURL(repo_url=repo_url) |
849 | 893 |
|
850 | | - repo_dict["progress_callback"] = progress_callback or progress_cb |
| 894 | + repo_payload["vcs"] = vcs |
| 895 | + repo_vcs = vcs |
851 | 896 |
|
852 | | - if repo_dict.get("vcs") is None: |
853 | | - vcs = guess_vcs(url=repo_dict["url"]) |
854 | | - if vcs is None: |
855 | | - raise CouldNotGuessVCSFromURL(repo_url=repo_dict["url"]) |
| 897 | + assert repo_vcs is not None |
856 | 898 |
|
857 | | - repo_dict["vcs"] = vcs |
| 899 | + if repo_vcs == "git": |
| 900 | + r = create_project(**t.cast("GitRepoPayload", repo_payload)) |
| 901 | + elif repo_vcs == "svn": |
| 902 | + r = t.cast("GitSync", create_project(**t.cast("SvnRepoPayload", repo_payload))) |
| 903 | + else: |
| 904 | + r = t.cast("GitSync", create_project(**t.cast("HgRepoPayload", repo_payload))) |
858 | 905 |
|
859 | | - r = create_project(**repo_dict) # Creates the repo object |
860 | 906 | r.update_repo(set_remotes=True) # Creates repo if not exists and fetches |
861 | 907 |
|
862 | 908 | # TODO: Fix this |
863 | | - return r # type:ignore |
| 909 | + return r |
0 commit comments