タコの足シミュレーター

タコの足で機械学習できると聞いて、タコの足のシミュレーター書いてみた。
のですが、 あんまりおもしろい感じにならないので 多分 一味足りないと思われます
自分用のメモとして保存

mnistの読み込みモジュールは こちらのを少し改造
https://endoyuta.com/2017/01/12/


import numpy as np
import cv2
import time
import mnist
#GLOBAL CONSTANT
GRAVITY=0.1
DT=1
isDISP=True
SPRING=0.01
ACC=1
DECAY_ACC=0.99
DUMPER=0.1
VMOMENTUM=0.7
XMOMENTUM=0.1
FADE=0.99
DELAY=True
NSEG=8
#値が変
class Seg(object):

def __init__(self,delay=False,acc=1.0,spring=0,dumper=0,gravity=0,
fade=1,
vmomentum=0,xmomentum=0):
self.parent=None
self.child=None
self.acc=acc
self.vmomentum=vmomentum
self.xmomentum=xmomentum
self.spring=spring
self.dumper=dumper
self.gravity=gravity
self.delay=delay
self.fade=fade
self.init()
self.fwd=self.fwd2 if delay else self.fwd1
def init(self):
self.value=0
self.x=0
self.v=0
self.rot_local=0
self.rot_global=0
if self.child:
self.child.init()
def rots(self,rs=None,gs=None):
if rs is None:
rs=[]
gs=[]
rs.append(self.rot_local)
gs.append(self.rot_global)
if self.child :
self.child.rots(rs,gs)

return rs,gs

def appendChild(self,child):
self.child=child
self.child.parent=self

def fwd1(self,x):
#no delay
self.fwd4(x)

if self.child:
self.child.fwd(self.fade*x)
self.x=x
def fwd2(self,x):
#delay
newx=x
x=self.x
self.fwd4(x)
if self.child:
self.child.fwd(self.fade*self.x)
self.x=newx

def fwd4(self,x):

x=x+x*self.x*self.xmomentum
if self.parent:
r0=self.parent.rot_global
else:
r0=0
v1=self.acc*x
vglobal=((r0+self.rot_local+v1*DT)-self.rot_global)/DT
d=vglobal*self.dumper*0.5 #グローバルな移動速度で抵抗が決まる
#d=v1*self.dumper
g=(self.rot_global-r0)*self.gravity #val0の真下に垂れ下がる
s=self.spring*self.rot_local

d=clamp(d)
g=clamp(g)
s=clamp(s)
v=(v1-d-g-s)+self.vmomentum*self.v
v=clamp(v,10)
self.rot_local+=v*DT
self.rot_global=r0+self.rot_local
self.v=v

def clamp(x,mx=1):
a=np.absolute(x)
a=a if a<mx else mx
return np.sign(x)*a
W=300
H=300
cvs=np.zeros((H,W),dtype=np.uint8)

def disp(wname,gs):
x0=int(W/2)
y0=10
thickness=1
color=255
sc=2
cvs[:,:]=0
for r in gs:
x=int(x0+r*sc)
y=int(y0+15)

cv2.line(cvs,(x,y),(x0,y0),color,thickness)
x0,y0=x,y
cv2.imshow(wname,cvs)
cv2.waitKey(1)
time.sleep(0.1)

def drnd():
return 0.9+np.random.randn(1)*0.1

#################################
gravity=GRAVITY*drnd()
delay=DELAY

segroot=Seg(acc=0,delay=False,spring=0,dumper=0,gravity=0,
xmomentum=0,vmomentum=0,fade=1) #empty
seg0=segroot
for i in range(NSEG):
spring=SPRING*drnd()
acc=ACC*drnd()
decay_acc=DECAY_ACC*drnd()
dumper=DUMPER*drnd()
vmomentum=VMOMENTUM*drnd()
xmomentum=XMOMENTUM*(10-i)/10
fade=FADE*drnd()

seg=Seg(acc=acc,delay=delay,spring=spring,dumper=dumper,gravity=gravity,
xmomentum=xmomentum, vmomentum=vmomentum,fade=fade)
acc*=decay_acc
seg0.appendChild(seg)
seg0=seg

ims ,labs= mnist.load_mnist(train=True)
datlen=labs.shape[0]
idx=np.random.randint(datlen*10)%10
img1 = ims[idx].reshape(28, 28)
lab1=labs[idx]
cv2.imshow("mnist",img1)
x,y,w,h = cv2.boundingRect(img1)
dat=img1[y:y+h,x:x+w].flatten()
#dat=img1.flatten()
if 0:
dat=img1.flatten()
idx=np.where(dat>0)[0]
dat=dat[idx[0]:idx[-1]+1]
print("label",lab1)
print("H")
for i,f in enumerate(dat):
f=1 if f>0 else -0.001
segroot.fwd(f)
rs,gs=segroot.rots()
gs=[x/(n+1) for n,x in enumerate(gs)]
if isDISP and i%20==0:
disp("H",gs)
#print("value",segroot.rots())
disp("H",gs)
print("V")
img2=img1.T
dat=img2.flatten()


x,y,w,h = cv2.boundingRect(img2)
dat=img2[y:y+h,x:x+w].flatten()
if 0:
idx=np.where(dat>0)[0]
dat=dat[idx[0]:idx[-1]+1]
segroot.init()
for i,f in enumerate(dat):

f=1 if f>0 else -0.001
segroot.fwd(f)
rs,gs=segroot.rots()
gs=[x/(n+1) for n,x in enumerate(gs)]
if isDISP and i%20==0:
disp("V",gs)
#print("value",segroot.rots())
disp("V",gs)
y=gs
print ("y",y)