Basic expression templates on element-wise algebraic expressions

suggest change

Introduction and motivation


Expression templates (denoted as ETs in the following) are a powerful template meta-programming technique, used to speed-up calculations of sometimes quite expensive expressions. It is widely used in different domains, for example in implementation of linear algebra libraries.

For this example, consider the context of linear algebraic computations. More specifically, computations involving only element-wise operations. This kind of computations are the most basic applications of ETs, and they serve as a good introduction to how ETs work internally.

Let’s look at a motivating example. Consider the computation of the expression:

Vector vec_1, vec_2, vec_3;

// Initializing vec_1, vec_2 and vec_3.

Vector result = vec_1 + vec_2*vec_3;

Here for the sake of simplicity, I’ll assume that the class Vector and operation + (vector plus: element-wise plus operation) and operation * (here means vector inner product: also element-wise operation) are both correctly implemented, as how they should be, mathematically.

In a conventional implementation without using ETs (or other similar techniques), at least five constructions of Vector instances take place in order to obtain the final result:

  1. Three instances corresponding to vec_1, vec_2 and vec_3.
  2. A temporary Vector instance _tmp, representing the result of _tmp = vec_2*vec_3;.
  3. Finally with proper use of return value optimization, the construction of final result in result = vec_1 + _tmp;.

Implementation using ETs can eliminate the creation of temporary Vector _tmp in 2, thus leaving only four constructions of Vector instances. More interestingly, consider the following expression which is more complex:

Vector result = vec_1 + (vec_2*vec3 + vec_1)*(vec_2 + vec_3*vec_1);

There will also be four constructions of Vector instances in total: vec_1, vec_2, vec_3 and result. In other words, in this example, where only element-wise operations are involved, it is guaranteed that no temporary objects will be created from intermediate calculations.


How do ETs work


Basically speaking, ETs for any algebraic computations consist of two building blocks:

  1. Pure algebraic expressions (PAE): they are proxies / abstractions of algebraic expressions. A pure algebraic does not do actual computations, they are merely abstractions / modeling of the computation work-flow. A PAE can be a model of either the input or the output of any algebraic expressions. Instances of PAEs are often considered cheap to copy.
  2. Lazy evaluations: which are implementation of real computations. In the following example, we will see that for expressions involving only element-wise operations, lazy evaluations can implement actual computations inside the indexed-access operation on the final result, thus creating a scheme of evaluation on-demand: a computation is not performed only if the final result is accessed / asked for.

So, specifically how do we implement ETs in this example? Let’s walk through it now.

Consider always the following code snippet:

Vector vec_1, vec_2, vec_3;

// Initializing vec_1, vec_2 and vec_3.

Vector result = vec_1 + vec_2*vec_3;

The expression to compute result can be decomposed further into two sub-expressions:

  1. A vector plus expression (denoted as plus_expr)
  2. A vector inner product expression (denoted as innerprod_expr).

What ETs do is the following:

result = plus_expr( vec_1, innerprod_expr(vec_2, vec_3) )
   /   \
  /     \
 /       \
/   innerprod_expr( vec_2, vec_3 )
/         /  \
/         /    \
/         /      \
vec_1     vec_2    vec_3
elem_res = elem_1 + elem_2*elem_3;

there is therefore no need to create a temporary Vector to store the result of intermediate inner product: the whole computation for one element can be done altogether, and be encoded inside the indexed-access operation.


Here are the example codes in action.


File vec.hh : wrapper for std::vector, used to show log when a construction is called.

—``` #ifndef EXPR_VEC # define EXPR_VEC

include

include

include

include

include

include

/// /// This is a wrapper for std::vector. It’s only purpose is to print out a log when a /// vector constructions in called. /// It wraps the indexed access operator [] and the size() method, which are /// important for later ETs implementation. ///

// std::vector wrapper. template class Vector { public: explicit Vector() { std::cout << “ctor called.”; }; explicit Vector(int size): _vec(size) { std::cout << “ctor called.”; }; explicit Vector(const std::vector &vec): _vec(vec) { std::cout << “ctor called.”; };

Vector(const Vector & vec): _vec{vec()} { std::cout << “copy ctor called.”; }; Vector(Vector && vec): _vec(std::move(vec())) { std::cout << “move ctor called.”; };

Vector & operator=(const Vector &) = default; Vector & operator=(Vector &&) = default;

decltype(auto) operator { return _vec[indx]; } decltype(auto) operator const { return _vec[indx]; }

decltype(auto) operator()() & { return (_vec); }; decltype(auto) operator()() const & { return (_vec); }; Vector && operator()() && { return std::move(*this); }

int size() const { return _vec.size(); }

private: std::vector _vec; };

/// /// These are conventional overloads of operator + (the vector plus operation) /// and operator * (the vector inner product operation) without using the expression /// templates. They are later used for bench-marking purpose. ///

// + (vector plus) operator. template auto operator+(const Vector &lhs, const Vector &rhs) { assert(lhs().size() == rhs().size() && “error: ops plus -> lhs and rhs size mismatch.”);

std::vector _vec; _vec.resize(lhs().size()); std::transform(std::cbegin(lhs()), std::cend(lhs()), std::cbegin(rhs()), std::begin(_vec), std::plus<>()); return Vector(std::move(_vec)); }

// * (vector inner product) operator. template auto operator*(const Vector &lhs, const Vector &rhs) { assert(lhs().size() == rhs().size() && “error: ops multiplies -> lhs and rhs size mismatch.”);

std::vector _vec; _vec.resize(lhs().size()); std::transform(std::cbegin(lhs()), std::cend(lhs()), std::cbegin(rhs()), std::begin(_vec), std::multiplies<>()); return Vector(std::move(_vec)); }

#endif //!EXPR_VEC

---

## File expr.hh : implementation of expression templates for element-wise operations (vector plus and vector inner product)

---

Let’s break it down to sections.

1. Section 1 implements a base class for all expressions. It employs the **Curiously Recurring Template Pattern** ([CRTP](http://stackoverflow.com/documentation/c%2B%2B/709/curiously-recurring-template-pattern-crtp#t=201607241604559383674)).

2. Section 2 implements the first **PAE**: a **terminal**, which is just a wrapper (const reference) of an input data structure containing real input value for computation.

3. Section 3 implements the second **PAE**: **binary\_operation**, which is a class template later used for vector\_plus and vector\_innerprod. It’s parametrized by the **type of operation**, **the left-hand-side PAE** and **the right-hand-side PAE**. The actual computation is encoded in the indexed-access operator.

4. Section 4 defines vector\_plus and vector\_innerprod operations as **element-wise operation**. It also overload operator + and \* for **PAE**s: such that these two operations also return **PAE**.

#ifndef EXPR_EXPR # define EXPR_EXPR

/// Fwd declaration. template class Vector;

namespace expr {

/// —————————————– /// /// Section 1. /// /// The first section is a base class template for all kinds of expression. It /// employs the Curiously Recurring Template Pattern, which enables its instantiation /// to any kind of expression structure inheriting from it. /// /// —————————————–

/// Base class for all expressions. template class expr_base { public: const Expr& self() const { return static_cast<const Expr&>(this); } Expr& self() { return static_cast<Expr&>(this); }

protected: explicit expr_base() {}; int size() const { return self().size_impl(); } auto operator const { return self().at_impl(indx); } auto operator()() const { return self()(); }; };

/// —————————————– /// /// The following section 2 & 3 are abstractions of pure algebraic expressions (PAE). /// Any PAE can be converted to a real object instance using operator(): it is in /// this conversion process, where the real computations are done.

/// /// Section 2. Terminal /// /// A terminal is an abstraction wrapping a const reference to the Vector data /// structure. It inherits from expr_base, therefore providing a unified interface /// wrapping a Vector into a PAE. /// /// It provides the size() method, indexed access through at_impl() and a conversion /// to referenced object through () operator. /// /// It might no be necessary for user defined data structures to have a terminal /// wrapper, since user defined structure can inherit expr_base, therefore eliminates /// the need to provide such terminal wrapper. /// /// —————————————–

/// Generic wrapper for underlying data structure. template class terminal: expr_base<terminal> { public: using base_type = expr_base<terminal>; using base_type::size; using base_type::operator[]; friend base_type;

explicit terminal(const DataType &val): _val(val) {} int size_impl() const { return _val.size(); }; auto at_impl(int indx) const { return _val[indx]; }; decltype(auto) operator()() const { return (_val); }

private: const DataType &_val; };

/// —————————————– /// /// Section 3. Binary operation expression. /// /// This is a PAE abstraction of any binary expression. Similarly it inherits from /// expr_base. /// /// It provides the size() method, indexed access through at_impl() and a conversion /// to referenced object through () operator. Each call to the at_impl() method is /// a element wise computation. /// /// —————————————–

/// Generic wrapper for binary operations (that are element-wise). template<typename Ops, typename lExpr, typename rExpr> class binary_ops: public expr_base<binary_ops<Ops,lExpr,rExpr>> { public: using base_type = expr_base<binary_ops<Ops,lExpr,rExpr>>; using base_type::size; using base_type::operator[]; friend base_type;

explicit binary_ops(const Ops &ops, const lExpr &lxpr, const rExpr &rxpr) : _ops(ops), _lxpr(lxpr), _rxpr(rxpr) {}; int size_impl() const { return _lxpr.size(); };

/// This does the element-wise computation for index indx. auto at_impl(int indx) const { return _ops(_lxpr[indx], _rxpr[indx]); };

/// Conversion from arbitrary expr to concrete data type. It evaluates /// element-wise computations for all indices. template operator DataType() { DataType _vec(size()); for(int _ind = 0; _ind < _vec.size(); ++_ind) _vec[_ind] = (*this)[_ind]; return _vec; }

private: /// Ops and expr are assumed cheap to copy. Ops _ops; lExpr _lxpr; rExpr _rxpr; };

/// —————————————– /// Section 4. /// /// The following two structs defines algebraic operations on PAEs: here only vector /// plus and vector inner product are implemented. /// /// First, some element-wise operations are defined : in other words, vec_plus and /// vec_prod acts on elements in Vectors, but not whole Vectors. /// /// Then, operator + & * are overloaded on PAEs, such that: + & * operations on PAEs /// also return PAEs. /// /// —————————————–

/// Element-wise plus operation. struct vec_plus_t { constexpr explicit vec_plus_t() = default; template<typename LType, typename RType> auto operator()(const LType &lhs, const RType &rhs) const { return lhs+rhs; } };

/// Element-wise inner product operation. struct vec_prod_t { constexpr explicit vec_prod_t() = default; template<typename LType, typename RType> auto operator()(const LType &lhs, const RType &rhs) const { return lhs*rhs; } };

/// Constant plus and inner product operator objects. constexpr vec_plus_t vec_plus{}; constexpr vec_prod_t vec_prod{};

/// Plus operator overload on expressions: return binary expression. template<typename lExpr, typename rExpr> auto operator+(const lExpr &lhs, const rExpr &rhs) { return binary_ops<vec_plus_t,lExpr,rExpr>(vec_plus,lhs,rhs); }

/// Inner prod operator overload on expressions: return binary expression. template<typename lExpr, typename rExpr> auto operator*(const lExpr &lhs, const rExpr &rhs) { return binary_ops<vec_prod_t,lExpr,rExpr>(vec_prod,lhs,rhs); }

} //!expr

#endif //!EXPR_EXPR

---

## File main.cc : test src file

---```
# include <chrono>
# include <iomanip>
# include <iostream>
# include "vec.hh"
# include "expr.hh"
# include "boost/core/demangle.hpp"

int main()
{
  using dtype = float;
  constexpr int size = 5e7;
  
  std::vector<dtype> _vec1(size);
  std::vector<dtype> _vec2(size);
  std::vector<dtype> _vec3(size);

  // ... Initialize vectors' contents.

  Vector<dtype> vec1(std::move(_vec1));
  Vector<dtype> vec2(std::move(_vec2));
  Vector<dtype> vec3(std::move(_vec3));

  unsigned long start_ms_no_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << "\nNo-ETs evaluation starts.\n";
  
  Vector<dtype> result_no_ets = vec1 + (vec2*vec3);
  
  unsigned long stop_ms_no_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << std::setprecision(6) << std::fixed
            << "No-ETs. Time eclapses: " << (stop_ms_no_ets-start_ms_no_ets)/1000.0
            << " s.\n" << std::endl;
  
  unsigned long start_ms_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << "Evaluation using ETs starts.\n";
  
  expr::terminal<Vector<dtype>> vec4(vec1);
  expr::terminal<Vector<dtype>> vec5(vec2);
  expr::terminal<Vector<dtype>> vec6(vec3);
  
  Vector<dtype> result_ets = (vec4 + vec5*vec6);
  
  unsigned long stop_ms_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << std::setprecision(6) << std::fixed
            << "With ETs. Time eclapses: " << (stop_ms_ets-start_ms_ets)/1000.0
            << " s.\n" << std::endl;
  
  auto ets_ret_type = (vec4 + vec5*vec6);
  std::cout << "\nETs result's type:\n";
  std::cout << boost::core::demangle( typeid(decltype(ets_ret_type)).name() ) << '\n'; 

  return 0;
}

Here’s one possible output when compiled with -O3 -std=c++14 using GCC 5.3:

ctor called.
ctor called.
ctor called.

No-ETs evaluation starts.
ctor called.
ctor called.
No-ETs. Time eclapses: 0.571000 s.

Evaluation using ETs starts.
ctor called.
With ETs. Time eclapses: 0.164000 s.

ETs result's type:
expr::binary_ops<expr::vec_plus_t, expr::terminal<Vector<float> >, expr::binary_ops<expr::vec_prod_t, expr::terminal<Vector<float> >, expr::terminal<Vector<float> > > >

The observations are:


Draw-backs and caveats


auto result = ...;                // Some expensive expression: 
                                  // auto returns the expr graph, 
                                  // NOT the computed value.
for(auto i = 0; i < 100; ++i)
    ScalrType value = result* ... // Some other expensive computations using result.

Here in each iteration of the for loop, result will be re-evaluated, since the expression graph instead of the computed value is passed to the for loop.


Existing libraries implementing ETs


Feedback about page:

Feedback:
Optional: your email if you want me to get back to you:



Table Of Contents