最近一直在学习python语言,之前一直使用matlab。看了很多博客都写了该算法的代码,我觉得还不具体,不详细。对于初学者来说,看懂不容易。写了一个NMF算法的python程序供大家参考。参照文献《Algorithms for non-negative matrix factorization》
具体代码如下:
import numpy as np
import torch
import random
import matplotlib.pyplot as plt
def nmf(X,r,maxiter,minError):
# X=U*V'
row,col = X.shape
U = np.around(np.array(np.random.rand(row, r)),5)
V = np.around(np.array(np.random.rand(col, r)),5)
obj = []
for iter in range(maxiter):
print('-----------------------------')
print('开始第',iter,'次迭代')
# update U
XV = np.dot(X,V)
UVV = np.dot(U,np.dot(V.T,V))
U = (U*(XV/np.maximum(UVV,1e-10)))
# update V
XU = np.dot(X.T,U)
VUU = np.dot(V,np.dot(V.T,V))
V = (V*(XU/np.maximum(VUU,1e-10)))
d = np.diag(1/np.maximum(np.sqrt(np.sum(V*V,0)),1e-10))
V = np.dot(V,d)
temp = X - np.dot(U,np.transpose(V))
error = np.sum(temp*temp)
print('error:',error)
print('第',iter,'次迭代结束')
obj.append(error)
if error<minError:
break
return U, V, obj
if __name__ =="__main__":
X = np.random.randn(20, 50)
X = np.array(np.abs(X))
#print('X:',X)
U,V,obj = nmf(X,2,100,0.01)
x = range(len(obj))
plt.plot(x, obj)
plt.show()
更多推荐
NMF算法python源代码
发布评论