diff --git a/pkg/http/handler.go b/pkg/http/handler.go index d55d7c53d7..1ae4713216 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -223,10 +223,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // Bypass cross-origin protection: this server uses bearer tokens (not + // cookies), so Sec-Fetch-Site CSRF checks are unnecessary. See PR #2359. + crossOriginProtection := http.NewCrossOriginProtection() + crossOriginProtection.AddInsecureBypassPattern("/") + mcpHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server { return ghServer }, &mcp.StreamableHTTPOptions{ - Stateless: true, + Stateless: true, + CrossOriginProtection: crossOriginProtection, }) mcpHandler.ServeHTTP(w, r) diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go index aeda12f424..46e86b4a89 100644 --- a/pkg/http/handler_test.go +++ b/pkg/http/handler_test.go @@ -756,3 +756,72 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo ctx := context.Background() return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx) } + +func TestCrossOriginProtection(t *testing.T) { + jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}` + + apiHost, err := utils.NewAPIHost("https://api.githubcopilot.com") + require.NoError(t, err) + + handler := NewHTTPMcpHandler( + context.Background(), + &ServerConfig{ + Version: "test", + }, + nil, + translations.NullTranslationHelper, + slog.Default(), + apiHost, + WithInventoryFactory(func(_ *http.Request) (*inventory.Inventory, error) { + return inventory.NewBuilder().Build() + }), + WithGitHubMCPServerFactory(func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) { + return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil + }), + WithScopeFetcher(allScopesFetcher{}), + ) + + r := chi.NewRouter() + handler.RegisterMiddleware(r) + handler.RegisterRoutes(r) + + tests := []struct { + name string + secFetchSite string + origin string + }{ + { + name: "cross-site request with bearer token succeeds", + secFetchSite: "cross-site", + origin: "https://example.com", + }, + { + name: "same-origin request succeeds", + secFetchSite: "same-origin", + }, + { + name: "native client without Sec-Fetch-Site succeeds", + secFetchSite: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonRPCBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(headers.AuthorizationHeader, "Bearer github_pat_xyz") + if tt.secFetchSite != "" { + req.Header.Set("Sec-Fetch-Site", tt.secFetchSite) + } + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code, "unexpected status code; body: %s", rr.Body.String()) + }) + } +} diff --git a/pkg/http/middleware/cors.go b/pkg/http/middleware/cors.go new file mode 100644 index 0000000000..2eaf4227b4 --- /dev/null +++ b/pkg/http/middleware/cors.go @@ -0,0 +1,43 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/github/github-mcp-server/pkg/http/headers" +) + +// SetCorsHeaders is middleware that sets CORS headers to allow browser-based +// MCP clients to connect from any origin. This is safe because the server +// authenticates via bearer tokens (not cookies), so cross-origin requests +// cannot exploit ambient credentials. +func SetCorsHeaders(h http.Handler) http.Handler { + allowHeaders := strings.Join([]string{ + "Content-Type", + "Mcp-Session-Id", + "Mcp-Protocol-Version", + "Last-Event-ID", + headers.AuthorizationHeader, + headers.MCPReadOnlyHeader, + headers.MCPToolsetsHeader, + headers.MCPToolsHeader, + headers.MCPExcludeToolsHeader, + headers.MCPFeaturesHeader, + headers.MCPLockdownHeader, + headers.MCPInsidersHeader, + }, ", ") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Max-Age", "86400") + w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, WWW-Authenticate") + w.Header().Set("Access-Control-Allow-Headers", allowHeaders) + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + h.ServeHTTP(w, r) + }) +} diff --git a/pkg/http/middleware/cors_test.go b/pkg/http/middleware/cors_test.go new file mode 100644 index 0000000000..fbd7c40cf9 --- /dev/null +++ b/pkg/http/middleware/cors_test.go @@ -0,0 +1,45 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/stretchr/testify/assert" +) + +func TestSetCorsHeaders(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware.SetCorsHeaders(inner) + + t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://localhost:6274") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST") + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization") + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type") + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id") + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Lockdown") + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Insiders") + assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id") + assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "WWW-Authenticate") + }) + + t.Run("POST request includes CORS headers", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("Origin", "http://localhost:6274") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin")) + }) +} diff --git a/pkg/http/server.go b/pkg/http/server.go index d1e8192ba4..f7cdaf9093 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -13,6 +13,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" @@ -167,6 +168,8 @@ func RunHTTPServer(cfg ServerConfig) error { } r.Group(func(r chi.Router) { + r.Use(middleware.SetCorsHeaders) + // Register Middleware First, needs to be before route registration handler.RegisterMiddleware(r)