最近一直在学习python语言,之前一直使用matlab。看了很多博客都写了该算法的代码,我觉得还不具体,不详细。对于初学者来说,看懂不容易。写了一个NMF算法的python程序供大家参考。参照文献《Algorithms for non-negative matrix factorization》

具体代码如下:

import numpy as np
import torch
import random
import matplotlib.pyplot as plt



def nmf(X,r,maxiter,minError):
    # X=U*V'
    row,col = X.shape
    U = np.around(np.array(np.random.rand(row, r)),5)
    V = np.around(np.array(np.random.rand(col, r)),5)
    obj = []
    for iter in range(maxiter):
        print('-----------------------------')
        print('开始第',iter,'次迭代')
        # update U
        XV = np.dot(X,V)
        UVV = np.dot(U,np.dot(V.T,V))
        U = (U*(XV/np.maximum(UVV,1e-10)))
        # update V
        XU = np.dot(X.T,U)
        VUU = np.dot(V,np.dot(V.T,V))
        V = (V*(XU/np.maximum(VUU,1e-10)))
        d = np.diag(1/np.maximum(np.sqrt(np.sum(V*V,0)),1e-10))
        V = np.dot(V,d)
        
        
        temp = X - np.dot(U,np.transpose(V))
        error = np.sum(temp*temp)
        print('error:',error)
        print('第',iter,'次迭代结束')
        obj.append(error)
        if error<minError:
            break
    return U, V, obj
    
if __name__ =="__main__":
    X = np.random.randn(20, 50)
    X = np.array(np.abs(X))
    #print('X:',X)
    U,V,obj = nmf(X,2,100,0.01)
    x = range(len(obj))
    plt.plot(x, obj)
    plt.show()

更多推荐

NMF算法python源代码