Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,5 @@ if(ARNOLDI_BUILD_EXAMPLES AND ORIGINAL_ARPACK_LIB)
endif()
target_link_libraries(bench_arpack PRIVATE ${CMAKE_Fortran_IMPLICIT_LINK_LIBRARIES})

add_test(NAME arnoldi_vs_arpack COMMAND bench_arpack --warmup 3 --reps 5)
set_tests_properties(arnoldi_vs_arpack PROPERTIES LABELS "regression;performance")
message(STATUS "Original ARPACK comparison enabled (lib: ${ORIGINAL_ARPACK_LIB})")
endif()
7 changes: 7 additions & 0 deletions include/arnoldi/comm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
// Users can provide their own Comm type satisfying:
// T allreduce_sum(T local) const; // scalar reduce
// void allreduce_sum(T* data, int n) const; // in-place array reduce
// int rank() const; // local rank in the comm (0 for serial)
//
// rank() is consumed by getv0 to seed larnv differently on each rank so
// that the random starting vector is not tiled across ranks (PARPACK's
// pdlarnv does the same).
//
// See arpack/mpi.hpp for an MPI-aware implementation.

Expand All @@ -21,6 +26,8 @@ namespace arnoldi {

template <typename T>
void allreduce_sum(T*, int) const noexcept {}

int rank() const noexcept { return 0; }
};

} // namespace arnoldi
Expand Down
19 changes: 13 additions & 6 deletions include/arnoldi/detail/getv0.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ namespace arnoldi::detail {
void getv0(detail::BackendRef<Backend> bref, const char* bmat, int itry, bool initv, int n, int j, Scalar* v, int ldv,
Scalar* resid, detail::real_t<Scalar>& rnorm, Scalar* workd, Scalar* workc, int& ierr, OP&& op, BOP&& bop,
const Comm& comm) {
using Real = detail::real_t<Scalar>;
static int iseed[4] = {1, 3, 5, 7};
int msglvl = detail::debug.getv0;
double t0, t1, t2, t3;
using Real = detail::real_t<Scalar>;
static int iseed[4] = {-1, -1, -1, -1};
if (iseed[0] < 0) {
int r = comm.rank();
iseed[0] = (1 + 17 * r) & 0xFFF;
iseed[1] = (3 + 23 * r) & 0xFFF;
iseed[2] = (5 + 31 * r) & 0xFFF;
iseed[3] = ((7 + 2 * r) & 0xFFF) | 1; // last entry must be odd
}
int msglvl = detail::debug.getv0;
double t0, t1, t2, t3;

detail::arscnd(t0);
ierr = 0;
Expand All @@ -29,8 +36,8 @@ namespace arnoldi::detail {
if (itry == 1) {
detail::stats.nopx++;
detail::arscnd(t2);
detail::Ops<Scalar, Backend>::copy(bref, n, resid, 1, workd, 1);
op(workd, &workd[n]);
detail::Ops<Scalar, Backend>::copy(bref, n, resid, 1, &workd[2 * n], 1);
op(&workd[2 * n], &workd[n]);
if (*bmat == 'G') {
detail::arscnd(t3);
detail::stats.mvopx += (t3 - t2);
Expand Down
6 changes: 6 additions & 0 deletions include/arnoldi/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ namespace arnoldi {
void allreduce_sum(std::complex<double>* data, int count) const {
MPI_Allreduce(MPI_IN_PLACE, data, count, MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, mpi_comm);
}

int rank() const {
int r = 0;
MPI_Comm_rank(mpi_comm, &r);
return r;
}
};

} // namespace arnoldi
Expand Down
10 changes: 5 additions & 5 deletions tests/test_getv0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ static void bop_diag(const double* x, double* y) {

static bool test_basic() {
const int n = 8;
std::vector<double> resid(n), workd(2 * n), workc(n), v(n);
std::vector<double> resid(n), workd(3 * n), workc(n), v(n);
double rnorm = 0;
int ierr = 0;
arnoldi::SerialComm comm;
Expand All @@ -43,7 +43,7 @@ static bool test_basic() {
// itry > 1 with bmat='G' takes the else-if branch in getv0.
static bool test_generalized_retry() {
const int n = 8;
std::vector<double> resid(n), workd(2 * n), workc(n), v(n);
std::vector<double> resid(n), workd(3 * n), workc(n), v(n);
double rnorm = 0;
int ierr = 0;
arnoldi::SerialComm comm;
Expand All @@ -64,7 +64,7 @@ static bool test_orthogonalization() {
const int n = 8;
const int j = 2;
std::vector<double> v(n * j, 0.0);
std::vector<double> resid(n), workd(2 * n), workc(n);
std::vector<double> resid(n), workd(3 * n), workc(n);
double rnorm = 0;
int ierr = 0;
arnoldi::SerialComm comm;
Expand All @@ -90,7 +90,7 @@ static bool test_orthogonalization_generalized() {
const int n = 8;
const int j = 2;
std::vector<double> v(n * j, 0.0);
std::vector<double> resid(n), workd(2 * n), workc(n);
std::vector<double> resid(n), workd(3 * n), workc(n);
double rnorm = 0;
int ierr = 0;
arnoldi::SerialComm comm;
Expand All @@ -111,7 +111,7 @@ static bool test_orthogonalization_generalized() {
// msglvl > 3 triggers both the rnorm and full-vector debug prints.
static bool test_debug_output() {
const int n = 8;
std::vector<double> resid(n), workd(2 * n), workc(n), v(n);
std::vector<double> resid(n), workd(3 * n), workc(n), v(n);
double rnorm = 0;
int ierr = 0;
arnoldi::SerialComm comm;
Expand Down
74 changes: 74 additions & 0 deletions tests/test_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,77 @@ TEST_CASE("test_accessors", "[solver]") {
check("accessors: iparam() != nullptr", s.iparam() != nullptr);
check("accessors: ipntr() != nullptr", s.ipntr() != nullptr);
}

// Rank-r operators force the Krylov sequence to saturate at step r+1,
// driving rnorm to zero mid-factorization. That triggers the
// restart-at-step branch in saitr / naitr (the rnorm<=0 path that calls
// getv0 to find a new starting vector orthogonal to the existing
// factorization). The iterative-refinement rescue and the h[j-1] < 0
// sign-fixup are also exercised because the rebuilt vectors land in the
// null space of the operator.
TEST_CASE("test_sym_low_rank_restart", "[solver][branch]") {
const int n = 32, nev = 2, ncv = 12;
arnoldi::Arnoldi<arnoldi::Kind::Sym, double> s("I", n, "LA", nev, ncv);
s.tol(1e-10).maxiter(500);
s.solve([n](const double* x, double* y) {
for (int i = 0; i < n; ++i) y[i] = 0.0;
y[0] = 5.0 * x[0];
y[1] = 4.0 * x[1];
y[2] = 3.0 * x[2];
});

check("sym low-rank: converged", s.converged());
check("sym low-rank: nconv >= nev", s.num_converged() >= nev);

auto r = s.eigenpairs(false);
std::sort(r.values.begin(), r.values.end(), std::greater<double>());
check("sym low-rank: largest is 5.0", std::abs(r.values[0] - 5.0) < 1e-8);
check("sym low-rank: second is 4.0", std::abs(r.values[1] - 4.0) < 1e-8);
}

TEST_CASE("test_nonsym_low_rank_restart", "[solver][branch]") {
const int n = 32, nev = 2, ncv = 12;
arnoldi::Arnoldi<arnoldi::Kind::Nonsym, double> s("I", n, "LM", nev, ncv);
s.tol(1e-10).maxiter(500);
s.solve([n](const double* x, double* y) {
for (int i = 0; i < n; ++i) y[i] = 0.0;
// Upper-triangular 3x3 block: eigenvalues 5, 4, 3 (all real).
y[0] = 5.0 * x[0] + 1.0 * x[1] + 0.5 * x[2];
y[1] = 4.0 * x[1] + 1.0 * x[2];
y[2] = 3.0 * x[2];
});

check("nonsym low-rank: converged", s.converged());
check("nonsym low-rank: nconv >= nev", s.num_converged() >= nev);

auto r = s.eigenpairs(false);
double max_mag = 0;
for (size_t i = 0; i < r.values_re.size(); ++i) {
double mag = std::sqrt(r.values_re[i] * r.values_re[i] +
r.values_im[i] * r.values_im[i]);
max_mag = std::max(max_mag, mag);
}
check("nonsym low-rank: largest |lambda| ~ 5", std::abs(max_mag - 5.0) < 1e-6);
}

TEST_CASE("test_herm_low_rank_restart", "[solver][branch]") {
using cplx = std::complex<double>;
const int n = 32, nev = 2, ncv = 12;
arnoldi::Arnoldi<arnoldi::Kind::Herm, cplx> s("I", n, "LM", nev, ncv);
s.tol(1e-10).maxiter(500);
s.solve([n](const cplx* x, cplx* y) {
for (int i = 0; i < n; ++i) y[i] = cplx(0.0, 0.0);
// Hermitian 3x3 block: real diagonal + conjugate off-diagonals.
const cplx off(0.5, 0.25);
y[0] = 5.0 * x[0] + off * x[1];
y[1] = std::conj(off) * x[0] + 4.0 * x[1] + off * x[2];
y[2] = std::conj(off) * x[1] + 3.0 * x[2];
});

check("herm low-rank: converged", s.converged());
auto r = s.eigenpairs(false);
double max_eig = 0;
for (size_t i = 0; i < r.values.size(); ++i)
max_eig = std::max(max_eig, r.values[i]);
check("herm low-rank: largest eig ~ 5 + |off|^2 contribution", max_eig > 4.9);
}
Loading