Skip to content

Commit 4e02afb

Browse files
authored
[INS-170] Unify JDBC URL Parsing Across Detector and Analyzer (Continued) (#4606)
* changes to unify with analyzer
1 parent 964eab0 commit 4e02afb

File tree

11 files changed

+559
-45
lines changed

11 files changed

+559
-45
lines changed

pkg/detectors/jdbc/jdbc.go

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ matchLoop:
8383
}
8484

8585
if verify {
86-
j, err := newJDBC(logCtx, jdbcConn)
86+
j, err := NewJDBC(logCtx, jdbcConn)
8787
if err != nil {
8888
continue
8989
}
@@ -206,31 +206,13 @@ func tryRedactRegex(conn string) (string, bool) {
206206
return newConn, true
207207
}
208208

209-
var supportedSubprotocols = map[string]func(logContext.Context, string) (jdbc, error){
210-
"mysql": ParseMySQL,
211-
"postgresql": ParsePostgres,
212-
"sqlserver": ParseSqlServer,
209+
var supportedSubprotocols = map[string]func(logContext.Context, string) (JDBC, error){
210+
"mysql": parseMySQL,
211+
"postgresql": parsePostgres,
212+
"sqlserver": parseSqlServer,
213213
}
214214

215-
type pingResult struct {
216-
err error
217-
determinate bool
218-
}
219-
220-
// ConnectionInfo holds parsed connection information
221-
type ConnectionInfo struct {
222-
Host string // includes port if specified, e.g., "host:port"
223-
Database string
224-
User string
225-
Password string
226-
Params map[string]string
227-
}
228-
229-
type jdbc interface {
230-
ping(context.Context) pingResult
231-
}
232-
233-
func newJDBC(ctx logContext.Context, conn string) (jdbc, error) {
215+
func NewJDBC(ctx logContext.Context, conn string) (JDBC, error) {
234216
// expected format: "jdbc:{subprotocol}:{subname}"
235217
if !strings.HasPrefix(strings.ToLower(conn), "jdbc:") {
236218
return nil, errors.New("expected jdbc prefix")
@@ -242,11 +224,11 @@ func newJDBC(ctx logContext.Context, conn string) (jdbc, error) {
242224
return nil, errors.New("expected a colon separated subprotocol and subname")
243225
}
244226

245-
// get the subprotocol parser
246227
parser, ok := supportedSubprotocols[strings.ToLower(subprotocol)]
247228
if !ok {
248-
return nil, errors.New("unsupported subprotocol")
229+
return nil, fmt.Errorf("unsupported subprotocol: %s", subprotocol)
249230
}
231+
250232
return parser(ctx, subname)
251233
}
252234

pkg/detectors/jdbc/jdbc_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import (
77

88
"github.com/google/go-cmp/cmp"
99
"github.com/kylelemons/godebug/pretty"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
1013
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
1114
"github.com/trufflesecurity/trufflehog/v3/pkg/engine/ahocorasick"
1215
)
@@ -183,3 +186,46 @@ func TestJdbc_FromDataWithIgnorePattern(t *testing.T) {
183186
})
184187
}
185188
}
189+
190+
func TestParseJDBCURL_EdgeCases(t *testing.T) {
191+
t.Run("MySQL with special characters in password", func(t *testing.T) {
192+
// Special chars: @ # $ % ^ & * ( )
193+
jdbcURL := "jdbc:mysql://user:p@ss%23word@localhost:3306/testdb"
194+
jdbc, err := NewJDBC(logContext.Background(), jdbcURL)
195+
require.NoError(t, err)
196+
197+
info := jdbc.GetConnectionInfo()
198+
assert.NoError(t, err)
199+
assert.NotNil(t, info)
200+
assert.Equal(t, "user", info.User)
201+
// URL encoding should be handled by url.Parse
202+
})
203+
204+
t.Run("PostgreSQL with empty database", func(t *testing.T) {
205+
jdbcURL := "jdbc:postgresql://user:pass@localhost:5432"
206+
jdbc, err := NewJDBC(logContext.Background(), jdbcURL)
207+
require.NoError(t, err)
208+
209+
info := jdbc.GetConnectionInfo()
210+
assert.Equal(t, "postgres", info.Database) // default
211+
})
212+
213+
t.Run("SQL Server with multiple semicolon params", func(t *testing.T) {
214+
jdbcURL := "jdbc:sqlserver://localhost:1433;database=testdb;user=sa;password=Pass123;encrypt=true;trustServerCertificate=false"
215+
jdbc, err := NewJDBC(logContext.Background(), jdbcURL)
216+
require.NoError(t, err)
217+
218+
info := jdbc.GetConnectionInfo()
219+
assert.Equal(t, "testdb", info.Database)
220+
assert.Equal(t, "sa", info.User)
221+
assert.Equal(t, "Pass123", info.Password)
222+
})
223+
224+
t.Run("MySQL missing host", func(t *testing.T) {
225+
// Missing // after prefix - will trigger error
226+
jdbcURL := "jdbc:mysql:/testdb"
227+
_, err := NewJDBC(logContext.Background(), jdbcURL)
228+
assert.Error(t, err)
229+
assert.Contains(t, err.Error(), "expected host to start with //")
230+
})
231+
}

pkg/detectors/jdbc/models.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package jdbc
2+
3+
import (
4+
"context"
5+
)
6+
7+
type DatabaseType int
8+
9+
const (
10+
Unknown DatabaseType = iota
11+
MySQL
12+
PostgreSQL
13+
SQLServer
14+
)
15+
16+
func (dt DatabaseType) String() string {
17+
switch dt {
18+
case MySQL:
19+
return "mysql"
20+
case PostgreSQL:
21+
return "postgresql"
22+
case SQLServer:
23+
return "sqlserver"
24+
default:
25+
return "unknown"
26+
}
27+
}
28+
29+
type pingResult struct {
30+
err error
31+
determinate bool
32+
}
33+
34+
// ConnectionInfo holds parsed connection information
35+
type ConnectionInfo struct {
36+
Host string // includes port if specified, e.g., "host:port"
37+
Database string
38+
User string
39+
Password string
40+
Params map[string]string
41+
}
42+
43+
type jdbcPinger interface {
44+
ping(context.Context) pingResult
45+
}
46+
47+
// public interfaces for analyzer
48+
type JDBCParser interface {
49+
GetConnectionInfo() *ConnectionInfo
50+
GetDBType() DatabaseType
51+
BuildConnectionString() string
52+
}
53+
type JDBC interface {
54+
jdbcPinger
55+
JDBCParser
56+
}

pkg/detectors/jdbc/mysql.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,26 @@ type MysqlJDBC struct {
1616
ConnectionInfo
1717
}
1818

19+
var _ JDBC = (*MysqlJDBC)(nil)
20+
1921
func (s *MysqlJDBC) ping(ctx context.Context) pingResult {
2022
return ping(ctx, "mysql", isMySQLErrorDeterminate,
21-
BuildMySQLConnectionString(s.Host, "", s.User, s.Password, s.Params))
23+
buildMySQLConnectionString(s.Host, "", s.User, s.Password, s.Params))
24+
}
25+
26+
func (s *MysqlJDBC) GetDBType() DatabaseType {
27+
return MySQL
28+
}
29+
30+
func (s *MysqlJDBC) GetConnectionInfo() *ConnectionInfo {
31+
return &s.ConnectionInfo
32+
}
33+
34+
func (s *MysqlJDBC) BuildConnectionString() string {
35+
return buildMySQLConnectionString(s.Host, s.Database, s.User, s.Password, s.Params)
2236
}
2337

24-
func BuildMySQLConnectionString(host, database, user, password string, params map[string]string) string {
38+
func buildMySQLConnectionString(host, database, user, password string, params map[string]string) string {
2539
conn := host + "/" + database
2640
userPass := user
2741
if password != "" {
@@ -56,7 +70,7 @@ func isMySQLErrorDeterminate(err error) bool {
5670
return false
5771
}
5872

59-
func ParseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
73+
func parseMySQL(ctx logContext.Context, subname string) (JDBC, error) {
6074
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
6175
if !strings.HasPrefix(subname, "//") {
6276
return nil, errors.New("expected host to start with //")
@@ -86,7 +100,7 @@ func ParseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
86100
}, nil
87101
}
88102

89-
func parseMySQLURI(ctx logContext.Context, subname string) (jdbc, error) {
103+
func parseMySQLURI(ctx logContext.Context, subname string) (JDBC, error) {
90104

91105
// for standard URI format, which is all i've seen for JDBC
92106
u, err := url.Parse(subname)

pkg/detectors/jdbc/mysql_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestMySQL(t *testing.T) {
9191
}
9292
for _, tt := range tests {
9393
t.Run(tt.input, func(t *testing.T) {
94-
j, err := ParseMySQL(logContext.Background(), tt.input)
94+
j, err := parseMySQL(logContext.Background(), tt.input)
9595

9696
if err != nil {
9797
got := result{ParseErr: true}

0 commit comments

Comments
 (0)