$f(x)=x$
测试数据集加高斯均匀扰动--rd.uniform(-4,4)*rd.gauss(0,4)
import random as rd
import cma
import math
from multiprocessing import Pool
%pylab
from __future__ import division
plt.rc('figure', figsize=(16, 9))
#构造一维数据集
def linefunc(x):
return np.poly1d([1,0])(x)+rd.uniform(-4,4)*rd.gauss(0,5)
E=math.exp
DIM=4
PI=math.pi
X=range(100)
Y=map(linefunc,X)
RANGE=max(Y)-min(Y)
ALPHA=0.01
BETA=PI/18.0
def realLineFunc(param):
def f(x):
if x<50:
return np.poly1d([param[0],param[1]])(x)
return np.poly1d([param[2],param[3]])(x)
return f
def fittingFunc(param):
return np.poly1d(param)
def mse(x,y,param):
f1=fittingFunc(param[:2])
f2=fittingFunc(param[2:4])
s=0
for i in range(50):
s+=(y[i]-f1(x[i]))**2
for i in range(50,100):
s+=(y[i]-f2(x[i]))**2
return math.sqrt(s/100.0)
#间断点
def evalfunc1(param):
f1=fittingFunc(param[:2])
f2=fittingFunc(param[2:4])
s=mse(X,Y,param)/RANGE
a=abs(f1(50)-f2(50))/RANGE
return s+E(a-ALPHA)-1
#间断点一阶导数
def evalfunc2(param):
s=mse(X,Y,param)/RANGE
b=abs(math.atan(param[0])-math.atan(param[2]))
return s+(E((b-BETA)/(10*E(1)))-1)
#考虑以上两种
def evalfunc3(param):
f1=fittingFunc(param[:2])
f2=fittingFunc(param[2:4])
s=mse(X,Y,param)/RANGE
a=abs(f1(50)-f2(50))/RANGE
b=abs(math.atan(param[0])-math.atan(param[2]))
return s+(E((b-BETA)/(10*math.e))-1)+E(a-ALPHA)-1
#不考虑其他因素
def evalfunc(param):
s=mse(X,Y,param)
return s/RANGE
def cmaUser(func):
pool=Pool()
es = cma.CMAEvolutionStrategy(DIM * [1], 0.3,{'popsize':15})
while not es.stop() :
solutions = es.ask()
es.tell(solutions,pool.map(func,solutions))
print 'eval value:%s'%es.result()[1]
return es.result()[0]
def draw(title):
global X
res=cmaUser(evalfunc)
res1=cmaUser(evalfunc1)
res2=cmaUser(evalfunc2)
res3=cmaUser(evalfunc3)
plt.figure(1)
plt.plot(X,Y,'b.',alpha=0.6,label="measure point")
X1=X+[49.99]
X1.sort()
plt.plot(X1,map(realLineFunc(res),X1),c="red",lw=2,ls="-",alpha=0.7,label="interval")
plt.plot(X1,map(realLineFunc(res1),X1),c="blue",lw=2,ls="-",alpha=0.7,label="continuity1")
plt.plot(X1,map(realLineFunc(res2),X1),c="green",lw=2,ls="-",alpha=0.7,label="continuity2")
plt.plot(X1,map(realLineFunc(res3),X1),c="black",lw=2,ls="--",alpha=0.7,label="continuity3")
plt.legend(loc='best')
plt.xlabel('x')
plt.ylabel('y')
plt.title(title)
plt.show()
plt.figure(2)
plt.plot(X,Y,'b.',alpha=0.6,label="measure point")
plt.plot(X1,map(realLineFunc([1,0,2,-40]),X1),c="cyan",lw=2,ls="-",alpha=0.7,label="real function")
plt.plot(X1,map(realLineFunc(res3),X1),c="black",lw=2,ls="--",alpha=0.7,label="continuity3")
draw('fitting of interval line')
#构造数据集
E=math.exp
DIM=14
PI=math.pi
def linefunc2(x):
return 100-x**2++rd.uniform(-3,3)*rd.gauss(0,4)
X=np.linspace(-7,7,301)
Y=map(linefunc2,X)
RANGE=max(Y)-min(Y)
ALPHA=0.01
BETA=PI/4.0
scatter(X,Y)
划分七段用CMA线性拟合14个参数
def realLineFunc(param):
def f(x):
if x<-5:
return np.poly1d([param[0],param[1]])(x)
elif x<-3:
return np.poly1d([param[2],param[3]])(x)
elif x<-1:
return np.poly1d([param[4],param[5]])(x)
elif x<1:
return np.poly1d([param[6],param[7]])(x)
elif x<3:
return np.poly1d([param[8],param[9]])(x)
elif x<5:
return np.poly1d([param[10],param[11]])(x)
else:
return np.poly1d([param[12],param[13]])(x)
return f
def mse(x,y,param):
s=0
for j in range(0,14,2):
for i in range(int(j/2)*10,int(j/2+1)*10,1):
s+=(y[i]-fittingFunc(param[j:j+2])(x[i]))**2
return math.sqrt(s/200.0)
#间断点
def evalfunc1(param):
s=mse(X,Y,param)/RANGE
a=0
for j in range(2,14,2):
a+=E(abs(fittingFunc(param[j-2:j])(X[(j/2)*10])-fittingFunc(param[j:j+2])(X[(j/2)*10]))/RANGE-ALPHA)-1
return s+a
#间断点一阶导数
def evalfunc2(param):
s=mse(X,Y,param)/RANGE
b=0
for j in range(2,14,2):
b+=(E((abs(math.atan(param[j-2])-math.atan(param[j]))-BETA))-1)/(100*math.e)
return s+b
#考虑两种
def evalfunc3(param):
s=mse(X,Y,param)/RANGE
a=0
for j in range(2,14,2):
a+=E(abs(fittingFunc(param[j-2:j])(X[(j/2)*10])-fittingFunc(param[j:j+2])(X[(j/2)*10]))/RANGE-ALPHA)-1
b=0
for j in range(2,14,2):
b+=(E((abs(math.atan(param[j-2])-math.atan(param[j]))-BETA)/(100*math.e))-1)
return s+a+b
#不考虑其他因素
def evalfunc(param):
s=mse(X,Y,param)/RANGE
return s
def draw(title):
global X
res=cmaUser(evalfunc)
res1=cmaUser(evalfunc1)
res2=cmaUser(evalfunc2)
res3=cmaUser(evalfunc3)
plt.figure(1)
plt.plot(X,Y,'b.',alpha=0.6,label="measure point")
X1=np.insert(X,0,[-5.01,-3.01,-1.01,0.99,2.99,4.99])
X1.sort()
plt.plot(X1,map(realLineFunc(res),X1),c="red",lw=2,ls="-",alpha=0.7,label="interval")
plt.plot(X1,map(realLineFunc(res1),X1),c="blue",lw=2,ls="-",alpha=0.7,label="continuity1")
plt.plot(X1,map(realLineFunc(res2),X1),c="green",lw=2,ls="-",alpha=0.7,label="continuity2")
plt.plot(X1,map(realLineFunc(res3),X1),c="black",lw=2,ls="--",alpha=0.7,label="continuity3")
plt.legend(loc='best')
plt.xlabel('x')
plt.ylabel('y')
plt.title(title)
plt.show()
plt.figure(2)
plt.plot(X,Y,'b.',alpha=0.6,label="measure point")
plt.plot(X1,map(lambda x:100-x**2,X1),c="cyan",lw=2,ls="-",alpha=0.7,label="real function")
plt.plot(X1,map(realLineFunc(res3),X1),c="black",lw=2,ls="--",alpha=0.7,label="continuity3")
plt.legend(loc='best')
plt.title("Compare best fitting curve with real curve ")
plt.xlabel('x')
plt.ylabel('y')
plt.show()
draw('fitting of interval quadratic curve with line')
interval | continuity1 | continuity2 | continuity3 |
---|---|---|---|
0.0171697128166 | -0.0416912639382 | 0.0146455354793 | -0.0480936701006 |
设置BETA=PI/4.0(45度)
利用二次曲线$y=ax^2+bx+c$进行分段拟合
def realLineFunc(param):
def f(x):
if x<-5:
return np.poly1d([param[0],param[1],param[2]])(x)
elif x<-3:
return np.poly1d([param[3],param[4],param[5]])(x)
elif x<-1:
return np.poly1d([param[6],param[7],param[8]])(x)
elif x<1:
return np.poly1d([param[9],param[10],param[11]])(x)
elif x<3:
return np.poly1d([param[12],param[13],param[14]])(x)
elif x<5:
return np.poly1d([param[15],param[16],param[17]])(x)
else:
return np.poly1d([param[18],param[19],param[20]])(x)
return f
def mse(x,y,param):
s=0
for j in range(0,21,3):
for i in range(int(j/3)*10,int(j/3+1)*10,1):
s+=(y[i]-fittingFunc(param[j:j+3])(x[i]))**2
return math.sqrt(s/200.0)
#间断点
def evalfunc1(param):
s=mse(X,Y,param)/RANGE
a=0
for j in range(3,21,3):
a+=E(abs(fittingFunc(param[j-3:j])(X[(j/3)*10])-fittingFunc(param[j:j+3])(X[(j/3)*10]))/RANGE-ALPHA)-1
return s+a
#间断点一阶导数
def evalfunc2(param):
s=mse(X,Y,param)/RANGE
b=0
for j in range(3,21,3):
b+=(E((abs(math.atan(2*param[j-3]+param[j-2])-math.atan(2*param[j])+param(j+1))-BETA))-1)/(10*math.e)
return s+b
#考虑两种
def evalfunc3(param):
s=mse(X,Y,param)/RANGE
a=0
for j in range(3,21,3):
a+=E(abs(fittingFunc(param[j-3:j])(X[(j/3)*10])-fittingFunc(param[j:j+3])(X[(j/3)*10]))/RANGE-ALPHA)-1
b=0
for j in range(3,21,3):
b+=(E((abs(math.atan(2*param[j-3]+param[j-2])-math.atan(2*param[j])+param(j+1))-BETA))-1)/(10*math.e)
return s+a+b
#不考虑其他因素
def evalfunc(param):
s=mse(X,Y,param)/RANGE
return s
def draw(title):
global X
res=cmaUser(evalfunc)
res1=cmaUser(evalfunc1)
res2=cmaUser(evalfunc2)
res3=cmaUser(evalfunc3)
plt.figure(1)
plt.plot(X,Y,'b.',alpha=0.6,label="measure point")
X1=np.insert(X,0,[-5.01,-3.01,-1.01,0.99,2.99,4.99])
X1.sort()
plt.plot(X1,map(realLineFunc(res),X1),c="red",lw=2,ls="-",alpha=0.7,label="interval")
plt.plot(X1,map(realLineFunc(res1),X1),c="blue",lw=2,ls="-",alpha=0.7,label="continuity1")
plt.plot(X1,map(realLineFunc(res2),X1),c="green",lw=2,ls="-",alpha=0.7,label="continuity2")
plt.plot(X1,map(realLineFunc(res3),X1),c="black",lw=2,ls="--",alpha=0.7,label="continuity3")
plt.legend(loc='best')
plt.xlabel('x')
plt.ylabel('y')
plt.title(title)
plt.show()
plt.figure(2)
plt.plot(X,Y,'b.',alpha=0.6,label="measure point")
plt.plot(X1,map(realLineFunc([1,0,2,-40]),X1),c="cyan",lw=2,ls="-",alpha=0.7,label="real function")
plt.plot(X1,map(realLineFunc(res3),X1),c="black",lw=2,ls="--",alpha=0.7,label="continuity3")
plt.legend(loc='best')
plt.title("Compare best fitting curve with real curve ")
plt.xlabel('x')
plt.ylabel('y')
plt.show()
draw('fitting of interval quadratic curve with curve')
interval | continuity1 | continuity2 | continuity3 |
---|---|---|---|
0.0166098265438 | -0.0418342446941 | -0.0174967646088 | -0.0757380809252 |
draw('fitting of interval sin curve with curve')
interval | continuity1 | continuity2 | continuity3 |
---|---|---|---|
0.0521840576532 | -0.00661281531701 | 0.0192881944499 | -0.0394042468057 |
draw('fitting of interval quadratic curve with line')
interval | continuity1 | continuity2 | continuity3 |
---|---|---|---|
0.0556539053525 | -0.00249753416862 | 0.0600579725391 | -0.0044681390571 |
设置BETA=PI/4.0(45度)