TRUST 1.9.8
HPC thermohydraulic platform
Loading...
Searching...
No Matches
Solv_AMGX.cpp
1/****************************************************************************
2* Copyright (c) 2026, CEA
3* All rights reserved.
4*
5* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6* 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7* 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8* 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9*
10* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
11* IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
12* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13*
14*****************************************************************************/
15
16#include <Solv_AMGX.h>
17#include <Matrice_Morse.h>
18#include <ctime>
19#include <communications.h>
20#include <Perf_counters.h>
21#include <chrono>
22#include <Device.h>
23
24Implemente_instanciable_sans_constructeur(Solv_AMGX,"Solv_AMGX",Solv_Petsc);
25// XD amgx petsc amgx NO_BRACE Solver via AmgX API
26// XD attr solveur chaine solveur REQ not_set
27// XD attr option_solveur bloc_lecture option_solveur REQ not_set
28
29
30// printOn
32{
33 return Solv_Petsc::printOn(s);
34}
35// readOn
37{
38 return Solv_Petsc::readOn(is);
39}
40
41#ifdef PETSCKSP_H
42#ifdef TRUST_USE_CUDA
43void Solv_AMGX::initialize()
44{
45 if (amgx_initialized()) return;
46 Nom AmgXmode = "dDDI"; // dDDI:GPU hDDI:CPU (not supported yet by AmgXWrapper)
47 /* Possible de jouer avec simple precision peut etre:
48 1. (lowercase) letter: whether the code will run on the host (h) or device (d).
49 2. (uppercase) letter: whether the matrix precision is float (F) or double (D).
50 3. (uppercase) letter: whether the vector precision is float (F) or double (D).
51 4. (uppercase) letter: whether the index type is 32-bit int (I) or else (not currently supported).
52 typedef enum { AMGX_mode_hDDI, AMGX_mode_hDFI, AMGX_mode_hFFI, AMGX_mode_dDDI, AMGX_mode_dDFI, AMGX_mode_dFFI } AMGX_Mode; */
53 Cerr << "Initializing Amgx and reading the " << config() << " file." << finl;
54 Perf_counters::time_point start = statistics().start_clock();
55 SolveurAmgX_.initialize(PETSC_COMM_WORLD, AmgXmode.getString(), config().getString());
56 Cout << "[AmgX] Time to initialize: " << statistics().compute_time(start) << finl;
57 amgx_initialized_ = true;
58 // MPI_Barrier(PETSC_COMM_WORLD); Voir dans https://github.com/barbagroup/AmgXWrapper/pull/30/commits/1554808a3689f51fa43ab81a35c47a9a1525939a
59}
60
61// Creation des objets
62void Solv_AMGX::Create_objects(const Matrice_Morse& mat_morse, int blocksize)
63{
64 initialize();
65 if (read_matrix())
66 {
67 Cerr << "Read_matrix not supported on GPU yet." << finl;
69 }
70 // Creation de la matrice Petsc (CSR pointeurs dessus)
71 if (MatricePetsc_ != nullptr) MatDestroy(&MatricePetsc_);
72
73 Create_MatricePetsc(MatricePetsc_, mataij_, mat_morse);
74 Perf_counters::time_point start = statistics().start_clock();
75 petscToCSR(MatricePetsc_, SolutionPetsc_, SecondMembrePetsc_);
76 Cout << "[AmgX] Time to create CSR pointers: " << statistics().compute_time(start) << finl;
77 statistics().begin_count(STD_COUNTERS::gpu_copytodevice,statistics().get_last_opened_counter_level()+1);
78 // Use device pointer to enable device consolidation in AmgXWrapper:
79 double* values_device;
80 cudaMalloc((void**)&values_device, nNz * sizeof(double));
81 cudaMemcpy(values_device, values, nNz * sizeof(double), cudaMemcpyHostToDevice);
82 //SolveurAmgX_.setA(nRowsGlobal, nRowsLocal, nNz, rowOffsets, colIndices, values, nullptr);
83 SolveurAmgX_.setA(nRowsGlobal, nRowsLocal, nNz, rowOffsets, colIndices, values_device, nullptr);
84 //cudaFree(values_device);delete[] hostArray;
85 Cout << "[AmgX] Time to set matrix (copy+setup) on GPU: " << statistics().get_time_since_last_open(STD_COUNTERS::gpu_copytodevice) << finl;// Attention balise lue par fiche de validation
86 statistics().end_count(STD_COUNTERS::gpu_copytodevice, 1, static_cast<int>(sizeof(int) * (nRowsLocal + nNz) + sizeof(double) * nNz));
87}
88
89void Solv_AMGX::Create_vectors(const DoubleVect& b)
90{
92}
93
94void Solv_AMGX::Update_vectors(const DoubleVect& secmem, DoubleVect& solution)
95{
97 if (reorder_matrix_) Process::exit("Option not supported yet for AmgX.");
98}
99
100void Solv_AMGX::Update_solution(DoubleVect& solution)
101{
103}
104
105// Fonction de conversion Petsc ->CSR
106PetscErrorCode Solv_AMGX::petscToCSR(Mat& A, Vec& lhs_petsc, Vec& rhs_petsc)
107{
108 PetscFunctionBeginUser;
109 PetscBool done;
110 MatType type;
111
112 // Get the Mat type
113 PetscErrorCode ierr = MatGetType(A, &type);
114 CHKERRQ(ierr);
115
116 // Check whether the Mat type is supported
117 if (std::strcmp(type, MATSEQAIJ) == 0) // sequential AIJ
118 {
119 // Make localA point to the same memory space as A does
120 localA = A;
121 }
122 else if (std::strcmp(type, MATMPIAIJ) == 0)
123 {
124 // Get local matrix from redistributed matrix
125 if (localA!=nullptr) MatDestroy(&localA); // Suite accroissement memoire !
126 ierr = MatMPIAIJGetLocalMat(A, MAT_INITIAL_MATRIX, &localA);
127 CHKERRQ(ierr);
128 }
129 else
130 {
131 SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG,"Mat type %s is not supported!\n", type);
132 }
133
134 // Get row and column indices in compressed row format
135 ierr = MatGetRowIJ(localA, 0, PETSC_FALSE, PETSC_FALSE, &nRowsLocal, &rowOffsets, &colIndices, &done);
136 if (done==PETSC_FALSE) Process::abort();
137 CHKERRQ(ierr);
138
139 ierr = MatSeqAIJGetArray(localA, &values);
140 CHKERRQ(ierr);
141
142 // Get pointers to the raw data of local vectors
143 /*
144 ierr = VecGetArray(lhs_petsc, &lhs);
145 CHKERRQ(ierr);
146 ierr = VecGetArray(rhs_petsc, &rhs);
147 CHKERRQ(ierr);
148 */
149
150 // Calculate the number of rows in nRowsGlobal
151 ierr = MPI_Allreduce(&nRowsLocal, &nRowsGlobal, 1, MPI_INT, MPI_SUM, PETSC_COMM_WORLD);
152 CHKERRQ(ierr);
153
154 // Store the number of non-zeros
155 nNz = rowOffsets[nRowsLocal];
156 PetscFunctionReturn(0);
157}
158
159void Solv_AMGX::Update_matrix(Mat& MatricePetsc, const Matrice_Morse& mat_morse)
160{
161 // La matrice CSR de PETSc a ete mise a jour dans check_stencil
162 statistics().begin_count(STD_COUNTERS::gpu_copytodevice,statistics().get_last_opened_counter_level()+1);
163 SolveurAmgX_.updateA(nRowsLocal, nNz, values); // ToDo erreur valgrind au premier appel de updateA...
164 Cout << "[AmgX] Time to update matrix (copy+resetup) on GPU: " << statistics().get_time_since_last_open(STD_COUNTERS::gpu_copytodevice) << finl; // Attention balise lue par fiche de validation
165 statistics().end_count(STD_COUNTERS::gpu_copytodevice, 1, sizeof(double)*nNz);
166}
167
168// Check and return true if new stencil
169bool Solv_AMGX::detect_new_stencil(const Matrice_Morse& mat_morse)
170{
171 int num_devices = 0;
172 cudaGetDeviceCount(&num_devices);
173 if (num_devices>1)
174 {
175 // Exemple cas PETSC_AMGX en parallele:
176 Cout << "[AmgX] In Solv_AMGX::check_stencil same_stencil=true cause bug in SolveurAmgX_::updateA on multi-GPU (ToDo: fix by switching to CSR interface)!" << finl;
177 return true;
178 }
179 Perf_counters::time_point start = statistics().start_clock();
180 // Parcours de la matrice_morse (qui peut contenir des 0 et qui n'est pas triee par colonnes croissantes)
181 // si matrice sur le GPU deja construite (qui est sans 0 et qui est triee par colonnes croissantes):
182 const auto& tab1 = mat_morse.get_tab1();
183 const auto& tab2 = mat_morse.get_tab2();
184 const auto& coeff = mat_morse.get_coeff();
185 const auto& renum_array = renum_;
186 int new_stencil = 0, RowLocal = 0;
187 Journal() << "Provisoire: nb_rows_=" << nb_rows_ << " nb_rows_tot_=" << nb_rows_tot_ << finl;
188 for (int i = 0; i < tab1.size_array() - 1; i++)
189 {
190 if (items_to_keep_[i])
191 {
192 int nnz_row = 0;
193 for (auto k = tab1(i) - 1; k < tab1(i + 1) - 1; k++)
194 if (coeff(k) != 0) nnz_row++;
195 if (nnz_row != rowOffsets[RowLocal + 1] - rowOffsets[RowLocal])
196 {
197 Journal() << "Provisoire: Number of non-zero on GPU will change from " << rowOffsets[RowLocal + 1] - rowOffsets[RowLocal] << " to " << nnz_row << " on row " << RowLocal << finl;
198 new_stencil = 1;
199 break;
200 }
201 else
202 {
203 for (auto k = tab1(i) - 1; k < tab1(i + 1) - 1; k++)
204 {
205 if (coeff(k) != 0)
206 {
207 bool found = false;
208 auto col = renum_array[tab2(k) - 1];
209 // Boucle pour voir si le coeff est sur le GPU:
210 auto RowGlobal = decalage_local_global_+RowLocal;
211 for (auto kk = rowOffsets[RowLocal]; kk < rowOffsets[RowLocal + 1]; kk++)
212 {
213 if (colIndices[kk] == col)
214 {
215 values[kk] = coeff(k); // On met a jour le coefficient
216 found = true;
217 break;
218 }
219 }
220 if (!found)
221 {
222 Journal() << "Provisoire: mat_morse(" << RowGlobal << "," << col << ")!=0 new on GPU " << finl;
223 new_stencil = 1;
224 break;
225 }
226 }
227 }
228 }
229 RowLocal++;
230 }
231 }
232 new_stencil = mp_max(new_stencil);
233 Cout << "[AmgX] Time to check stencil: " << statistics().compute_time(start) << finl;
234 return new_stencil;
235}
236
237// Resolution
238int Solv_AMGX::solve(ArrOfDouble& residu)
239{
240 mapToDevice(rhs_);
241 computeOnTheDevice(lhs_);
242 statistics().begin_count(STD_COUNTERS::gpu_library,statistics().get_last_opened_counter_level()+1);
243 // Offer device pointers to AmgX:
244 SolveurAmgX_.solve(addrOnDevice(lhs_), addrOnDevice(rhs_), static_cast<int>(nRowsLocal), seuil_);
245 statistics().end_count(STD_COUNTERS::gpu_library);
246 Cout << "[AmgX] Time to solve system on GPU: " << statistics().get_total_time(STD_COUNTERS::gpu_library) << finl;
247 return nbiter(residu);
248}
249
250int Solv_AMGX::nbiter(ArrOfDouble& residu)
251{
252 int nbiter = -1;
253 SolveurAmgX_.getIters(nbiter);
254 // Bug AmgX, seul le process 0 renvoie correctement nbiter...
255 envoyer_broadcast(nbiter, 0);
256 if (limpr() > -1)
257 {
258 SolveurAmgX_.getResidual(0, residu(0));
259 if (nbiter>0) SolveurAmgX_.getResidual(nbiter - 1, residu(nbiter));
260 }
261 return nbiter;
262}
263#endif
264#endif
Class defining operators and methods for all reading operation in an input flow (file,...
Definition Entree.h:42
Classe Matrice_Morse Represente une matrice M (creuse), non necessairement carree.
const auto & get_tab2() const
const auto & get_tab1() const
const auto & get_coeff() const
const std::string & getString() const
Definition Nom.h:92
virtual Entree & readOn(Entree &)
Lecture d'un Objet_U sur un flot d'entree Methode a surcharger.
Definition Objet_U.cpp:293
virtual Sortie & printOn(Sortie &) const
Ecriture de l'objet sur un flot de sortie Methode a surcharger.
Definition Objet_U.cpp:282
std::chrono::time_point< clock > time_point
void begin_count(const STD_COUNTERS &std_cnt, int counter_lvl=-100000)
double get_time_since_last_open(const STD_COUNTERS &name)
Give as a double the time (in second) elapsed in the operation tracked by the standard counter call n...
double get_total_time(const STD_COUNTERS &name)
Give as a double the total time (in second) elapsed in the operation tracked by the standard counter ...
double compute_time(time_point start)
return time since start in seconds
time_point start_clock()
Start a clock, return a time_point, not a double.
void end_count(const std::string &custom_count_name, int count_increment=1, long int quantity_increment=0)
End the count of a counter and update the counter values.
static double mp_max(double)
Definition Process.cpp:376
static Sortie & Journal(int message_level=0)
Renvoie un objet statique de type Sortie qui sert de journal d'evenements.
Definition Process.cpp:588
static void abort()
Routine de sortie de Trio-U sur une erreur abort().
Definition Process.cpp:570
static void exit(int exit_code=-1)
Routine de sortie de TRUST dans une region Kokkos.
Definition Process.cpp:455
ArrOfDouble lhs_
public_for_cuda void Update_lhs_rhs(const DoubleVect &b, DoubleVect &x)
ArrOfDouble rhs_
void Create_lhs_rhs_onDevice()
void Update_solution(DoubleVect &x)
bool amgx_initialized_
Definition Solv_Petsc.h:205
bool reorder_matrix_
Definition Solv_Petsc.h:213
const Nom config()
ArrOfBit items_to_keep_
Definition Solv_tools.h:31
trustIdType nb_rows_tot_
Definition Solv_tools.h:35
TIDTab renum_
Definition Solv_tools.h:29
trustIdType decalage_local_global_
Definition Solv_tools.h:36
bool read_matrix() const
int limpr() const
Classe de base des flux de sortie.
Definition Sortie.h:52