diff --git a/README.md b/README.md index 44a15a8..3fa813c 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Project description --- -## 🚀 Quick Start +## Quick Start ### Setup @@ -19,3 +19,13 @@ Project description - **Test**: `uv run pytest` - **Format**: `uv run ruff format .` - **Lint**: `uv run ruff check . --fix` + +## References + +[1] Susnjara et al., Accelerated filtering on graphs using Lanczos method. + +[2] https://epfl-lts2.github.io/gspbox-html/ + +## Online guides + +https://every-algorithm.github.io/2024/05/23/lanczos_algorithm.html diff --git a/src/afgl/lanczos.py b/src/afgl/lanczos.py index 9bfe2b3..5cd2305 100644 --- a/src/afgl/lanczos.py +++ b/src/afgl/lanczos.py @@ -1,6 +1,12 @@ import numpy as np +import numpy.linalg as LA """ +Arguments +L Real valued NxN symmetric matrix +s vector of size N +M natural number indicating basis size + Returns ------- V : ndarray @@ -13,16 +19,22 @@ beta : ndarray def lanczos(L, s, M): + N = len(s) alp = np.zeros(M) - beta = np.zeros(M) - V = np.zeros(M) - V[0] = s / np.norm(s) - - for j in range(0, M): - w = L * V - alp[j] = V[j] @ w - Vtmp = w - V[j] * alp[j] - if j > 1: - Vtmp = Vtmp - V[j - 1] * beta[j - 1] - beta[j] = np.norm(Vtmp) - V[j + 1] = Vtmp / beta[j] + beta = np.zeros(M - 1) + V = np.zeros((N, M)) + V[:, 0] = s / LA.norm(s) + + for j in range(M): + w = L @ V[:, j] + alp[j] = np.dot(V[:, j], w) + + v_tilde = w - V[:, j] * alp[j] + if j > 0: + v_tilde = v_tilde - V[:, j - 1] * beta[j - 1] + + if j < M - 1: + beta[j] = LA.norm(v_tilde) + V[:, j + 1] = v_tilde / beta[j] + + return [V, alp, beta] diff --git a/tests/test_main.py b/tests/test_main.py index e69de29..29a97e1 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -0,0 +1,27 @@ +import numpy as np +import numpy.linalg as LA +from afgl.lanczos import lanczos + +""" +Todo: better test case +""" + + +def test_lanczos_return_correct_solution(): + N = 6 + M = 4 + + A = np.random.randint(1, 10, size=(N, N)) + L = (A + A.T) / 2 + s = np.random.randint(1, 10, N) + [V, alp, beta] = lanczos(L, s, M) + + T = np.diag(alp) + np.diag(beta, -1) + np.diag(beta, 1) + + x = LA.solve(A, s) + e_1 = np.zeros(M) + e_1[0] = 1 + y = (LA.inv(T) @ e_1) * LA.norm(s) + x_lanczos = V @ y + + assert LA.norm(x - x_lanczos) < 1e-3