Skip to content

Commit b6ab52a

Browse files
committed
Made the source argument mutually exclusive - you must specify either --repo or --dir
Added a new crawl_local_files() function that mimics the interface of crawl_github_files() Modified the FetchRepo node to handle both cases Project name is now derived from either: The repository name (from GitHub URL) The directory name (from local path) Or can be manually specified with -n/--name The tool will use the same file pattern matching and size limits for both sources. All other functionality (generating abstractions, relationships, chapters, etc.) remains unchanged since they work with the same file list format.
1 parent 60b5467 commit b6ab52a

2 files changed

Lines changed: 114 additions & 25 deletions

File tree

main.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,27 @@
88

99
# Default file patterns
1010
DEFAULT_INCLUDE_PATTERNS = {
11-
"*.py", "*.js", "*.ts", "*.go", "*.java", "*.pyi", "*.pyx",
11+
"*.py", "*.js", "*.jsx", "*.ts", "*.tsx", "*.go", "*.java", "*.pyi", "*.pyx",
1212
"*.c", "*.cc", "*.cpp", "*.h", "*.md", "*.rst", "Dockerfile",
13-
"Makefile", "*.yaml", "*.yml"
13+
"Makefile", "*.yaml", "*.yml",
1414
}
1515

1616
DEFAULT_EXCLUDE_PATTERNS = {
1717
"*test*", "tests/*", "docs/*", "examples/*", "v1/*",
1818
"dist/*", "build/*", "experimental/*", "deprecated/*",
19-
"legacy/*", ".git/*", ".github/*"
19+
"legacy/*", ".git/*", ".github/*", ".next/*", ".vscode/*", "obj/*", "bin/*", "node_modules/*", "*.log"
2020
}
2121

2222
# --- Main Function ---
2323
def main():
24-
parser = argparse.ArgumentParser(description="Generate a tutorial for a GitHub codebase.")
25-
parser.add_argument("repo_url", help="URL of the public GitHub repository.")
26-
parser.add_argument("-n", "--name", help="Project name (optional, derived from URL if omitted).")
24+
parser = argparse.ArgumentParser(description="Generate a tutorial for a GitHub codebase or local directory.")
25+
26+
# Create mutually exclusive group for source
27+
source_group = parser.add_mutually_exclusive_group(required=True)
28+
source_group.add_argument("--repo", help="URL of the public GitHub repository.")
29+
source_group.add_argument("--dir", help="Path to local directory.")
30+
31+
parser.add_argument("-n", "--name", help="Project name (optional, derived from repo/directory if omitted).")
2732
parser.add_argument("-t", "--token", help="GitHub personal access token (optional, reads from GITHUB_TOKEN env var if not provided).")
2833
parser.add_argument("-o", "--output", default="output", help="Base directory for output (default: ./output).")
2934
parser.add_argument("-i", "--include", nargs="+", help="Include file patterns (e.g. '*.py' '*.js'). Defaults to common code files if not specified.")
@@ -32,14 +37,17 @@ def main():
3237

3338
args = parser.parse_args()
3439

35-
# Get GitHub token from argument or environment variable
36-
github_token = args.token or os.environ.get('GITHUB_TOKEN')
37-
if not github_token:
38-
print("Warning: No GitHub token provided. You might hit rate limits for public repositories.")
40+
# Get GitHub token from argument or environment variable if using repo
41+
github_token = None
42+
if args.repo:
43+
github_token = args.token or os.environ.get('GITHUB_TOKEN')
44+
if not github_token:
45+
print("Warning: No GitHub token provided. You might hit rate limits for public repositories.")
3946

4047
# Initialize the shared dictionary with inputs
4148
shared = {
42-
"repo_url": args.repo_url,
49+
"repo_url": args.repo,
50+
"local_dir": args.dir,
4351
"project_name": args.name, # Can be None, FetchRepo will derive it
4452
"github_token": github_token,
4553
"output_dir": args.output, # Base directory for CombineTutorial output
@@ -58,7 +66,7 @@ def main():
5866
"final_output_dir": None
5967
}
6068

61-
print(f"Starting tutorial generation for: {args.repo_url}")
69+
print(f"Starting tutorial generation for: {args.repo or args.dir}")
6270

6371
# Create the flow instance
6472
tutorial_flow = create_tutorial_flow()

nodes.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,73 @@
11
import os
22
import yaml
3+
import fnmatch
34
from pocketflow import Node, BatchNode
45
from utils.crawl_github_files import crawl_github_files
56
from utils.call_llm import call_llm # Assuming you have this utility
67

8+
def crawl_local_files(directory, include_patterns=None, exclude_patterns=None, max_file_size=None, use_relative_paths=True):
9+
"""
10+
Crawl files in a local directory with similar interface as crawl_github_files.
11+
12+
Args:
13+
directory (str): Path to local directory
14+
include_patterns (set): File patterns to include (e.g. {"*.py", "*.js"})
15+
exclude_patterns (set): File patterns to exclude (e.g. {"tests/*"})
16+
max_file_size (int): Maximum file size in bytes
17+
use_relative_paths (bool): Whether to use paths relative to directory
18+
19+
Returns:
20+
dict: {"files": {filepath: content}}
21+
"""
22+
if not os.path.isdir(directory):
23+
raise ValueError(f"Directory does not exist: {directory}")
24+
25+
files_dict = {}
26+
27+
for root, _, files in os.walk(directory):
28+
for filename in files:
29+
filepath = os.path.join(root, filename)
30+
31+
# Get path relative to directory if requested
32+
if use_relative_paths:
33+
relpath = os.path.relpath(filepath, directory)
34+
else:
35+
relpath = filepath
36+
37+
# Check if file matches any include pattern
38+
included = False
39+
if include_patterns:
40+
for pattern in include_patterns:
41+
if fnmatch.fnmatch(relpath, pattern):
42+
included = True
43+
break
44+
else:
45+
included = True
46+
47+
# Check if file matches any exclude pattern
48+
excluded = False
49+
if exclude_patterns:
50+
for pattern in exclude_patterns:
51+
if fnmatch.fnmatch(relpath, pattern):
52+
excluded = True
53+
break
54+
55+
if not included or excluded:
56+
continue
57+
58+
# Check file size
59+
if max_file_size and os.path.getsize(filepath) > max_file_size:
60+
continue
61+
62+
try:
63+
with open(filepath, 'r', encoding='utf-8') as f:
64+
content = f.read()
65+
files_dict[relpath] = content
66+
except Exception as e:
67+
print(f"Warning: Could not read file {filepath}: {e}")
68+
69+
return {"files": files_dict}
70+
771
# Helper to create context from files, respecting limits (basic example)
872
def create_llm_context(files_data):
973
context = ""
@@ -26,20 +90,26 @@ def get_content_for_indices(files_data, indices):
2690

2791
class FetchRepo(Node):
2892
def prep(self, shared):
29-
repo_url = shared["repo_url"]
93+
repo_url = shared.get("repo_url")
94+
local_dir = shared.get("local_dir")
3095
project_name = shared.get("project_name")
96+
3197
if not project_name:
32-
# Basic name derivation from URL
33-
project_name = repo_url.split('/')[-1].replace('.git', '')
98+
# Basic name derivation from URL or directory
99+
if repo_url:
100+
project_name = repo_url.split('/')[-1].replace('.git', '')
101+
else:
102+
project_name = os.path.basename(os.path.abspath(local_dir))
34103
shared["project_name"] = project_name
35104

36-
# Get file patterns directly from shared (defaults are defined in main.py)
105+
# Get file patterns directly from shared
37106
include_patterns = shared["include_patterns"]
38107
exclude_patterns = shared["exclude_patterns"]
39108
max_file_size = shared["max_file_size"]
40109

41110
return {
42111
"repo_url": repo_url,
112+
"local_dir": local_dir,
43113
"token": shared.get("github_token"),
44114
"include_patterns": include_patterns,
45115
"exclude_patterns": exclude_patterns,
@@ -48,15 +118,26 @@ def prep(self, shared):
48118
}
49119

50120
def exec(self, prep_res):
51-
print(f"Crawling repository: {prep_res['repo_url']}...")
52-
result = crawl_github_files(
53-
repo_url=prep_res["repo_url"],
54-
token=prep_res["token"],
55-
include_patterns=prep_res["include_patterns"],
56-
exclude_patterns=prep_res["exclude_patterns"],
57-
max_file_size=prep_res["max_file_size"],
58-
use_relative_paths=prep_res["use_relative_paths"]
59-
)
121+
if prep_res["repo_url"]:
122+
print(f"Crawling repository: {prep_res['repo_url']}...")
123+
result = crawl_github_files(
124+
repo_url=prep_res["repo_url"],
125+
token=prep_res["token"],
126+
include_patterns=prep_res["include_patterns"],
127+
exclude_patterns=prep_res["exclude_patterns"],
128+
max_file_size=prep_res["max_file_size"],
129+
use_relative_paths=prep_res["use_relative_paths"]
130+
)
131+
else:
132+
print(f"Crawling directory: {prep_res['local_dir']}...")
133+
result = crawl_local_files(
134+
directory=prep_res["local_dir"],
135+
include_patterns=prep_res["include_patterns"],
136+
exclude_patterns=prep_res["exclude_patterns"],
137+
max_file_size=prep_res["max_file_size"],
138+
use_relative_paths=prep_res["use_relative_paths"]
139+
)
140+
60141
# Convert dict to list of tuples: [(path, content), ...]
61142
files_list = list(result.get("files", {}).items())
62143
print(f"Fetched {len(files_list)} files.")

0 commit comments

Comments
 (0)