Skip to content

Commit 9bd29cd

Browse files
authored
Merge pull request #10 from SpeedOfSpin/main
Made the source argument mutually exclusive - you must specify either…
2 parents 60b5467 + b6ab52a commit 9bd29cd

File tree

2 files changed

+114
-25
lines changed

2 files changed

+114
-25
lines changed

main.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,27 @@
88

99
# Default file patterns
1010
DEFAULT_INCLUDE_PATTERNS = {
11-
"*.py", "*.js", "*.ts", "*.go", "*.java", "*.pyi", "*.pyx",
11+
"*.py", "*.js", "*.jsx", "*.ts", "*.tsx", "*.go", "*.java", "*.pyi", "*.pyx",
1212
"*.c", "*.cc", "*.cpp", "*.h", "*.md", "*.rst", "Dockerfile",
13-
"Makefile", "*.yaml", "*.yml"
13+
"Makefile", "*.yaml", "*.yml",
1414
}
1515

1616
DEFAULT_EXCLUDE_PATTERNS = {
1717
"*test*", "tests/*", "docs/*", "examples/*", "v1/*",
1818
"dist/*", "build/*", "experimental/*", "deprecated/*",
19-
"legacy/*", ".git/*", ".github/*"
19+
"legacy/*", ".git/*", ".github/*", ".next/*", ".vscode/*", "obj/*", "bin/*", "node_modules/*", "*.log"
2020
}
2121

2222
# --- Main Function ---
2323
def main():
24-
parser = argparse.ArgumentParser(description="Generate a tutorial for a GitHub codebase.")
25-
parser.add_argument("repo_url", help="URL of the public GitHub repository.")
26-
parser.add_argument("-n", "--name", help="Project name (optional, derived from URL if omitted).")
24+
parser = argparse.ArgumentParser(description="Generate a tutorial for a GitHub codebase or local directory.")
25+
26+
# Create mutually exclusive group for source
27+
source_group = parser.add_mutually_exclusive_group(required=True)
28+
source_group.add_argument("--repo", help="URL of the public GitHub repository.")
29+
source_group.add_argument("--dir", help="Path to local directory.")
30+
31+
parser.add_argument("-n", "--name", help="Project name (optional, derived from repo/directory if omitted).")
2732
parser.add_argument("-t", "--token", help="GitHub personal access token (optional, reads from GITHUB_TOKEN env var if not provided).")
2833
parser.add_argument("-o", "--output", default="output", help="Base directory for output (default: ./output).")
2934
parser.add_argument("-i", "--include", nargs="+", help="Include file patterns (e.g. '*.py' '*.js'). Defaults to common code files if not specified.")
@@ -32,14 +37,17 @@ def main():
3237

3338
args = parser.parse_args()
3439

35-
# Get GitHub token from argument or environment variable
36-
github_token = args.token or os.environ.get('GITHUB_TOKEN')
37-
if not github_token:
38-
print("Warning: No GitHub token provided. You might hit rate limits for public repositories.")
40+
# Get GitHub token from argument or environment variable if using repo
41+
github_token = None
42+
if args.repo:
43+
github_token = args.token or os.environ.get('GITHUB_TOKEN')
44+
if not github_token:
45+
print("Warning: No GitHub token provided. You might hit rate limits for public repositories.")
3946

4047
# Initialize the shared dictionary with inputs
4148
shared = {
42-
"repo_url": args.repo_url,
49+
"repo_url": args.repo,
50+
"local_dir": args.dir,
4351
"project_name": args.name, # Can be None, FetchRepo will derive it
4452
"github_token": github_token,
4553
"output_dir": args.output, # Base directory for CombineTutorial output
@@ -58,7 +66,7 @@ def main():
5866
"final_output_dir": None
5967
}
6068

61-
print(f"Starting tutorial generation for: {args.repo_url}")
69+
print(f"Starting tutorial generation for: {args.repo or args.dir}")
6270

6371
# Create the flow instance
6472
tutorial_flow = create_tutorial_flow()

nodes.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,73 @@
11
import os
22
import yaml
3+
import fnmatch
34
from pocketflow import Node, BatchNode
45
from utils.crawl_github_files import crawl_github_files
56
from utils.call_llm import call_llm # Assuming you have this utility
67

8+
def crawl_local_files(directory, include_patterns=None, exclude_patterns=None, max_file_size=None, use_relative_paths=True):
9+
"""
10+
Crawl files in a local directory with similar interface as crawl_github_files.
11+
12+
Args:
13+
directory (str): Path to local directory
14+
include_patterns (set): File patterns to include (e.g. {"*.py", "*.js"})
15+
exclude_patterns (set): File patterns to exclude (e.g. {"tests/*"})
16+
max_file_size (int): Maximum file size in bytes
17+
use_relative_paths (bool): Whether to use paths relative to directory
18+
19+
Returns:
20+
dict: {"files": {filepath: content}}
21+
"""
22+
if not os.path.isdir(directory):
23+
raise ValueError(f"Directory does not exist: {directory}")
24+
25+
files_dict = {}
26+
27+
for root, _, files in os.walk(directory):
28+
for filename in files:
29+
filepath = os.path.join(root, filename)
30+
31+
# Get path relative to directory if requested
32+
if use_relative_paths:
33+
relpath = os.path.relpath(filepath, directory)
34+
else:
35+
relpath = filepath
36+
37+
# Check if file matches any include pattern
38+
included = False
39+
if include_patterns:
40+
for pattern in include_patterns:
41+
if fnmatch.fnmatch(relpath, pattern):
42+
included = True
43+
break
44+
else:
45+
included = True
46+
47+
# Check if file matches any exclude pattern
48+
excluded = False
49+
if exclude_patterns:
50+
for pattern in exclude_patterns:
51+
if fnmatch.fnmatch(relpath, pattern):
52+
excluded = True
53+
break
54+
55+
if not included or excluded:
56+
continue
57+
58+
# Check file size
59+
if max_file_size and os.path.getsize(filepath) > max_file_size:
60+
continue
61+
62+
try:
63+
with open(filepath, 'r', encoding='utf-8') as f:
64+
content = f.read()
65+
files_dict[relpath] = content
66+
except Exception as e:
67+
print(f"Warning: Could not read file {filepath}: {e}")
68+
69+
return {"files": files_dict}
70+
771
# Helper to create context from files, respecting limits (basic example)
872
def create_llm_context(files_data):
973
context = ""
@@ -26,20 +90,26 @@ def get_content_for_indices(files_data, indices):
2690

2791
class FetchRepo(Node):
2892
def prep(self, shared):
29-
repo_url = shared["repo_url"]
93+
repo_url = shared.get("repo_url")
94+
local_dir = shared.get("local_dir")
3095
project_name = shared.get("project_name")
96+
3197
if not project_name:
32-
# Basic name derivation from URL
33-
project_name = repo_url.split('/')[-1].replace('.git', '')
98+
# Basic name derivation from URL or directory
99+
if repo_url:
100+
project_name = repo_url.split('/')[-1].replace('.git', '')
101+
else:
102+
project_name = os.path.basename(os.path.abspath(local_dir))
34103
shared["project_name"] = project_name
35104

36-
# Get file patterns directly from shared (defaults are defined in main.py)
105+
# Get file patterns directly from shared
37106
include_patterns = shared["include_patterns"]
38107
exclude_patterns = shared["exclude_patterns"]
39108
max_file_size = shared["max_file_size"]
40109

41110
return {
42111
"repo_url": repo_url,
112+
"local_dir": local_dir,
43113
"token": shared.get("github_token"),
44114
"include_patterns": include_patterns,
45115
"exclude_patterns": exclude_patterns,
@@ -48,15 +118,26 @@ def prep(self, shared):
48118
}
49119

50120
def exec(self, prep_res):
51-
print(f"Crawling repository: {prep_res['repo_url']}...")
52-
result = crawl_github_files(
53-
repo_url=prep_res["repo_url"],
54-
token=prep_res["token"],
55-
include_patterns=prep_res["include_patterns"],
56-
exclude_patterns=prep_res["exclude_patterns"],
57-
max_file_size=prep_res["max_file_size"],
58-
use_relative_paths=prep_res["use_relative_paths"]
59-
)
121+
if prep_res["repo_url"]:
122+
print(f"Crawling repository: {prep_res['repo_url']}...")
123+
result = crawl_github_files(
124+
repo_url=prep_res["repo_url"],
125+
token=prep_res["token"],
126+
include_patterns=prep_res["include_patterns"],
127+
exclude_patterns=prep_res["exclude_patterns"],
128+
max_file_size=prep_res["max_file_size"],
129+
use_relative_paths=prep_res["use_relative_paths"]
130+
)
131+
else:
132+
print(f"Crawling directory: {prep_res['local_dir']}...")
133+
result = crawl_local_files(
134+
directory=prep_res["local_dir"],
135+
include_patterns=prep_res["include_patterns"],
136+
exclude_patterns=prep_res["exclude_patterns"],
137+
max_file_size=prep_res["max_file_size"],
138+
use_relative_paths=prep_res["use_relative_paths"]
139+
)
140+
60141
# Convert dict to list of tuples: [(path, content), ...]
61142
files_list = list(result.get("files", {}).items())
62143
print(f"Fetched {len(files_list)} files.")

0 commit comments

Comments
 (0)