python 関数のオーバーロード (型 引数の数編)

pythonにて、 JAVAとかC++風味に、関数の引数の型とか数によって、呼び出す関数の本体を切り替えるためのデコレータ書いてみました

先に、サンプルを兼ねたテストコード(長くなったのでさわりだけ)
Py2 Py3 ともに動作確認済み

同じ addという名前の関数ですが、引数の数や型によって、呼び出す関数本体を切り替えています。
annotationsは、py3風のアノテーションをpy2にて行うためのデコレータ。
(py3でも動作可能)

クラス本体は下のほうにあります
ライセンスはパブリックドメイン扱いでお願いします。


# -*- coding: utf8 -*-
from __future__ import print_function
from typeoverload import TypeOverload,annotations
import sys
if sys.version_info[0]==3:
long=int

if __name__=="__main__":

if True or sys.version_info[0]==2:
#関数オーバーロード デコレータのインスタンス生成
add=TypeOverload("add")
#
#
@add.typeoverload
@annotations(x=int,returns=int) # 引数 1個 整数のみ
def iadd1(x):
y=x+1
return y
#
@add.typeoverload #引数2個 int long float どれでも
@annotations(x=(int,long,float),y=(int,long,float),returns=(int,long,float))
def fadd2(x,y):
z=x+y
return z
#
assert add(100)==101
try:
b=add(2.0) #floatだとエラーになる
raise Exception
except TypeError as e:
assert e.args[0]=="add (2.0,)"
assert add(1,2)==3
assert add(4,5)==9
try:
b=add()  #引数 0個もエラー
raise Exception
except TypeError as e:
assert e.args[0]=="add ()"
assert add(1.0,2.0)==3.0
print ("Py2 done")

#-------------------------

if sys.version_info[0]==3:
txt="""
if __name__=="__main__":

@add.typeoverload
def dummy(x:(int,float),y:(int,float),z:(int,float)) ->float:
return x+y+z
assert add(1.0, 2,3)==6.0
try:        #ついでなので戻り値の型のチェックしてます
add(1,2,3)==6  # 戻り値が整数なのでエラーになるのを確認してる
except TypeError as e:
print(e.args)

print ("Py3 done")

#-------------------------
"""
exec(txt)

typepverload.py

そんなにたいした大きさのクラスじゃないのだけれど、いきなり完成形を見るとよくわかりませんね


# -*- coding: utf8 -*-
""" function override with arg type. PUBLIC DOMAIN """
from __future__ import print_function
#Py2用  Py3の関数 引数アノテーション互換品
def annotations(**kwargs):
__annotations__={}
__annotations__.update(kwargs)
if "returns" in kwargs:
__annotations__["return"]=__annotations__["returns"]
__annotations__.pop("returns")

def _(f):
f.__annotations__= getattr(f,"__annotations__",{})
f.__annotations__.update(__annotations__)
n_an=(len(f.__annotations__.keys())-
(1 if "return" in __annotations__ else 0))
assert (f.__code__.co_argcount==n_an)
for vname in f.__code__.co_varnames[:n_an]:
assert (vname in __annotations__)
return f
return _

def combinations( nested ,fix=[]):
if nested==[]:
yield tuple(fix)
else:
car=nested[0]
cdr=nested[1:]
if not isinstance(car,(list,tuple)):
car=[car]
for i in car:
fix2=fix+[i]
for j in combinations(cdr,fix2):
yield j

#クラス名とかメソッド(デコレータ名)は、微妙にダサいので変えたほうがいいかも
class TypeOverload(dict):
def __init__(self,fname):
self.fname=fname
def typeoverload(self,f):
annotations=f.__annotations__ #check exists
argcount=f.__code__.co_argcount
varnames=f.__code__.co_varnames[:argcount]
annolist=[annotations[v] for v in varnames]
f.__name__=self.fname
for c in combinations(annolist):
self[c]=f

if "return" in annotations:
ret_type=annotations["return"]
def _(*args):
ret=f(*args)
assert isinstance(ret, ret_type),TypeError(
"%s [%s]"%(self.fname,c))
return ret
f_ret=_
else :
f_ret=f_ret
return f_ret

def __call__(self,*args):
argtype=tuple([type(a) for a in args])
argcount=len(args)
#完全一致
if argtype in self:
f=self[argtype]
else:
#サブクラス用
for argtype in self.keys():
if ( argcount==len(argtype)and
all([isinstance(a,t) for a,t in zip(args,argtype)])):
f= self[argtype]
break
else:
raise TypeError("%s %s"%(self.fname,args))
return f(*args)