pingcap/parser (MySQL互換) で SQL を手軽に解析

この記事は freee Developers Advent Calendar 2021 の23日目の記事です🎄

freee の DBRE チームに所属している caterpillar です. なんだか大きなデータベースを眺める仕事をしています.

突然ですが, pingcap/parser を使って SQL を簡単に解析していきたいと思います. Go 製 の SQL Parser で, MySQL への高い互換性を謳う TiDB で利用されています. この parser の嬉しい点はこんな感じです.

  • シンプルで使いやすい
  • TiDB に利用されていることから, ある程度結果を信頼できる
  • mask 済 SQL もおおよそ構文解析可能

3つ目について, mask済の SQL は select * from users where id = ? のように一部が別の文字に置き換わっているものを指します. freee では, ORM が生成するクエリのパフォーマンス改善のため, セキュリティ観点で扱いやすいよう pt-fingerprint で masking を施したクエリを活用しています. pingcap/parser はおおよそ (後述) この mask 済 SQL を構文解析できました*1.

なおこの記事のソースコードは, github.com/pingcap/parser@6870058 のバージョンで動作確認しています*2.

利用例1: SQL に含まれるテーブル一覧を取得する

公式のquickstart を見るのが早いですが, まずは簡単な例で使い方を紹介します. 作成するのは「SQL中に出現したテーブル名の一覧を取得する」クエリ解析器です.

まず, チュートリアルに従って parse 関数を用意します(下記コード①). この関数にクエリ文字列を渡すと, AST (abstract syntax tree🎄) を得られます(③). 今回入力に使用した pt-fingerprint を用いて mask したクエリでは, 現状出会った範囲では where id in (?+) という形式が parse に失敗するため, 愚直に修正しています(②).

package main

import (
    "fmt"
    "github.com/pingcap/parser"
    "github.com/pingcap/parser/ast"
    _ "github.com/pingcap/parser/test_driver"
    "strings"
)

func parse(sql string) (*ast.StmtNode, error) { // ① 受け取った sql を Parse() に渡し, 結果ノードを返す
    p := parser.New()

    stmtNodes, _, err := p.Parse(sql, "", "")
    if err != nil {
        return nil, err
    }

    return &stmtNodes[0], nil
}

func main() {
    input := "select * from users where id = ?"
    sql := strings.ReplaceAll(input, "?+", "?")     // ② in (?+) が parser できないので置換
    astNode, _ := parse(sql)                        // ③ AST 生成
    ...
}

次に, AST を探索するインターフェース ast.Visitor が提供されているので, これを実装して AST を解析します.

Visitor を利用するには, 2つの関数 EnterLeave を実装する必要があります. Visitor は AST を深さ優先に巡回するわけですが, Enter は 親ノードから自身に到達した際に行う処理, Leave は子ノードの探索が終了した後に行う処理を記述します.

各ノードの型は parser/ast 以下でそれぞれ定義されていて, 型アサーションを利用することで判別できます.

Enter 関数内で, テーブル名を表すノード ast.TableName を見つけたら, 実際のテーブル名を記録しています(③). 欲しい情報を持つノードの型は, parse.y を眺めたり, 定義されている構造体を眺めていくと発見できます. TiDB document で確認できるものもあります.

// 見つけたテーブル名を格納する struct
type table struct {
    Names []string
}

func (v *table) Enter(in ast.Node) (ast.Node, bool) {
    if name, ok := in.(*ast.TableName); ok { // ③ 型アサーションで識別
        v.Names = append(v.Names, name.Name.O) // O は original, L は lower case を取得できる
    }
    return in, false
}

func (v *table) Leave(in ast.Node) (ast.Node, bool) {
    return in, true
}

func extract(rootNode *ast.StmtNode) []string {
    v := &table{}
    (*rootNode).Accept(v)
    return v.Names
}

作成した table Visitor を使用して, 以下のように SQL に登場するテーブルの一覧を取得できました.

func main() {
    input := "select `authors`.* from `authors` inner join `books` on `authors`.`id` = `books`.`author_id`" +
        " where `books`.`published_at` < ? order by `authors`.`id` limit ?"
    sql := strings.ReplaceAll(input, "?+", "?")     // ② in (?+) が parser できないので置換
    astNode, _ := parse(sql)                      // ③ AST 生成
    fmt.Printf(strings.Join(extract(astNode), ",")) // => authors,books
}

利用例2: Join しているテーブルを取得する

次に, SQL 中で Join が発生しているテーブルのペアを順に取得する Visitor パターンを実装してみます. 実際に Join される順番は Optimizer の実行計画に依存しますが, テーブル整理や怪しいクエリの監視に役立つかもしれません.

まずは Join を扱う ast.Node を探します.

・・・

・・・

・・・・・・

ありました! parser.y#L8836-L8876

基本 ast.Join に置き換えられ, メンバ LeftRight に Join するテーブル情報が子ノードとして加わるようです.

JoinTable:
    /* Use %prec to evaluate production TableRef before cross join */
    TableRef CrossOpt TableRef %prec tableRefPriority
    {
        $$ = ast.NewCrossJoin($1.(ast.ResultSetNode), $3.(ast.ResultSetNode))
    }
|   TableRef CrossOpt TableRef "ON" Expression
    {
        on := &ast.OnCondition{Expr: $5}
        $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on}
    }
|   TableRef CrossOpt TableRef "USING" '(' ColumnNameList ')'
    {
        $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, Using: $6.([]*ast.ColumnName)}
    }

ast.Join を見ると, LeftRight はそれぞれ ast.TableSource か JoinNode(ast.Join) と書かれています. TableSource のコメントを参考に, LeftRight はそれぞれ以下に派生していくようです.

  • ast.TableName: テーブル名
  • ast.SelectStmt: Select文
  • ast.SetOprStmt: Union句
  • ast.Join
// Join represents table join.
type Join struct {
    node

    // Left table can be TableSource or JoinNode.
    Left ResultSetNode
    // Right table can be TableSource or JoinNode or nil.
    Right ResultSetNode
    ...
}

// TableSource represents table source with a name.
type TableSource struct {
    node

    // Source is the source of the data, can be a TableName,
    // a SelectStmt, a SetOprStmt, or a JoinNode.
    Source ResultSetNode

    // AsName is the alias name of the table source.
    AsName model.CIStr
}

ここで大変申し訳ないのですが, 簡単のため Left, Right に入りうる ast.SelectStmt, ast.SetOprStmt ノードについて「Join が発生しない場合, 使用されるテーブルは1つ」と仮定します. 例えば, 今回紹介するコードでは以下のようなクエリは正しく検査できません.

SELECT * FROM table_a
INNER JOIN
  (SELECT *
   FROM table_b
   WHERE (EXISTS
            (SELECT 1
             FROM table_c
             WHERE table_c.id = 1 )) ) b ON table_a.some_id = b.id ;

さて, 探索結果を保存する struct として 2種類用意します. Joinされるテーブル を二分木で管理する tablePair と, 二分木の root を管理する result です. tablePair は, 木の leaf を兼用していて, Name にテーブル名が入ります.

type result struct {
    Joins []*tablePair
}

type tablePair struct {
    Name string
    Left *tablePair
    Right *tablePair
}

Visitor を実装していきます. まずは AST の root から順にたどり, Join を見つける Visitor です. ast.JoinNode を見つけると, Left, Right それぞれの子ノードを新たな Visitor に探索させます(④). この Visitor の実装は次の項で紹介します. Left, Right 両者の探索が終わると, 結果として得られた tablePair を保存し, (⑤)終了します. 重複探索を防ぐため, 返り値の二要素目 skipChildren を true にし, 子ノードの探索を skip しています(⑥).

func (v *result) Enter(in ast.Node) (ast.Node, bool) {
    switch node := in.(type) {
    case *ast.Join: // ④
        left, right := &tablePair{}, &tablePair{}
        node.Left.Accept(left)
        if node.Right != nil { // Right は nil の可能性あり
            node.Right.Accept(right)
        } else {
            right = nil
        }
        v.Joins = append(v.Joins, &tablePair{Left: left, Right: right}) // ⑤ 結果を格納
        return in, true                                                 // ⑥ 子ノードを探索しない
    default:
        return in, false
    }
}

func (v *result) Leave(in ast.Node) (ast.Node, bool) {
    return in, true
}

次に, ast.JoinNodeLeft, Right をそれぞれ探索する tablePair Visitor を実装します.

ast.SelectStmt, ast.SetOprStmt ノードについて「Join が発生しない場合, 使用されるテーブルは1つ」

という仮定を踏まえると, チェックすべきノードは ast.Joinast.TableName の2つです. ast.TableName は テーブル名を保存し, ast.Join は再帰的に Left, Right を辿っていきます.

func (v *tablePair) Enter(in ast.Node) (ast.Node, bool) {
    switch node := in.(type) {
    case *ast.Join:
        left, right := &tablePair{}, &tablePair{}
        node.Left.Accept(left)
        v.Left = left
        if node.Right != nil { // Right は nil の可能性あり
            node.Right.Accept(right)
            v.Right = right
        } else {
            v.Right = nil
        }
        return in, true
    case *ast.TableName:
        v.Name = node.Name.O
        return in, true
    default:
        return in, false
    }
}

func (v *tablePair) Leave(in ast.Node) (ast.Node, bool) {
    return in, true
}

作成した Visitor を利用して, 次のように結果を得られます. より現実的な解析では, "LEFT OUTER JOIN", "STRAIGHT JOIN" といった Join の種類などテーブル名以外の情報も収集することで, 便利に活用できそうです.

func main() {
    input := "SELECT DISTINCT `A`.* FROM `A`" +
        " INNER JOIN `B` ON `B`.`a_id` = `A`.`id`" +
        " INNER JOIN `C` ON `C`.`b_id` = `B`.`id`" +
        " WHERE `A`.`some_id` = ?" +
        " AND NOT (EXISTS" +
        " (SELECT ? FROM `B` INNER JOIN `C` ON `B`.`id` = `C`.`b_id`" +
        " WHERE `B`.`A_id` = `A`.`id` AND `C`.`some_type` IN(?+)))" +
        " AND `C`.`some_id` = ?" +
        " ORDER BY A.created_at DESC LIMIT ? OFFSET ?"
    sql := strings.ReplaceAll(input, "?+", "?") // ② in (?+) が parser できないので置換
    astNode, _ := parse(sql)                    // ③ AST 生成
    // extract の実装は省略
    fmt.Printf("%+v\n", extract(astNode)) // => [((A, B), C) (B, C)]
}

まとめ

pingcap/parser で手軽に AST を探索することができました. サービスの規模拡大に伴い, データベースに対して叩かれるクエリの把握は難しくなっていきますが, 機械的に立ち向かっていけたら楽しいですね.

さて, アドベントカレンダーも残すところあと2日となりました. 明日は酒のあるところに行けばだいたい会える, Public API チームのリーダー matz パイセンの記事です. お楽しみに!!!

*1:Prepared Statement として受け入れられているようです. ref: ast.ParamMakerExpr

*2:最近, pingcap/parser は pingcap/tidb に移動 (出戻り) したようです(記事). 最新に追従したい場合はご注意ください.