66import subprocess
77import tempfile
88from pathlib import Path
9- from typing import TYPE_CHECKING , TypedDict
9+ from typing import TYPE_CHECKING , Any , TypedDict
1010
1111import questionary
12+ from prompt_toolkit .application .current import get_app
1213
1314from commitizen import factory , git , out
1415from commitizen .cz .exceptions import CzException
2526 NothingToCommitError ,
2627)
2728from commitizen .git import smart_open
29+ from commitizen .interactive_preview import (
30+ make_length_validator as make_length_validator_preview ,
31+ )
32+ from commitizen .interactive_preview import (
33+ make_toolbar_content as make_toolbar_content_preview ,
34+ )
2835
2936if TYPE_CHECKING :
37+ from collections .abc import Callable
38+
3039 from commitizen .config import BaseConfig
3140
3241
@@ -45,6 +54,11 @@ class CommitArgs(TypedDict, total=False):
4554class Commit :
4655 """Show prompt for the user to create a guided commit."""
4756
57+ # Questionary types for interactive preview hooks (length validator / toolbar),
58+ # based on questionary 2.0.1
59+ VALIDATABLE_TYPES = {"input" , "text" , "password" , "path" , "checkbox" }
60+ BOTTOM_TOOLBAR_TYPES = {"input" , "text" , "password" , "confirm" }
61+
4862 def __init__ (self , config : BaseConfig , arguments : CommitArgs ) -> None :
4963 if not git .is_git_project ():
5064 raise NotAGitProjectError ()
@@ -71,13 +85,120 @@ def _read_backup_message(self) -> str | None:
7185 encoding = self .config .settings ["encoding" ]
7286 ).strip ()
7387
88+ def _build_commit_questions (
89+ self ,
90+ questions : list ,
91+ preview_enabled : bool ,
92+ max_preview_length : int ,
93+ ) -> list :
94+ """Build the list of questions to ask; add toolbar/validate when preview enabled."""
95+ if not preview_enabled :
96+ return list (questions )
97+
98+ default_answers : dict [str , Any ] = {
99+ q ["name" ]: q .get ("default" , "" )
100+ for q in questions
101+ if isinstance (q .get ("name" ), str )
102+ }
103+ field_filters : dict [str , Any ] = {
104+ q ["name" ]: q .get ("filter" )
105+ for q in questions
106+ if isinstance (q .get ("name" ), str )
107+ }
108+ answers_state : dict [str , Any ] = {}
109+
110+ def _get_current_buffer_text () -> str :
111+ try :
112+ app = get_app ()
113+ buffer = app .layout .current_buffer
114+ return buffer .text if buffer is not None else ""
115+ except Exception :
116+ return ""
117+
118+ def subject_builder (current_field : str , current_text : str ) -> str :
119+ preview_answers : dict [str , Any ] = default_answers .copy ()
120+ preview_answers .update (answers_state )
121+ if current_field :
122+ field_filter = field_filters .get (current_field )
123+ if field_filter :
124+ try :
125+ preview_answers [current_field ] = field_filter (current_text )
126+ except Exception :
127+ preview_answers [current_field ] = current_text
128+ else :
129+ preview_answers [current_field ] = current_text
130+ try :
131+ return self .cz .message (preview_answers ).partition ("\n " )[0 ].strip ()
132+ except Exception :
133+ return ""
134+
135+ def make_stateful_filter (
136+ name : str , original_filter : Callable [[str ], Any ] | None
137+ ) -> Callable [[str ], Any ]:
138+ def _filter (raw : str ) -> Any :
139+ value = original_filter (raw ) if original_filter else raw
140+ answers_state [name ] = value
141+ return value
142+
143+ return _filter
144+
145+ def make_toolbar (name : str ) -> Callable [[], str ]:
146+ def _toolbar () -> str :
147+ return make_toolbar_content_preview (
148+ subject_builder ,
149+ name ,
150+ _get_current_buffer_text (),
151+ max_length = max_preview_length ,
152+ )
153+
154+ return _toolbar
155+
156+ def make_length_validator (name : str ) -> Callable [[str ], bool | str ]:
157+ return make_length_validator_preview (
158+ subject_builder ,
159+ name ,
160+ max_length = max_preview_length ,
161+ )
162+
163+ enhanced_questions : list [dict [str , object ]] = []
164+ for q in questions :
165+ q_dict = dict (q )
166+ q_type = q_dict .get ("type" )
167+ name = q_dict .get ("name" )
168+
169+ if isinstance (name , str ):
170+ original_filter = q_dict .get ("filter" )
171+ q_dict ["filter" ] = make_stateful_filter (name , original_filter )
172+
173+ if q_type in self .BOTTOM_TOOLBAR_TYPES :
174+ q_dict ["bottom_toolbar" ] = make_toolbar (name )
175+
176+ if q_type in self .VALIDATABLE_TYPES :
177+ q_dict ["validate" ] = make_length_validator (name )
178+
179+ enhanced_questions .append (q_dict )
180+ return enhanced_questions
181+
74182 def _get_message_by_prompt_commit_questions (self ) -> str :
75- # Prompt user for the commit message
76183 questions = self .cz .questions ()
77184 for question in (q for q in questions if q ["type" ] == "list" ):
78185 question ["use_shortcuts" ] = self .config .settings ["use_shortcuts" ]
186+
187+ preview_enabled = bool (
188+ self .arguments .get ("preview" , False )
189+ or self .config .settings .get ("preview" , False )
190+ )
191+ max_preview_length = self .arguments .get (
192+ "message_length_limit" ,
193+ self .config .settings .get ("message_length_limit" , 0 ),
194+ )
195+
196+ questions_to_ask = self ._build_commit_questions (
197+ questions , preview_enabled , max_preview_length
198+ )
199+
79200 try :
80- answers = questionary .prompt (questions , style = self .cz .style )
201+ answers = questionary .prompt (questions_to_ask , style = self .cz .style )
81202 except ValueError as err :
82203 root_err = err .__context__
83204 if isinstance (root_err , CzException ):
0 commit comments