Groovy AST の勉強 その4

前回のお題を実現する ASTTransformation を実装してみる。

visit メソッドの内容は次の通り。メインクラスの ClassNode を取得し、(A)(B)(C)の変更を行う。

  public void visit(ASTNode[] astNodes, SourceUnit sourceUnit) {

    def ast = sourceUnit.getAST()
    def cn = pickup_main_class(ast)

    if (cn != null) {

// (A) メインクラスにフィールド変数 ta を追加

      add_field_ta(cn)

// (B) メインクラスにメソッド print/println/openWindow を追加

      add_method_print(cn)
      add_method_println(cn)
      add_method_openWindow(cn)

// (C) run メソッド内部に try-catch 構造を追加

      def run_mn = pickup_run_method(cn)
      if (run_mn != null) {
        insert_try_catch_to_method(run_mn)
      }
    }

  }

(A) メインクラスにフィールド変数 ta を追加。

void add_field_ta (ClassNode cn) {
  def n1 = new FieldNode ("ta", 10,
    ClassHelper.make(javax.swing.JTextArea,false),
    cn, ConstantExpression.NULL)
  cn.addField(n1)
}

(B) メインクラスにメソッド print/println/openWindow を追加。

print メソッドの追加を示す。println/openWindow は巻末に提示する。
AST ツリー構造にしたがって、たんたんとコーディングしていけばよいだけなのだが、非常に骨が折れた。*1

変数などの型の ClassNode の解決については、ClassHelper.make メソッドで行っている。

void add_method_print (ClassNode cn) {
  def n0 = new Parameter (
             ClassHelper.make(java.lang.Object,false),"x")
  def n1 = [n0] as Parameter[]
  def n2 = new VariableExpression ("ta")
  def n3 = new Token (123,"==",-1,-1)
  def n4 = ConstantExpression.NULL
  def n5 = new BinaryExpression (n2,n3,n4)
  def n6 = new BooleanExpression (n5)
  def n7 = new VariableExpression ("this")
  def n8 = new ConstantExpression ("openWindow")
  def n9 = new ConstantExpression ("console")
  def n10 = [n9]
  def n11 = new ArgumentListExpression (n10)
  def n12 = new MethodCallExpression (n7,n8,n11)
          // this.openWindow("console")
  def n13 = new ExpressionStatement (n12)
  def n15 = new BlockStatement()
  n15.addStatement(n13)
  def n16 = new EmptyStatement ()
  def n17 = new IfStatement (n6,n15,n16)
  def n18 = new VariableExpression ("ta")
  def n19 = new ConstantExpression ("setText")
  def n20 = new VariableExpression ("ta")
  def n21 = new ConstantExpression ("getText")
  def n23 = ArgumentListExpression.EMPTY_ARGUMENTS
  def n24 = new MethodCallExpression (n20,n21,n23)
          // ta.getText()
  def n25 = new Token (200,"+",-1,-1)
  def n26 = new VariableExpression ("x")
  def n27 = new ConstantExpression ("toString")
  def n29 = ArgumentListExpression.EMPTY_ARGUMENTS
  def n30 = new MethodCallExpression (n26,n27,n29)
          // x.toString()
  def n31 = new BinaryExpression (n24,n25,n30)
          // ta.getText()+x.toString()
  def n32 = [n31]
  def n33 = new ArgumentListExpression (n32)
  def n34 = new MethodCallExpression (n18,n19,n33)
          // ta.setText(ta.getText()+x.toString())
  def n35 = new ExpressionStatement (n34)
  def n36 = new VariableExpression ("ta")
  def n37 = new ConstantExpression ("setCaretPosition")
  def n38 = new VariableExpression ("ta")
  def n39 = new ConstantExpression ("getText")
  def n41 = ArgumentListExpression.EMPTY_ARGUMENTS
  def n42 = new MethodCallExpression (n38,n39,n41)
          // ta.getText()
  def n43 = new ConstantExpression ("length")
  def n45 = ArgumentListExpression.EMPTY_ARGUMENTS
  def n46 = new MethodCallExpression (n42,n43,n45)
          // ta.getText().length()
  def n47 = [n46]
  def n48 = new ArgumentListExpression (n47)
  def n49 = new MethodCallExpression (n36,n37,n48)
          // ta.setCaretPosition(ta.getText().length())
  def n50 = new ExpressionStatement (n49)
  def n52 = new BlockStatement()
  n52.addStatement(n17)
  n52.addStatement(n35)
  n52.addStatement(n50)
  def n53 = new MethodNode ("print",9,
              ClassHelper.make(void,false),
              n1,ClassNode.EMPTY_ARRAY,n52)
  cn.addMethod(n53)
}

(C) run メソッド内部に try-catch 構造を追加。

[1] で run メソッドの statement を TryCatchStatement に組み込み、
代わりに [2] で run メソッドに TryCatchStatement を setCode する。

void insert_try_catch_to_run (MethodNode mn) {

  def n9 = new Parameter (ClassHelper.make(java.lang.Exception,false),"e")
  def n10 = new VariableExpression ("this")
  def n11 = new ConstantExpression ("println")
  def n12 = new VariableExpression ("e")
  def n13 = [n12]
  def n14 = new ArgumentListExpression (n13)
  def n15 = new MethodCallExpression (n10,n11,n14)
  def n16 = new ExpressionStatement (n15)
  def n17 = new Parameter (
              ClassHelper.make(java.lang.StackTraceElement,false),"ste")
  def n18 = new VariableExpression ("e")
  def n19 = new ConstantExpression ("getStackTrace")
  def n21 = ArgumentListExpression.EMPTY_ARGUMENTS
  def n22 = new MethodCallExpression (n18,n19,n21)
  def n23 = new VariableExpression ("this")
  def n24 = new ConstantExpression ("println")
  def n25 = new ConstantExpression ("\tat ")
  def n26 = new ConstantExpression ("")
  def n27 = [n25,n26]
  def n28 = new VariableExpression ("ste")
  def n29 = [n28]
  def n30 = new GStringExpression ("",n27,n29)
  def n31 = [n30]
  def n32 = new ArgumentListExpression (n31)
  def n33 = new MethodCallExpression (n23,n24,n32)
  def n34 = new ExpressionStatement (n33)
  def n36 = new BlockStatement()
  n36.addStatement(n34)
  def n37 = new ForStatement (n17,n22,n36)
  def n39 = new BlockStatement()
  n39.addStatement(n16)
  n39.addStatement(n37)
  def n40 = new CatchStatement (n9,n39)
  def n42 = new EmptyStatement ()
  def n43 = new TryCatchStatement (mn.getCode(), n42) // [1]
  n43.addCatch(n40)
  def n45 = new BlockStatement()
  n45.addStatement(n43)

  mn.setCode(n45) // [2]
}

実装した MyASTTransformation2.groovy の全体を以下に示す。

// MyASTTransformation2.groovy
package bunji
import org.codehaus.groovy.syntax.*
import org.codehaus.groovy.ast.*
import org.codehaus.groovy.ast.stmt.*
import org.codehaus.groovy.ast.expr.*
import org.codehaus.groovy.control.*
import org.codehaus.groovy.transform.*

@GroovyASTTransformation(phase=CompilePhase.CONVERSION)
public class MyASTTransformation2 implements ASTTransformation {

  public void visit(ASTNode[] astNodes, SourceUnit sourceUnit) {

    def ast = sourceUnit.getAST()
    def cn = pickup_main_class(ast)
  
    if (cn != null) {
      add_field_ta(cn)
      add_method_print(cn)
      add_method_println(cn)
      add_method_openWindow(cn)

      def run_mn = pickup_run_method(cn)
      if (run_mn != null) {
        insert_try_catch_to_method(run_mn)
      }
    }
  }
  
  ClassNode pickup_main_class (ModuleNode mn) {
    String mainClassName = mn.mainClassName
    if (mn.classes != null) {
      for (ClassNode cn in mn.classes) {
        if (cn.name == mainClassName) {
          return cn
        }
      }
    }
    return null
  }
  
  MethodNode pickup_run_method (ClassNode cn) {
    String method_name = "run"
    if (cn.methods != null) {
      for (MethodNode mn in cn.methods) {
        if (mn.name == method_name) {
          if (mn.modifiers & 1) { // public
            // XXX: メソッド名しかチェックしていないので厳密ではない
            // チェックすべき項目:
            //   returnType
            //   引数の型
            //   修飾子
            return mn
          }
        }
      }
    }
    return null
  }
  
  void add_field_ta (ClassNode cn) {
    def n = new FieldNode ("ta", 10,
      ClassHelper.make(javax.swing.JTextArea,false),cn,
      ConstantExpression.NULL)
    cn.addField(n)
    // 元のソースコードで変数taが使われていたら?
  }
  
  void add_method_print (ClassNode cn) {
    def n0 = new Parameter (
               ClassHelper.make(java.lang.Object,false),"x")
    def n1 = [n0] as Parameter[]
    def n2 = new VariableExpression ("ta")
    def n3 = new Token (123,"==",-1,-1)
    def n4 = ConstantExpression.NULL
    def n5 = new BinaryExpression (n2,n3,n4)
    def n6 = new BooleanExpression (n5)
    def n7 = new VariableExpression ("this")
    def n8 = new ConstantExpression ("openWindow")
    def n9 = new ConstantExpression ("console")
    def n10 = [n9]
    def n11 = new ArgumentListExpression (n10)
    def n12 = new MethodCallExpression (n7,n8,n11)
            // this.openWindow("console")
    def n13 = new ExpressionStatement (n12)
    def n15 = new BlockStatement()
    n15.addStatement(n13)
    def n16 = new EmptyStatement ()
    def n17 = new IfStatement (n6,n15,n16)
    def n18 = new VariableExpression ("ta")
    def n19 = new ConstantExpression ("setText")
    def n20 = new VariableExpression ("ta")
    def n21 = new ConstantExpression ("getText")
    def n23 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n24 = new MethodCallExpression (n20,n21,n23)
            // ta.getText()
    def n25 = new Token (200,"+",-1,-1)
    def n26 = new VariableExpression ("x")
    def n27 = new ConstantExpression ("toString")
    def n29 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n30 = new MethodCallExpression (n26,n27,n29)
            // x.toString()
    def n31 = new BinaryExpression (n24,n25,n30)
            // ta.getText()+x.toString()
    def n32 = [n31]
    def n33 = new ArgumentListExpression (n32)
    def n34 = new MethodCallExpression (n18,n19,n33)
            // ta.setText(ta.getText()+x.toString())
    def n35 = new ExpressionStatement (n34)
    def n36 = new VariableExpression ("ta")
    def n37 = new ConstantExpression ("setCaretPosition")
    def n38 = new VariableExpression ("ta")
    def n39 = new ConstantExpression ("getText")
    def n41 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n42 = new MethodCallExpression (n38,n39,n41)
            // ta.getText()
    def n43 = new ConstantExpression ("length")
    def n45 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n46 = new MethodCallExpression (n42,n43,n45)
            // ta.getText().length()
    def n47 = [n46]
    def n48 = new ArgumentListExpression (n47)
    def n49 = new MethodCallExpression (n36,n37,n48)
            // ta.setCaretPosition(ta.getText().length())
    def n50 = new ExpressionStatement (n49)
    def n52 = new BlockStatement()
    n52.addStatement(n17)
    n52.addStatement(n35)
    n52.addStatement(n50)
    def n53 = new MethodNode ("print",9,
                ClassHelper.make(void,false),
                n1,ClassNode.EMPTY_ARRAY,n52)
    cn.addMethod(n53)
  }
  
  void add_method_println (ClassNode cn) {
    def n0 = new Parameter (ClassHelper.make(java.lang.Object,false),"x")
    def n1 = [n0] as Parameter[]
    def n2 = new VariableExpression ("this")
    def n3 = new ConstantExpression ("print")
    def n4 = new VariableExpression ("x")
    def n5 = new ConstantExpression ("toString")
    def n7 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n8 = new MethodCallExpression (n4,n5,n7)
           // x.toString()
    def n9 = new Token (200,"+",-1,-1)
    def n10 = new ConstantExpression ("\n")
    def n11 = new BinaryExpression (n8,n9,n10)
            // x.toString()+"\n"
    def n12 = [n11]
    def n13 = new ArgumentListExpression (n12)
    def n14 = new MethodCallExpression (n2,n3,n13)
            // this.print(x.toString()+"\n")
    def n15 = new ExpressionStatement (n14)
    def n17 = new BlockStatement()
    n17.addStatement(n15)
    def n18 = new MethodNode ("println",9,
                ClassHelper.make(void,false),n1,ClassNode.EMPTY_ARRAY,n17)
    cn.addMethod(n18)
  }
  
  void add_method_openWindow (ClassNode cn) {
    def n0 = new Parameter (ClassHelper.make(String,false),"title")
    def n1 = [n0] as Parameter[]
    def n2 = new VariableExpression ("f")
    def n3 = new Token (100,"=",-1,-1)
    def n4 = new VariableExpression ("title")
    def n5 = [n4]
    def n6 = new ArgumentListExpression (n5)
    def n7 = new ConstructorCallExpression (
               ClassHelper.make(javax.swing.JFrame,false),n6)
    def n8 = new DeclarationExpression (n2,n3,n7)
             // f = new javax.swing.JFrame("title")
    def n9 = new ExpressionStatement (n8)
    def n10 = new VariableExpression ("f")
    def n11 = new ConstantExpression ("setSize")
    def n12 = new ConstantExpression (800)
    def n13 = new ConstantExpression (600)
    def n14 = [n12,n13]
    def n15 = new ArgumentListExpression (n14)
    def n16 = new MethodCallExpression (n10,n11,n15)
              // f.setSize(800,600)
    def n17 = new ExpressionStatement (n16)
    def n18 = new VariableExpression ("f")
    def n19 = new ConstantExpression ("setDefaultCloseOperation")
    def n20 = new VariableExpression ("javax")
    def n21 = new ConstantExpression ("swing")
    def n22 = new PropertyExpression (n20,n21)
    def n23 = new ConstantExpression ("JFrame")
    def n24 = new PropertyExpression (n22,n23)
    def n25 = new ConstantExpression ("EXIT_ON_CLOSE")
    def n26 = new PropertyExpression (n24,n25)
    def n27 = [n26]
    def n28 = new ArgumentListExpression (n27)
    def n29 = new MethodCallExpression (n18,n19,n28)
              // f.setDefaultCloseOperation(javax.swing.JFrame.EXIT_ON_CLOSE)
    def n30 = new ExpressionStatement (n29)
    def n31 = new VariableExpression ("ta")
    def n32 = new Token (100,"=",-1,-1)
    def n33 = new ConstantExpression (100)
    def n34 = new ConstantExpression (50)
    def n35 = [n33,n34]
    def n36 = new ArgumentListExpression (n35)
    def n37 = new ConstructorCallExpression (
                ClassHelper.make(javax.swing.JTextArea,false),n36)
    def n38 = new BinaryExpression (n31,n32,n37)
              // ta = javax.swin.JTestArea(100,50)
    def n39 = new ExpressionStatement (n38)
    def n40 = new VariableExpression ("ta")
    def n41 = new ConstantExpression ("setForeground")
    def n42 = new VariableExpression ("java")
    def n43 = new ConstantExpression ("awt")
    def n44 = new PropertyExpression (n42,n43)
    def n45 = new ConstantExpression ("Color")
    def n46 = new PropertyExpression (n44,n45)
    def n47 = new ConstantExpression ("white")
    def n48 = new PropertyExpression (n46,n47)
    def n49 = [n48]
    def n50 = new ArgumentListExpression (n49)
    def n51 = new MethodCallExpression (n40,n41,n50)
              // ta.setForeground(java.awt.Color.white)
    def n52 = new ExpressionStatement (n51)
    def n53 = new VariableExpression ("ta")
    def n54 = new ConstantExpression ("setBackground")
    def n55 = new VariableExpression ("java")
    def n56 = new ConstantExpression ("awt")
    def n57 = new PropertyExpression (n55,n56)
    def n58 = new ConstantExpression ("Color")
    def n59 = new PropertyExpression (n57,n58)
    def n60 = new ConstantExpression ("black")
    def n61 = new PropertyExpression (n59,n60)
    def n62 = [n61]
    def n63 = new ArgumentListExpression (n62)
    def n64 = new MethodCallExpression (n53,n54,n63)
              // ta.setBackground(java.awt.Color.black)
    def n65 = new ExpressionStatement (n64)
    def n66 = new VariableExpression ("ta")
    def n67 = new ConstantExpression ("setFont")
    def n68 = new ConstantExpression ("メイリオ")
    def n69 = new VariableExpression ("java")
    def n70 = new ConstantExpression ("awt")
    def n71 = new PropertyExpression (n69,n70)
    def n72 = new ConstantExpression ("Font")
    def n73 = new PropertyExpression (n71,n72)
    def n74 = new ConstantExpression ("PLAIN")
    def n75 = new PropertyExpression (n73,n74)
    def n76 = new ConstantExpression (50)
    def n77 = [n68,n75,n76]
    def n78 = new ArgumentListExpression (n77)
    def n79 = new ConstructorCallExpression (
                ClassHelper.make(java.awt.Font,false),n78)
    def n80 = [n79]
    def n81 = new ArgumentListExpression (n80)
    def n82 = new MethodCallExpression (n66,n67,n81)
           // ta.setFont(new java.awt.Font("メイリオ",java.awt.Font.PLAIN,50))
    def n83 = new ExpressionStatement (n82)
    def n84 = new VariableExpression ("f")
    def n85 = new ConstantExpression ("getContentPane")
    def n87 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n88 = new MethodCallExpression (n84,n85,n87)
              // f.getContentPane()
    def n89 = new ConstantExpression ("add")
    def n90 = new VariableExpression ("ta")
    def n91 = [n90]
    def n92 = new ArgumentListExpression (n91)
    def n93 = new ConstructorCallExpression (
                ClassHelper.make(javax.swing.JScrollPane,false),n92)
    def n94 = new VariableExpression ("java")
    def n95 = new ConstantExpression ("awt")
    def n96 = new PropertyExpression (n94,n95)
    def n97 = new ConstantExpression ("BorderLayout")
    def n98 = new PropertyExpression (n96,n97)
    def n99 = new ConstantExpression ("CENTER")
    def n100 = new PropertyExpression (n98,n99)
    def n101 = [n93,n100]
    def n102 = new ArgumentListExpression (n101)
    def n103 = new MethodCallExpression (n88,n89,n102)
         // f.getContentPane()
         // .add(new javax.swing.JScrollPane(ta), java.awt.BorderLayout.CENTER)
    def n104 = new ExpressionStatement (n103)
    def n105 = new VariableExpression ("f")
    def n106 = new ConstantExpression ("setVisible")
    def n107 = ConstantExpression.TRUE
    def n108 = [n107]
    def n109 = new ArgumentListExpression (n108)
    def n110 = new MethodCallExpression (n105,n106,n109)
               // f.setVisible(true)
    def n111 = new ExpressionStatement (n110)
    def n113 = new BlockStatement()
    n113.addStatement(n9)
    n113.addStatement(n17)
    n113.addStatement(n30)
    n113.addStatement(n39)
    n113.addStatement(n52)
    n113.addStatement(n65)
    n113.addStatement(n83)
    n113.addStatement(n104)
    n113.addStatement(n111)
    def n114 = new MethodNode ("openWindow",10,
                 ClassHelper.make(void,false),n1,ClassNode.EMPTY_ARRAY,n113)
    cn.addMethod(n114)
  }
  
  void insert_try_catch_to_method (MethodNode mn) {
  
    def n9 = new Parameter (ClassHelper.make(java.lang.Exception,false),"e")
    def n10 = new VariableExpression ("this")
    def n11 = new ConstantExpression ("println")
    def n12 = new VariableExpression ("e")
    def n13 = [n12]
    def n14 = new ArgumentListExpression (n13)
    def n15 = new MethodCallExpression (n10,n11,n14)
    def n16 = new ExpressionStatement (n15)
    def n17 = new Parameter (
                ClassHelper.make(java.lang.StackTraceElement,false),"ste")
    def n18 = new VariableExpression ("e")
    def n19 = new ConstantExpression ("getStackTrace")
    def n21 = ArgumentListExpression.EMPTY_ARGUMENTS
    def n22 = new MethodCallExpression (n18,n19,n21)
    def n23 = new VariableExpression ("this")
    def n24 = new ConstantExpression ("println")
    def n25 = new ConstantExpression ("\tat ")
    def n26 = new ConstantExpression ("")
    def n27 = [n25,n26]
    def n28 = new VariableExpression ("ste")
    def n29 = [n28]
    def n30 = new GStringExpression ("",n27,n29)
    def n31 = [n30]
    def n32 = new ArgumentListExpression (n31)
    def n33 = new MethodCallExpression (n23,n24,n32)
    def n34 = new ExpressionStatement (n33)
    def n36 = new BlockStatement()
    n36.addStatement(n34)
    def n37 = new ForStatement (n17,n22,n36)
    def n39 = new BlockStatement()
    n39.addStatement(n16)
    n39.addStatement(n37)
    def n40 = new CatchStatement (n9,n39)
    def n42 = new EmptyStatement ()
    def n43 = new TryCatchStatement (mn.getCode(), n42)
    n43.addCatch(n40)
    def n45 = new BlockStatement()
    n45.addStatement(n43)
  
    mn.setCode(n45)
  }
  
} // End of class

次のようなAstTransformation用サービスプロバイダの構成情報ファイルを用意して、META-INF/services/ 配下に設置すればこのAST変換を使うことができる。

ファイル名 org.codehaus.groovy.transform.ASTTransformation
ファイルの中身 bunji.MyASTTransformation2

●表示例1

前回の fib.grrovy

// fib.groovy
int fib (int x) {
  (x<2)?x:(fib(x-2)+fib(x-1))
}
println ((0..10).collect{fib(it)}.join(" "))

これを AST変換した際のスクリーンショット

f:id:bunji2:20150329125149p:plain

●表示例2

日本語の表示。

// hello.groovy
def msg = "Hello, world!"
println msg
println "(重要なことなのでもう一度)"
println msg.split("").join(" ")

f:id:bunji2:20150329124748p:plain

●表示例3

例外の表示。

// e.groovy
println (100/0)

f:id:bunji2:20150329125404p:plain

こんな感じです。

*1:AstBuilderを使えないかみてみたのだが使いにくいように感じた。