Ginkgo Generated from branch based on main. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
 
Loading...
Searching...
No Matches
batch_lin_op.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
6#define GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13#include <ginkgo/core/base/abstract_factory.hpp>
14#include <ginkgo/core/base/batch_multi_vector.hpp>
15#include <ginkgo/core/base/dim.hpp>
16#include <ginkgo/core/base/exception_helpers.hpp>
17#include <ginkgo/core/base/math.hpp>
18#include <ginkgo/core/base/matrix_assembly_data.hpp>
19#include <ginkgo/core/base/matrix_data.hpp>
20#include <ginkgo/core/base/polymorphic_object.hpp>
21#include <ginkgo/core/base/types.hpp>
22#include <ginkgo/core/base/utils.hpp>
23#include <ginkgo/core/log/logger.hpp>
24
25
26namespace gko {
27namespace batch {
28
29
59class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
60public:
67 {
69 }
70
77
83 const batch_dim<2>& get_size() const noexcept { return size_; }
84
90 template <typename ValueType>
93 {
94 GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
95 GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
96
97 GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
98 GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
99 GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
100 }
101
107 template <typename ValueType>
109 const MultiVector<ValueType>* b,
110 const MultiVector<ValueType>* beta,
111 MultiVector<ValueType>* x) const
112 {
113 GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
114 GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
115
116 GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
117 GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
118 GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
119 GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(),
120 gko::dim<2>(1, 1));
121 GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1));
122 }
123
124protected:
130 void set_size(const batch_dim<2>& size) { size_ = size; }
131
138 explicit BatchLinOp(std::shared_ptr<const Executor> exec,
139 const batch_dim<2>& batch_size)
140 : EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
141 {}
142
151 explicit BatchLinOp(std::shared_ptr<const Executor> exec,
152 const size_type num_batch_items = 0,
153 const dim<2>& common_size = dim<2>{})
154 : BatchLinOp{std::move(exec),
155 num_batch_items > 0
156 ? batch_dim<2>(num_batch_items, common_size)
157 : batch_dim<2>{}}
158 {}
159
160private:
161 batch_dim<2> size_{};
162};
163
164
195 : public AbstractFactory<BatchLinOp, std::shared_ptr<const BatchLinOp>> {
196public:
197 using AbstractFactory<BatchLinOp,
198 std::shared_ptr<const BatchLinOp>>::AbstractFactory;
199
200 std::unique_ptr<BatchLinOp> generate(
201 std::shared_ptr<const BatchLinOp> input) const
202 {
203 this->template log<
204 gko::log::Logger::batch_linop_factory_generate_started>(
205 this, input.get());
206 const auto exec = this->get_executor();
207 std::unique_ptr<BatchLinOp> generated;
208 if (input->get_executor() == exec) {
209 generated = this->AbstractFactory::generate(input);
210 } else {
211 generated =
212 this->AbstractFactory::generate(gko::clone(exec, input));
213 }
214 this->template log<
215 gko::log::Logger::batch_linop_factory_generate_completed>(
216 this, input.get(), generated.get());
217 return generated;
218 }
219};
220
221
249template <typename ConcreteBatchLinOp, typename PolymorphicBase = BatchLinOp>
251 : public EnablePolymorphicObject<ConcreteBatchLinOp, PolymorphicBase>,
252 public EnablePolymorphicAssignment<ConcreteBatchLinOp> {
253public:
254 using EnablePolymorphicObject<ConcreteBatchLinOp,
255 PolymorphicBase>::EnablePolymorphicObject;
256};
257
258
275template <typename ConcreteFactory, typename ConcreteBatchLinOp,
276 typename ParametersType, typename PolymorphicBase = BatchLinOpFactory>
278 EnableDefaultFactory<ConcreteFactory, ConcreteBatchLinOp, ParametersType,
279 PolymorphicBase>;
280
281
358#define GKO_ENABLE_BATCH_LIN_OP_FACTORY(_batch_lin_op, _parameters_name, \
359 _factory_name) \
360public: \
361 const _parameters_name##_type& get_##_parameters_name() const \
362 { \
363 return _parameters_name##_; \
364 } \
365 \
366 class _factory_name \
367 : public ::gko::batch::EnableDefaultBatchLinOpFactory< \
368 _factory_name, _batch_lin_op, _parameters_name##_type> { \
369 friend class ::gko::EnablePolymorphicObject< \
370 _factory_name, ::gko::batch::BatchLinOpFactory>; \
371 friend class ::gko::enable_parameters_type<_parameters_name##_type, \
372 _factory_name>; \
373 explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec) \
374 : ::gko::batch::EnableDefaultBatchLinOpFactory< \
375 _factory_name, _batch_lin_op, _parameters_name##_type>( \
376 std::move(exec)) \
377 {} \
378 explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec, \
379 const _parameters_name##_type& parameters) \
380 : ::gko::batch::EnableDefaultBatchLinOpFactory< \
381 _factory_name, _batch_lin_op, _parameters_name##_type>( \
382 std::move(exec), parameters) \
383 {} \
384 }; \
385 friend ::gko::batch::EnableDefaultBatchLinOpFactory< \
386 _factory_name, _batch_lin_op, _parameters_name##_type>; \
387 \
388 \
389private: \
390 _parameters_name##_type _parameters_name##_; \
391 \
392public: \
393 static_assert(true, \
394 "This assert is used to counter the false positive extra " \
395 "semi-colon warnings")
396
397
398} // namespace batch
399} // namespace gko
400
401
402#endif // GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
std::unique_ptr< abstract_product_type > generate(Args &&... args) const
Creates a new product from the given components.
Definition abstract_factory.hpp:67
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:345
This mixin provides a default implementation of a concrete factory.
Definition abstract_factory.hpp:126
This mixin is used to enable a default PolymorphicObject::copy_from() implementation for objects that...
Definition polymorphic_object.hpp:723
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:662
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:234
A BatchLinOpFactory represents a higher order mapping which transforms one batch linear operator into...
Definition batch_lin_op.hpp:195
Definition batch_lin_op.hpp:59
const batch_dim< 2 > & get_size() const noexcept
Returns the size of the batch operator.
Definition batch_lin_op.hpp:83
void validate_application_parameters(const MultiVector< ValueType > *b, MultiVector< ValueType > *x) const
Validates the sizes for the apply(b,x) operation in the concrete BatchLinOp.
Definition batch_lin_op.hpp:91
void validate_application_parameters(const MultiVector< ValueType > *alpha, const MultiVector< ValueType > *b, const MultiVector< ValueType > *beta, MultiVector< ValueType > *x) const
Validates the sizes for the apply(alpha, b , beta, x) operation in the concrete BatchLinOp.
Definition batch_lin_op.hpp:108
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_lin_op.hpp:76
size_type get_num_batch_items() const noexcept
Returns the number of items in the batch operator.
Definition batch_lin_op.hpp:66
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:252
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition batch_multi_vector.hpp:59
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:144
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:134
EnableDefaultFactory< ConcreteFactory, ConcreteBatchLinOp, ParametersType, PolymorphicBase > EnableDefaultBatchLinOpFactory
This is an alias for the EnableDefaultFactory mixin, which correctly sets the template parameters to ...
Definition batch_lin_op.hpp:277
The logger namespace .
Definition batch_logger.hpp:23
The Ginkgo namespace.
Definition abstract_factory.hpp:20
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:89
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:173
A type representing the dimensions of a multidimensional batch object.
Definition batch_dim.hpp:27
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition batch_dim.hpp:43
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition batch_dim.hpp:36
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:26