SympyからC++/Eigenな運動方程式を生成する

Posted on 7/15/2020
このエントリーをはてなブックマークに追加

ちょっと複雑な運動方程式をCやPythonコード(numba)に落とすとき、Sympyが便利です。Sympyのコード生成にちょっと癖があったのでメモ代わりに残します。 例として倒立振子をモデルにした。

モデル

欲しい式

非線形な運動方程式に対して、

\[ \dot{q} = f(q, u) \]

テーラー展開で一次線形近似した式の赤いところと青いところを計算してくれる式が欲しい。

\[ \dot{q} = \color{red}{f(q_0, u_0)} + \color{blue}{\frac{\partial f}{\partial q} \Big\vert_{q_0 u_0}} q +\color{blue}{ \frac{\partial f}{\partial u}\Big\vert_{q_0 u_0}} u \]

以降簡単のため、

\[ \begin{array}{rl} A_c(q_0, u_0) &= \frac{\partial f}{\partial q} \Big\vert_{q_0 u_0} \\ B_c(q_0, u_0) &= \frac{\partial f}{\partial u}\Big\vert_{q_0 u_0} \\ g(q_0, u_0) &= f(q_0, u_0) \end{array} \]

運動方程式

一般化座標 $ q $を使って、カートの駆動力 $ u $, カート質量 $ M $, 振り子長 $ l $, 振り子先端質量 $ m $ について運動方程式を立てます。

\[ q = [x, \theta, \dot{x}, \dot{\theta}] \]

あとの導出方法等は他所様の記事を引用して終わります

(7節まで、8節はスキップ. あと回転座標のとり方が若干違う)

Sympyモデルの準備

さて、Qiitaの記事ではラグランジアンを使っていましたが、別ルートで導出したので、こちらになります。

import sympy as sy
class cart_pole():
    def gen_rhe_sympy(self):
        g = sy.symbols('g')
        l = sy.symbols('l')
        M = sy.symbols('M')
        m = sy.symbols('m')

        q  = sy.symbols('q:{0}'.format(4))
        qd = q[2:4]
        u  = sy.symbols('u')
        
        I = sy.Matrix([[1, 0, 0, 0], 
                      [0, 1, 0, 0], 
                      [0, 0, M + m, l*m*sy.cos(q[1])], 
                      [0, 0, l*m*sy.cos(q[1]), l**2*m]])
        f = sy.Matrix([
                      qd[0], 
                      qd[1],
                      l*m*sy.sin(q[1])*qd[1]**2 + u,
                      -g*l*m*sy.sin(q[1])])
        return sy.simplify(I.inv()*f)
    
    def gen_lmodel(self):
        mat = self.gen_rhe_sympy()
        q = sy.symbols('q:{0}'.format(4))
        u = sy.symbols('u')
        
        A = mat.jacobian(q)
        #B = mat.jacobian(u)
        B = mat.diff(u)
        
        return A,B

Python

Pythonの場合はlamdifyを使って

c = CartPole()
q = sy.symbols('q:{0}'.format(4))
u = sy.symbols('u')
# calc_rhe(q, u) -> np.array
calc_rhe = sy.lambdify([q,u], c.gen_rhe_sympy())

などとしたほうが自然ですが、numbaに放り込みたいなどの事情があることもあるでしょう。そこでリファレンス通りにやってみます

from sympy.printing.pycode import pycode
c = CartPole()
pycode(c.gen_rhe_sympy())

すると、あんまりありがたくない形の出力が出ます。

'ImmutableDenseMatrix([[q2], [q3], [((1/2)*g*m*math.sin(2*q1) + l*m*q3**2*math.sin(q1) + u)/(M + m*math.sin(q1)**2)], [-(g*(M + m)*math.sin(q1) + (l*m*q3**2*math.sin(q1) + u)*math.cos(q1))/(l*(M + m*math.sin(q1)**2))]])'
  • ImmutableDenseMatrixというSympyの型そのままの出力
  • sin/cosがmath? numpyがいいな!
  • q2, q3じゃなくてq[2], q[3]と出して欲しい


さらに、リファレンスには使い方が詳しくない裏メニューのNumPyPrinterを使うとこうなります

from sympy.printing.pycode import NumPyPrinter
c = cart_pole()
NumPyPrinter().doprint(c.gen_rhe_sympy())

出力

'numpy.array([[q2], [q3], [((1/2)*g*m*numpy.sin(2*q1) + l*m*q3**2*numpy.sin(q1) + u)/(M + m*numpy.sin(q1)**2)], [-(g*(M + m)*numpy.sin(q1) + (l*m*q3**2*numpy.sin(q1) + u)*numpy.cos(q1))/(l*(M + m*numpy.sin(q1)**2))]])'

かなり良くなりましたが、

  • q2, q3じゃなくてq[2], q[3]と出して欲しい

が残ります。掘ってみるとSympyの立式で

q = sy.symbols('q:{0}'.format(4))

と定義すると、q[0] が q0 というsympyシンボルとして定義されているところから来ている様子です。

Matrixやindexingを駆使するとできるようなことも書いてある気がしないでもないですが、面倒なのでこうしました。

from sympy.printing.pycode import NumPyPrinter
c = cart_pole()
class NumPyPrinterR(NumPyPrinter):
  def _print_Symbol(self, expr):
      name = super(NumPyPrinter, self)._print_Symbol(expr)

      # 変数名が他とかぶらないのでマッチングせずに頭文字だけで判断
      if name[0] == 'q':
          name = name[0] + '[' + name[1] + ']'
      elif name[0] == 'u':
          name = 'u[0]'
      return name
NumPyPrinterR().doprint(np.squeeze(c.gen_rhe_sympy()))

これでめでたし。

C++11

Pythonで検証したあとは同じようにするだけです。C++ではEigenを使いたいのとImmutableDenseMatrixの処理が定義されていないのでもう一捻りします。

別途

Eigen::Matrix<double, 4, 1> g;
Eigen::Matrix<double, 4, 4> A;
Eigen::Matrix<double, 4, 1> B;

と定義されていることとしましょう

c = cart_pole()
class CXX11CodePrinterR(CXX11CodePrinter):
    def _print_Symbol(self, expr):
        name = super(CXX11CodePrinter, self)._print_Symbol(expr)
        
        if name[0] == 'q':
            name = name[0] + '(' + name[1] + ', 0)'
        elif name[0] == 'u':
            name = 'u(0, 0)'
        return name

print('g << ')
for expr in np.squeeze(c.gen_rhe_sympy()).tolist():
    print(CXX11CodePrinterR().doprint(expr), end=',\n')

A, B = np.squeeze(c.gen_lmodel())

print('\nA << ', end='')
for expr in A:
    print(CXX11CodePrinterR().doprint(expr), end=',\n')

print('\nB << ', end='')
for expr in B:
    print(CXX11CodePrinterR().doprint(expr), end=',\n')

あとはセミコロンとか処理してお終い。

その他

Sympyには他にも

  • julia
  • Rust

などのプリンタも用意されているそうです!

0 件のコメント:

コメントを投稿