Graph Framework
Loading...
Searching...
No Matches
Adding New Operations Tutorial

A tutorial for creating new operations.

Introduction

In most cases, physics problems can be generated from combinations of graph nodes. For instance, the graph::tan nodes are built from \(\frac{\sin\left(x\right)}{\cos\left(x\right)}\). However, some problems will call for adding new operations. This page provides a basic example of how to impliment a new operator \(foo\left(x\right)\) in the graph_framework.


Node Subclasses

All graph nodes are subclasses of graph::leaf_node or subclasses of other nodes. In the case of our \(foo\left(x\right)\) example we can sublass the graph::straight_node since these assume single arguments. If there are two or three operands you can subclass

Note
Any existing node can be subclassed but do so with caution. Subclasses inherent reduction rules which maybe incorrect.

In this case, the graph::straight_node (Along with graph::branch_node, graph::triple_node) have no reduction assumputions. For this case since our operation \(foo\left(x\right)\) takes one argument, we will subclass the graph::straight_node.

The basics of subclassing a node, start with a subclass and a constructor.

template<jit::float_scalar T, bool SAFE_MATH=false>
class foo_node : public straight_node {
private:
static std::string to_string(leaf_node<T, SAFE_MATH> *x) {
return "foo(" +
jit::format_to_string(reinterpret_cast<size_t> (l)) +
")";
}
public:
foo_node(shared_leaf<T, SAFE_MATH> x) :
straight_node(x, foo_node::to_string(x.get())) {}
};
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211

The static to_string method provices an idenifier that can be used to generate a hash for the node. This hash will be used later in a factory function to exsure nodes only exist once.

A factory function constructs a node then immedately reduces it. The reduced node is then checked if it already exists in the graph::leaf_node::caches_t::nodes. If the node is a new node, we add it to the cache and return it. Otherwise we discard the node and return the cached node in it's place.

template<jit::float_scalar T, bool SAFE_MATH=false>
shared_leaf<T, SAFE_MATH> foo(shared_leaf<T, SAFE_MATH> x) {
auto temp = std::make_shared<foo_node<T, SAFE_MATH>> (x)->reduce();
// Test for hash collisions.
for (size_t i = temp->get_hash();
i < std::numeric_limits<size_t>::max(); i++) {
if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
leaf_node<T, SAFE_MATH>::caches.nodes[i] = temp;
return temp;
} else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
return leaf_node<T, SAFE_MATH>::caches.nodes[i];
}
}
}

To aid in introspection we also need a function to cast a generic graph::shared_leaf back to the specific node type. For convience, we also define a type alias for shared type.

template<jit::float_scalar T, bool SAFE_MATH=false>
using shared_foo = std::shared_ptr<add_node<T, SAFE_MATH>>;
template<jit::float_scalar T, bool SAFE_MATH=false>
shared_foo<T, SAFE_MATH> foo_cast(shared_leaf<T, SAFE_MATH> x) {
return std::dynamic_pointer_cast<add_node<T, SAFE_MATH>> (x);
}

Methods overloads

To subclass a graph::leaf_node there are several methods that need to be provided.


Evaluate

To start, lets provide a way to evalute the node. The first step to evaluate a node is to evaluate the nodes argument.

virtual shared_leaf<T, SAFE_MATH> evaluate() {
backend::buffer<T> result = this->arg->evaluate();
}
Class representing a generic buffer.
Definition backend.hpp:29

backend::buffer are quick ways we can evalute the node on the host before needing to generate device kernels and is used by the graph::leaf_node::reduce method to precompute constant values. We can extend the backend::buffer class with a new method to evaluate foo or you can use the existing operators. In this case lets assume \(foo\left(x\right)=x^{2}\).

virtual shared_leaf<T, SAFE_MATH> evaluate() {
backend::buffer<T> result = this->arg->evaluate();
return result*result;
}

Is Match

This methiod checks if the node matches another node. The first thing to check is if the pointers match. Then we can check if the structure of the graphs match. This is important for the factory function. When checking for cached nodes, two graphs can be identical but have different pointer values. Checking the structure of the graphs ensures that we catch identical graphs.

virtual bool is_match(shared_leaf<T, SAFE_MATH> x) {
if (this == x.get()) {
return true;
}
auto x_cast = foo_cast(x);
if (x_cast.get()) {
return this->arg->is_match(x_cast->get_arg());
}
return false;
}

Reduce

Lets add a simple reduction method. When the argument \(x \) is a constant we can reduce this node down to a single constant by pre evaluating it.

virtual shared_leaf<T, SAFE_MATH> reduce() {
if (constant_cast(this->arg).get()) {
return constant<T, SAFE_MATH> (this->evaluate());
}
return this->shared_from_this();
}

In this example we first check if the argument can be cast to a constant. If it was castable, we evalute this node and create a new constant to return in its place. Otherwise we return the current node unchanged.

Note
Other reductions are possible but not shown here.

df

Auto differentiation is provided by returning the derivative expression. \(\frac{\partial}{\partial y}foo\left(x\right)=2x\frac{\partial x}{\partial y}\). However, in this framework it is also possible to take a derivative with respect to itself \(\frac{\partial foo\left(x\right)}{\partial foo\left(x\right)}=1 \).

virtual shared_leaf<T, SAFE_MATH> df(shared_leaf<T, SAFE_MATH> x) {
if (this->is_match(x)) {
return one<T, SAFE_MATH> ();
}
const size_t hash = reinterpret_cast<size_t> (x.get());
if (this->df_cache.find(hash) == this->df_cache.end()) {
this->df_cache[hash] = 2.0*this->arg*this->arg->df(x);
}
return this->df_cache[hash];
}

Here we made use of the graph::leaf_node::df_cache to avoid needing to rebuild expressions everytime the same derivative is taken.


Compile preamble

The graph::leaf_node::compile_preamble method provides ways to include header files or define functions. Lets use this method to define a function that can be called from the kernel.

virtual void compile_preamble(std::ostringstream &stream,
jit::register_map &registers,
jit::visiter_map &visited,
jit::texture1d_list &textures1d,
jit::texture2d_list &textures2d,
int &avail_const_mem) {
if (visited.find(this) == visited.end()) {
this->arg->compile_preamble(stream, registers,
visited, usage,
textures1d, textures2d,
avail_const_mem);
jit::add_type<T> (stream);
stream << " foo(const "
jit::add_type<T> (stream);
stream << "x) {"
<< " return 2*x;"
<< "}";
visited.insert(this);
#ifdef SHOW_USE_COUNT
usage[this] = 1;
} else {
++usage[this];
#endif
}
}
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:260

The compile methods generate kernel source code. In this case we created a function in the preamble to evaluate foo. Since we only want this create this preamble once, we first check if this node has already been visited. The build system option SHOW_USE_COUNT tracks the number of times a node is used in the kernel. When this option is set we need to increment it's usage count.

Note
Most nodes don't require a preamble so this method can be left out.

Compile

The compile method writes a line of source code to the kernel. Here we can use the function defined in the preamble.

virtual shared_leaf<T, SAFE_MATH>
compile(std::ostringstream &stream,
jit::register_map &registers,
const jit::register_usage &usage) {
if (registers.find(this) == registers.end()) {
shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
registers,
indices,
usage);
registers[this] = jit::to_string('r', this);
stream << " const ";
jit::add_type<T> (stream);
stream << " " << registers[this] << " = foo("
<< registers[a.get()] << ")";
this->endline(stream, usage);
}
return this->shared_from_this();
}
void compile(graph::input_nodes< T > inputs, graph::output_nodes< T > outputs, graph::map_nodes< T > setters, const T expected, const T tolarance)
Compile kernal and check the result of the output.
Definition jit_test.cpp:49
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245

Kernels are created by assuming infinite registers. In this case, a register is a temporary variable. To provide a unquie name, the node pointer value is converted into a string. Since we only want to evaluate this once, we check if the register has already been created.


To Latex

This method returns the code to generate the \(\LaTeX \) expression for the node.

virtual void to_latex () const {
std::cout << "foo\left(;
this->arg->to_latex();
std::cout << "\right)";
}

Is Power Like

This provides information for other nodes about how this works for reduction methods. In this care we need to set this to true. If this node did not act like a power, this method can be ignored.

virtual bool is_power_like() const {
return true;
}

Get power base

Return the base of the power node. This provides information for other nodes about how this works for reduction methods. In this case the power base is the function argument.

virtual shared_leaf<T, SAFE_MATH> get_power_base() const {
return this->arg;
}

Get power exponent

Return the exponent of the power node. This provides information for other nodes about how this works for reduction methods. In this case, the power exponent is \(2 \).

virtual shared_leaf<T, SAFE_MATH> get_power_exponent() const {
return constant<T, SAFE_MATH> (static_cast<T> (2.0));
}

Remove Pseudo

Return the node with pseduo variables removed. graph::pseudo_variable_node are used to end derivatives construction by treating a sub graph as a pseduo variable. Before graphs can be evaluated, these graph::pseudo_variable_node need to be removed.

virtual shared_leaf<T, SAFE_MATH> remove_pseudo() {
if (this->has_pseudo()) {
return sqrt(this->arg->remove_pseudo());
}
return this->shared_from_this();
}

To Vizgraph

Generates a vizgraph node for visualization.

virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
jit::register_map &registers) {
if (registers.find(this) == registers.end()) {
const std::string name = jit::to_string('r', this);
registers[this] = name;
stream << " " << name
<< " [label = \"foo\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
auto a = this->arg->to_vizgraph(stream, registers);
stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
}
return this->shared_from_this();
}