Bridge TensorFlow* to run on Intel®
nGraph™ backends
v0.5 : ソースコード解析
作成:2018/08/11, 19
Slideshareにて公開
  :2018/09/04
@Vengineer
ブログ (2007年~) : Vengineerの戯言
 http://blogs.yahoo.co.jp/verification_engineer
SlideShare :
 https://www.slideshare.net/ssuser479fa3
Twitter (2009年~) :
@Vengineer
最近は、ソースコード解析職人
https://github.com/NervanaSystems/ngraph
Intel nGraph library
ONNX
neon
TensorFlow
MXNet
NNP = ARGON ?
TensorFlow
dynamically loadable XLA plugin
https://github.com/NervanaSystems/ngraph-tensorflow
tensorflow/compiler/plugin/dynamic
TensorFlowのXLA側のコードの修正必要が無くなる
Bridge TensorFlow* to run on Intel®
nGraph™ backends
https://github.com/NervanaSystems/ngraph-tf
https://github.com/NervanaSystems/ngraph-tf/tree/r0.5/
グラフの変形
Passを使ってグラフの変形を行っている
 1)、Feed/Fetchノードの追加
subgraph::RewriteGraphForExecution
ここで、PRE_PLACEMENTパス を実行
 2)、Placement
ここで、POST_PLACEMENTパス を実行
  SimpleGraphExecutionState::BuildGraph関数で
   POST_REWRITE_FOR_EXECパス を実行
 3)、グラフの分割
Partition
ここで、POST_PARTITIONINGパス を実行
OptimizationPass
tensorflow/core/common_runtime/optimization_registry.h
class OptimizationPassRegistry {
public:
// Groups of passes are run at different points in initialization.
enum Grouping {
PRE_PLACEMENT, // after cost model assignment, before placement.
POST_PLACEMENT, // after placement.
POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints.
POST_PARTITIONING, // after partitioning
};
src/ngraph_rewrite_pass.cc
POST_PLACEMENT, 0, NGraphVariableCapturePass
POST_REWRITE_FOR_EXEC, 0, NGraphEncapsulationPass
NGraphVariableCapturePass と NGraphEncapsulationPass は、
NGraphRewritePass を継承
最適化パス
NGraphVariableCapturePass
POST_PLACEMENT
 ・NGraphVariableCapturePass
 ・VariableV2 のすべてのインスタンスを NGraphVariable op に置換
最適化パス
src/ngraph_rewrite_pass.cc
class NGraphVariableCapturePass : public NGraphRewritePass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
….
};
NGraphVariableCapturePass
src/ngraph_rewrite_pass.cc
Status NGraphVariableCapturePass::Run(
const GraphOptimizationPassOptions& options) override;{
// For filename generation purposes, grab a fresh index. This is just an
// arbitrary integer to avoid filename collisions resulting from subsequent
// runs of this pass.
int idx = FreshIndex();
// Do variable capture then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(CaptureVariables(options.graph->get()));
return Status::OK();
}
NGraphVariableCapturePass
src/ngraph_capture_variables.cc
static bool NGraphPlacementRequested(const Node* node) { return true; }
Status CaptureVariables(Graph* graph) {
for (auto node : graph->op_nodes()) {
if (NGraphPlacementRequested(node)) {
if (node->type_string() == "VariableV2") {
CaptureVariables
src/ngraph_capture_variables.cc
std::string container;
std::string shared_name;
if (GetNodeAttr(node->attrs(), "container", &container) != Status::OK()) {
container = "";
}
if (GetNodeAttr(node->attrs(), "shared_name", &shared_name) != Status::OK()) {
shared_name = "";
}
CaptureVariables
src/ngraph_capture_variables.cc
Node* replacement;
TF_RETURN_IF_ERROR( // NGraphVariable というノードを生成
NodeBuilder(node->name(), "NGraphVariable")
.Attr("shape",shape)
.Attr("dtype",dtype)
.Attr("container",container)
.Attr("shared_name",shared_name)
.Device(node->assigned_device_name())
.Finalize(graph,&replacement));
replacement->set_assigned_device_name(node->assigned_device_name());
CaptureVariables
src/ngraph_capture_variables.cc
std::vector<const Edge*> edges;
for (auto edge : node->out_edges()) { edges.push_back(edge) }}
for (auto edge : edges) { // グラフのエッジ情報の更新
graph->UpdateEdge(replacement, edge->src_output(),
edge->dst(), edge->dst_input());
}
}
}
}
return Status::OK();
}
CaptureVariables
NGraphVariableOp
src/ngraph_tracked_variable.cc
class NGraphVariableOp : public OpKernel {
public:
explicit NGraphVariableOp(OpKernelConstruction* context);
void Compute(OpKernelContext* ctx) override;
private:
DataType dtype_;
TensorShape shape_;
mutex init_mu_;
ContainerInfo cinfo_ GUARDED_BY(init_mu_);
bool initialized_ GUARDED_BY(init_mu_){false};
TF_DISALLOW_COPY_AND_ASSIGN(NGraphVariableOp);
};
NGraphVariableOp
src/ngraph_tracked_variable.cc
NGraphVariableOp::NGraphVariableOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
dtype_ = RemoveRefType(context->output_type(0));
}
NGraphVariableOp
src/ngraph_tracked_variable.cc
void NGraphVariableOp::Compute(OpKernelContext* ctx) {
mutex_lock l(init_mu_);
if (!initialized_) {
ctx = cinfo_.Init(ctx->resource_manager(), def(), true /* use name() */));
initialized_ = true;
}
auto creator = [this](NGraphVar** var) {
*var = new NGraphVar(dtype_,shape_);
//(*var)->tensor()->set_shape(shape_);
return Status::OK();
};
NGraphVariableOp
src/ngraph_tracked_variable.cc
NGraphVar* var;
ctx = cinfo_.resource_manager()->LookupOrCreate<NGraphVar>(
cinfo_.container(), cinfo_.name(), &var, creator));
ctx->set_output_ref(0, var->mu(), var->tensor());
if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
}
}
NGraphVariableOp
src/ngraph_variable_ops.cc
class NGraphVar : public ResourceBase {
public:
explicit NGraphVar(DataType dtype, const TensorShape& shape)
: tensor_(dtype, shape) {}
NGraphVar(const NGraphVar&) = delete;
NGraphVar& operator=(const NGraphVar&) = delete;
….
~NGraphVar() override {}
};
NGraphVar
src/ngraph_tracked_variable.cc
REGISTER_OP("NGraphVariable")
.Output("ref: Ref(dtype)")
.Attr("shape: shape")
.Attr("dtype: type")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ExplicitShape);
REGISTER_KERNEL_BUILDER(
Name("NGraphVariable").Device(DEVICE_CPU),
NGraphVariableOp);
NGraphVariableOp
NGraphEncapsulatePass
POST_REWRITE_FOR_EXEC
 ・NGraphEncapsulatePass
 ・分割した サブグラフ を NGraphEncapsulateOp に変換
  これも、TensorFlow XLAでの _XlaLaunchOp に変換と同じこと
最適化パス
src/ngraph_rewrite_pass.cc
// Pass that rewrites the graph for nGraph operation.
//
// The pass has several phases, each executed in sequence:
//
// 1. Marking [ngraph_mark_for_clustering.cc]
// 2. Cluster Assignment [ngraph_assign_clusters.cc]
// 3. Cluster Deassignment [ngraph_deassign_clusters.cc]
// 4. Cluster Encapsulation [ngraph_encapsulate_clusters.cc]
NGraphEncapsulatePass
src/ngraph_rewrite_pass.cc
class NGraphEncapsulationPass : public NGraphRewritePass {
public:
Status Run(const GraphOptimizationPassOptions& options) override {
// For filename generation purposes, grab a fresh index. This is just an
// arbitrary integer to avoid filename collisions resulting from subsequent
// runs of this pass.
int idx = FreshIndex();
// If requested, dump unmarked graphs.
if (DumpUnmarkedGraphs()) {
DumpGraphs(options, idx, "unmarked", "Unmarked Graph");
}
NGraphEncapsulatePass
src/ngraph_rewrite_pass.cc
// 1. Marking [ngraph_mark_for_clustering.cc]
// Mark for clustering then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(MarkForClustering(options.graph->get()));
if (DumpMarkedGraphs()) {
DumpGraphs(options, idx, "marked", "Graph Marked for Clustering");
}
NGraphEncapsulatePass
src/ngraph_rewrite_pass.cc
// 2. Cluster Assignment [ngraph_assign_clusters.cc]
// Assign clusters then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(AssignClusters(options.graph->get()));
if (DumpClusteredGraphs()) {
DumpGraphs(options, idx, "clustered", "Graph with Clusters Assigned");
}
NGraphEncapsulatePass
src/ngraph_rewrite_pass.cc
// 3. Cluster Deassignment [ngraph_deassign_clusters.cc]
// Deassign trivial clusters then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(DeassignClusters(options.graph->get()));
if (DumpDeclusteredGraphs()) {
DumpGraphs(options, idx, "declustered",
"Graph with Trivial Clusters De-Assigned");
}
NGraphEncapsulatePass
src/ngraph_rewrite_pass.cc
// 4. Cluster Encapsulation [ngraph_encapsulate_clusters.cc]
// Encapsulate clusters then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(EncapsulateClusters(options.graph->get()));
if (DumpEncapsulatedGraphs()) {
DumpGraphs(options, idx, "encapsulated",
"Graph with Clusters Encapsulated");
}
NGraphEncapsulatePass
src/ngraph_rewrite_pass.cc
// Rewrite for tracking then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(RewriteForTracking(options.graph->get()));
if (DumpTrackedGraphs()) {
DumpGraphs(options, idx, "tracked",
"Graph with Variables Rewritten for Tracking");
}
return Status::OK();
}
NGraphEncapsulatePass
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
Status EncapsulateClusters(Graph* graph) {
std::map<int, std::string> device_name_map;
std::map<std::tuple<int, int>, std::tuple<int, int>> output_remap_map;
std::map<std::tuple<int, int, int>, int> input_remap_map;
std::map<std::tuple<int, std::string, int>, string> input_rename_map;
std::map<int, std::vector<std::tuple<int, int, DataType>>> cluster_input_map;
std::map<int, std::vector<DataType>> cluster_output_dt_map;
std::map<int, Node*> cluster_node_map;
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 1: Populate the cluster-index-to-device name map for each existing
// cluster.
if (it != device_name_map.end()) {
if (it->second != node->requested_device()) {
std::stringstream ss_err;
// ここでエラーメッセージを生成
return errors::Internal(ss_err.str());
}
} else {
device_name_map[cluster_idx] = node->requested_device();
}
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 2: Find all nodes that are feeding into/out of each cluster, and
// add inputs for them to the corresponding FunctionDef(s).
std::map<int, int> retval_index_count;
std::map<int, int> arg_index_count;
for (auto edge : graph->edges()) {
if (edge->IsControlEdge()) {
continue;
}
Node* src = edge->src();
Node* dst = edge->dst();
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
if (!src->IsOp() || !dst->IsOp()) {
continue;
}
int dst_cluster_idx;
bool dst_clustered =
(GetNodeCluster(dst, &dst_cluster_idx) == Status::OK());
int src_cluster_idx;
bool src_clustered =
(GetNodeCluster(src, &src_cluster_idx) == Status::OK());
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
if (dst_cluster_idx == src_cluster_idx) {
continue;
}
DataType dt = dst->input_type(edge->dst_input());
std::string flow_kind = dst_clustered && src_clustered
? "cross-flow"
: dst_clustered ? "in-flow" : "out-flow";
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
if (src_clustered &&
output_remap_map.find(std::make_tuple(src->id(), edge->src_output())) ==
output_remap_map.end()) {
output_remap_map[std::make_tuple(src->id(), edge->src_output())] =
std::make_tuple(src_cluster_idx,
cluster_output_dt_map[src_cluster_idx].size());
std::stringstream ss;
ss << "ngraph_output_" << cluster_output_dt_map[src_cluster_idx].size();
string output_name = ss.str();
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
auto new_output_node_def =
NGraphClusterManager::GetClusterGraph(src_cluster_idx)->add_node();
new_output_node_def->set_name(output_name);
new_output_node_def->set_op("_Retval");
std::stringstream ss_input_to_retval;
ss_input_to_retval << src->name() << ":" << edge->src_output();
new_output_node_def->add_input(ss_input_to_retval.str());
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
SetAttrValue(dt, &((*(new_output_node_def->mutable_attr()))["T"]));
SetAttrValue(retval_index_count[src_cluster_idx],
&((*(new_output_node_def->mutable_attr()))["index"]));
retval_index_count[src_cluster_idx]++;
cluster_output_dt_map[src_cluster_idx].push_back(dt);
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
if (dst_clustered &&
input_remap_map.find(
std::make_tuple(dst_cluster_idx, src->id(), edge->src_output())) ==
input_remap_map.end()) {
input_remap_map[std::make_tuple(dst_cluster_idx, src->id(),
edge->src_output())] =
cluster_input_map[dst_cluster_idx].size();
std::stringstream ss;
ss << "ngraph_input_" << cluster_input_map[dst_cluster_idx].size();
std::string new_input_name = ss.str();
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
input_rename_map[std::make_tuple(dst_cluster_idx, src->name(),
edge->src_output())] = new_input_name;
auto new_input_node_def =
NGraphClusterManager::GetClusterGraph(dst_cluster_idx)->add_node();
new_input_node_def->set_name(new_input_name);
new_input_node_def->set_op("_Arg");
SetAttrValue(dt, &((*(new_input_node_def->mutable_attr()))["T"]));
SetAttrValue(arg_index_count[dst_cluster_idx],
&((*(new_input_node_def->mutable_attr()))["index"]));
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
arg_index_count[dst_cluster_idx]++;
cluster_input_map[dst_cluster_idx].push_back(
std::make_tuple(src->id(), edge->src_output(), dt));
}
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 3: Create encapsulation nodes for all clusters.
for (auto& kv : device_name_map) {
int cluster_idx = kv.first;
std::stringstream ss;
ss << "ngraph_cluster_" << cluster_idx;
std::vector<DataType> input_types;
std::vector<NodeBuilder::NodeOut> inputs;
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
for (auto& tup : cluster_input_map[cluster_idx]) {
int src_node_id;
int src_output_idx;
DataType dt;
std::tie(src_node_id, src_output_idx, dt) = tup;
input_types.push_back(dt);
inputs.push_back(
NodeBuilder::NodeOut(graph->FindNodeId(src_node_id), src_output_idx));
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
Node* n;
Status status = NodeBuilder(ss.str(), "NGraphEncapsulate")
.Attr("ngraph_cluster", cluster_idx)
.Attr("Targuments", input_types)
.Attr("Tresults", cluster_output_dt_map[cluster_idx])
.Device(device_name_map[cluster_idx])
.Input(inputs)
.Finalize(graph, &n);
TF_RETURN_IF_ERROR(status);
n->set_assigned_device_name(device_name_map[cluster_idx]);
cluster_node_map[cluster_idx] = n;
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 4: Remap all non-clustered inputs that are reading from
// encapsulated edges, and all control edges that cross cluster
// boundaries.
// Copy the edge pointers, so as not to invalidate the iterator.
std::vector<Edge*> edges;
for (auto edge : graph->edges()) {
edges.push_back(edge);
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
for (auto edge : edges) {
int src_cluster_idx;
bool src_clustered =
(GetNodeCluster(edge->src(), &src_cluster_idx) == Status::OK());
int dst_cluster_idx;
bool dst_clustered =
(GetNodeCluster(edge->dst(), &dst_cluster_idx) == Status::OK());
if (src_cluster_idx == dst_cluster_idx) {
continue;
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
if (edge->IsControlEdge()) {
if (src_clustered && dst_clustered) {
graph->RemoveControlEdge(edge);
graph->AddControlEdge(cluster_node_map[src_cluster_idx],
cluster_node_map[dst_cluster_idx]);
} else if (src_clustered) {
Node* dst = edge->dst();
graph->RemoveControlEdge(edge);
graph->AddControlEdge(cluster_node_map[src_cluster_idx], dst);
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
} else if (dst_clustered) {
Node* src = edge->src();
graph->RemoveControlEdge(edge);
graph->AddControlEdge(src, cluster_node_map[dst_cluster_idx]);
}
} else {
// This is handled at a later stage (TODO(amprocte): explain)
if (dst_clustered) {
continue;
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
auto it = output_remap_map.find(
std::make_tuple(edge->src()->id(), edge->src_output()));
if (it == output_remap_map.end()) {
continue;
}
int cluster_idx, cluster_output;
std::tie(cluster_idx, cluster_output) = it->second;
graph->UpdateEdge(cluster_node_map[cluster_idx], cluster_output,
edge->dst(), edge->dst_input());
}
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 5: Make copies of all clustered nodes inside the cluster graphs,
// rewiring the inputs in their NodeDefs as we go.
for (auto node : graph->op_nodes()) {
int cluster_idx;
if (GetNodeAttr(node->attrs(), "_ngraph_cluster", &cluster_idx) !=
Status::OK()) {
continue;
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Because the input names may have changed from the original node def,
// we will need to borrow some code from Graph::ToGraphDefSubRange
in
// tensorflow/core/graph/graph.cc that rewrites the node's input list.
// begin code copied and pasted (and modified) from graph.cc...
NodeDef original_def = node->def();
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
original_def.clear_input();
original_def.mutable_input()->Reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
const Edge* edge = inputs[i];
if (edge == nullptr) {
if (i < node->requested_inputs().size()) {
original_def.add_input(node->requested_inputs()[i]);
} else {
original_def.add_input("");
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
} else {
const Node* src = edge->src();
if (!src->IsOp()) continue;
AddInput(&original_def, src->name(), edge->src_output());
}
}
auto node_def =
NGraphClusterManager::GetClusterGraph(cluster_idx)->add_node();
*node_def = original_def;
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
for (auto& input : *(node_def->mutable_input())) {
TensorId tensor_id = ParseTensorName(input);
auto it = input_rename_map.find(std::make_tuple(
cluster_idx, tensor_id.first.ToString(), tensor_id.second));
if (it != input_rename_map.end()) {
input = it->second;
}
}
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 6: Remove clustered nodes from the graph.
for (auto node : graph->op_nodes()) {
int cluster_idx;
if (GetNodeAttr(node->attrs(), "_ngraph_cluster", &cluster_idx) !=
Status::OK()) {
continue;
}
graph->RemoveNode(node);
}
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// Pass 7 (optional, only run if environment variable <= デバッグ用?
// NGRAPH_TF_VALIDATE_CLUSTER_GRAPHS is set):
// validate the graph def, and
// make sure we can construct a graph from it.
if (std::getenv("NGRAPH_TF_VALIDATE_CLUSTER_GRAPHS")) {
for (auto& kv : device_name_map) {
int cluster_idx = kv.first;
TF_RETURN_IF_ERROR(graph::ValidateGraphDef(
*NGraphClusterManager::GetClusterGraph(cluster_idx),
*OpRegistry::Global()));
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// TensorFlowグラフを生成
Graph g(OpRegistry::Global());
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
opts, *NGraphClusterManager::GetClusterGraph(cluster_idx), &g));
std::stringstream ss;
ss << "ngraph_cluster_" << cluster_idx;
std::string filename_prefix = ss.str();
EncapsulateClusters
src/ngraph_encapsulate_clusters.cc
// グラフ情報をファイルに出力
GraphToPbTextFile(&g, filename_prefix + ".pbtxt");
GraphToDotFile(&g, filename_prefix + ".dot",
"nGraph Cluster Dump: " + filename_prefix);
}
}
return Status::OK();
}
EncapsulateClusters
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
class NGraphEncapsulateOp : public OpKernel {
public:
explicit NGraphEncapsulateOp(OpKernelConstruction* ctx)
: OpKernel(ctx), m_graph(OpRegistry::Global()) {
GraphDef* graph_def;
ctx = ctx->GetAttr<int>("ngraph_cluster", &m_ngraph_cluster));
graph_def = NGraphClusterManager::GetClusterGraph(m_ngraph_cluster);
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
ctx = ConvertGraphDefToGraph(opts, *graph_def, &m_graph);
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
if (m_ng_backend == nullptr) {
#if defined(NGRAPH_EMBEDDED_IN_TENSORFLOW)
NGRAPH_VLOG(2) << "Using INTERPRETER backend since "
"NGRAPH_EMBEDDED_IN_TENSORFLOW is enabled";
m_ng_backend_name = "INTERPRETER";
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
#else
const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND");
if (ng_backend_env_value != nullptr) {
m_ng_backend_name = std::string(ng_backend_env_value);
if (m_ng_backend_name.empty()) {
m_ng_backend_name = "CPU";
}
} else {
m_ng_backend_name = "CPU";
}
#endif
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
m_ng_backend = ng::runtime::Backend::create(m_ng_backend_name);
OP_REQUIRES(ctx, m_ng_backend != nullptr,
errors::InvalidArgument("Cannot create nGraph backend"));
}
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
void Compute(OpKernelContext* ctx) override {
std::vector<TensorShape> input_shapes;
std::stringstream signature_ss;
for (int i = 0; i < ctx->num_inputs(); i++) { // 入力
const Tensor& input_tensor = ctx->input(i);
input_shapes.push_back(input_tensor.shape());
for (const auto& x : input_tensor.shape()) {
signature_ss << x.size << ",";
}
signature_ss << ";";
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
std::shared_ptr<ngraph::Function> ng_function;
std::string signature = signature_ss.str();
auto it = m_ng_functions.find(signature); // 関数名を探す
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
if (it == m_ng_functions.end()) { // 関数名が見つからない場合
// TensorFlowのグラフ を nGraphのグラフ に変換
// input_shapes 入力データの形状
// m_graph 変換前の TensorFlowのグラフ
// ng_function 変換後の nGraphのグラフ
ctx = Builder::TranslateGraph(input_shapes, &m_graph, ng_function));
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
if (std::getenv("NGRAPH_ENABLE_SERIALIZE") != nullptr) {
std::string file_name =
"tf_function_" + ctx->op_kernel().name() + ".json";
std::string js = ngraph::serialize(ng_function, 4);
{
std::ofstream f(file_name);
f << js;
}
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
m_ng_functions[signature] = ng_function; // 関数を登録
} else {
ng_function = it->second; // 登録済みの関数
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
if (m_freshness_tracker == nullptr) {
auto creator = [](NGraphFreshnessTracker** tracker) {
*tracker = new NGraphFreshnessTracker();
return Status::OK();
};
OP_REQUIRES_OK(
ctx,
ctx->resource_manager()->LookupOrCreate<NGraphFreshnessTracker>(
ctx->resource_manager()->default_container(),
"ngraph_freshness_tracker", &m_freshness_tracker, creator));
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
vector<shared_ptr<ng::runtime::TensorView>> ng_inputs;
std::vector<std::pair<void*, std::shared_ptr<ng::runtime::TensorView>>>&
input_caches = m_ng_function_input_cache_map[ng_function];
input_caches.resize(input_shapes.size());
// 入力を nGraph用に変換
for (int i = 0; i < input_shapes.size(); i++) {
ng::Shape ng_shape(input_shapes[i].dims());
for (int j = 0; j < input_shapes[i].dims(); ++j) {
ng_shape[j] = input_shapes[i].dim_size(j);
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
ng::element::Type ng_element_type;
ctx = TFDataTypeToNGraphElementType(ctx->input(i).dtype(),
&ng_element_type));
void* last_src_ptr = input_caches[i].first;
std::shared_ptr<ng::runtime::TensorView> last_tv = input_caches[i].second;
// 入力データを DMA
void* current_src_ptr = (void*)DMAHelper::base(&ctx->input(i));
std::shared_ptr<ng::runtime::TensorView> current_tv;
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
if (m_ng_backend_name == "CPU") {
if (current_src_ptr == last_src_ptr && last_tv != nullptr) {
if (m_freshness_tracker->IsFresh(current_src_ptr, ng_function)) {
last_tv->set_stale(false);
} else {
last_tv->set_stale(true);
}
current_tv = last_tv;
} else {
current_tv = m_ng_backend->create_tensor(ng_element_type, ng_shape, current_src_ptr);
current_tv->set_stale(true);
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
} else {
if (last_tv != nullptr) {
current_tv = last_tv;
} else {
current_tv = m_ng_backend->create_tensor(ng_element_type, ng_shape);
}
current_tv->write(current_src_ptr, 0, current_tv->get_element_count() *
ng_element_type.size());
} // if (m_ng_backend_name == "CPU")
input_caches[i] = std::make_pair(current_src_ptr, current_tv);
ng_inputs.push_back(current_tv);
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
// 出力
vector<shared_ptr<ng::runtime::TensorView>> ng_outputs;
std::vector<std::pair<void*, std::shared_ptr<ng::runtime::TensorView>>>&
output_caches = m_ng_function_output_cache_map[ng_function];
output_caches.resize(ng_function->get_output_size());
for (auto i = 0; i < ng_function->get_output_size(); i++) {
auto ng_shape = ng_function->get_output_shape(i);
auto ng_element_type = ng_function->get_output_element_type(i);
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
vector<int64> dims;
for (auto dim : ng_shape) {
dims.push_back(dim);
}
TensorShape tf_shape(dims);
Tensor* output_tensor = nullptr;
ctx = ctx->allocate_output(i, tf_shape, &output_tensor);
ng::element::Type expected_elem_type;
ctx = TFDataTypeToNGraphElementType(ctx->expected_output_dtype(i),
&expected_elem_type);
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
OP_REQUIRES(
ctx, ng_element_type == expected_elem_type,
errors::Internal("Element type inferred by nGraph does not match "
"the element type expected by TensorFlow"));
void* last_dst_ptr = output_caches[i].first;
std::shared_ptr<ng::runtime::TensorView> last_tv =
output_caches[i].second;
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
// 出力データを DMA
void* current_dst_ptr = DMAHelper::base(output_tensor);
std::shared_ptr<ng::runtime::TensorView> current_tv;
if (m_ng_backend_name == "CPU") {
if (current_dst_ptr == last_dst_ptr && last_tv != nullptr) {
current_tv = last_tv;
} else {
current_tv = m_ng_backend->create_tensor(ng_element_type, ng_shape, current_dst_ptr);
}
} else {
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
if (last_tv != nullptr) {
current_tv = last_tv;
} else {
current_tv = m_ng_backend->create_tensor(ng_element_type, ng_shape);
}
}
current_tv->set_stale(true);
output_caches[i] = std::make_pair(current_dst_ptr, current_tv);
ng_outputs.push_back(current_tv);
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
// nGraph で関数を実行
// ng_function nGraphの関数
// ng_outputs nGraphの出力バッファ
// ng_inputs nGraphの入力バッファ
m_ng_backend->call(ng_function, ng_outputs, ng_inputs);
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
// 出力をコピー
if (m_ng_backend_name != "CPU") {
for (size_t i = 0; i < output_caches.size(); ++i) {
void* dst_ptr;
std::shared_ptr<ng::runtime::TensorView> dst_tv;
std::tie(dst_ptr, dst_tv) = output_caches[i];
auto ng_element_type = dst_tv->get_tensor().get_element_type();
dst_tv->read(dst_ptr, 0,
dst_tv->get_element_count() * ng_element_type.size());
}
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
// 次のために、入力の後始末
for (int i = 0; i < input_shapes.size(); i++) {
void* src_ptr = (void*)DMAHelper::base(&ctx->input(i));
m_freshness_tracker->MarkFresh(src_ptr, ng_function);
}
}
NGraphEncapsulateOp
src/ngraph_encapsulate_op.cc
REGISTER_OP("NGraphEncapsulate")
.Input("args: Targuments")
.Attr("Targuments: list(type) >= 0")
.Output("results: Tresults")
.Attr("Tresults: list(type) >= 0")
.Attr("ngraph_cluster: int")
.SetIsStateful()
.Doc("nGraph Encapsulation Op. For use by the nGraph JIT only.");
REGISTER_KERNEL_BUILDER(
Name("NGraphEncapsulate").Device(DEVICE_CPU),
ngraph_bridge::NGraphEncapsulateOp);
NGraphEncapsulateOp
Builder::TranslateGraph
src/ngraph_builder.cc
Status Builder::TranslateGraph(
const std::vector<TensorShape>& inputs,
const Graph* input_graph,
shared_ptr<ng::Function>& ng_function) {
vector<Node*> ordered;
GetReversePostOrder(*input_graph, &ordered); // グラフを逆から
vector<const Node*> tf_params;
vector<const Node*> tf_ret_vals;
vector<const Node*> tf_ops;
Builder::TranslateGraph
src/ngraph_builder.cc
for (const auto n : ordered) {
if (n->IsSink() || n->IsSource()) { // 入力 か 出力の場合
continue;
}
if (n->IsControlFlow()) { // 制御フルーはサポートしない
return errors::Unimplemented(
"Encountered a control flow op in the nGraph bridge: ",
n->DebugString());
}
Builder::TranslateGraph
src/ngraph_builder.cc
if (n->type_string() == "_Arg") { // パラメータ
tf_params.push_back(n);
} else if (n->type_string() == "_Retval") { // 戻り値
tf_ret_vals.push_back(n);
} else {
tf_ops.push_back(n); // Op
}
}
Builder::TranslateGraph
src/ngraph_builder.cc
Builder::OpMap ng_op_map;
vector<shared_ptr<ng::op::Parameter>> ng_parameter_list(tf_params.size());
for (auto parm : tf_params) { // パラメータの処理
DataType dtype;
if (GetNodeAttr(parm->attrs(), "T", &dtype) != Status::OK()) {
return errors::InvalidArgument("No data type defined for _Arg");
}
int index;
if (GetNodeAttr(parm->attrs(), "index", &index) != Status::OK()) {
return errors::InvalidArgument("No index defined for _Arg");
}
Builder::TranslateGraph
src/ngraph_builder.cc
// TensorFlow のデータタイプを nGraphのエレメントタイプに変換
ng::element::Type ng_et;
TFDataTypeToNGraphElementType(dtype, &ng_et);
// TensorFlow のテンソルシェイプを nGraphのシェイプに変換
ng::Shape ng_shape;
TFTensorShapeToNGraphShape(inputs[index], &ng_shape);
auto ng_param = make_shared<ng::op::Parameter>(ng_et, ng_shape);
// nGraph の Opに
SaveNgOp(ng_op_map, parm->name(), ng_param);
ng_parameter_list[index] = ng_param;
}
Builder::TranslateGraph
src/ngraph_builder.cc
// Op の処理
for (auto op : tf_ops) {
try {
// TensorFlow の Op を nGraph の Op にマッピング
TRANSLATE_OP_MAP.at(op->type_string())(op, ng_op_map);
} catch (const std::out_of_range&) {
return errors::InvalidArgument("Unsupported Op: ", op->name(), " (",
op->type_string(), ")");
}
}
Builder::TranslateGraph
src/ngraph_builder.cc
vector<shared_ptr<ng::Node>> ng_result_list(tf_ret_vals.size());
// 入力データ
for (auto n : tf_ret_vals) {
if (n->num_inputs() != 1) {
return errors::InvalidArgument("_Retval has ", n->num_inputs(),
" inputs, should have 1");
}
int index;
if (GetNodeAttr(n->attrs(), "index", &index) != Status::OK()) {
return errors::InvalidArgument("No index defined for _Retval");
}
Builder::TranslateGraph
src/ngraph_builder.cc
shared_ptr<ng::Node> result;
GetInputNode(ng_op_map, n, 0, &result);
ng_result_list[index] = result;
}
// nGraphの関数を生成 : nGraph Libraryを利用する
ng_function = make_shared<ng::Function>(ng_result_list, ng_parameter_list);
return Status::OK();
}
Builder::TranslateGraph
src/ngraph_builder.cc
const static std::map<const string,
const function<Status(const Node*, Builder::OpMap&)>>
TRANSLATE_OP_MAP{
{"Abs", TranslateUnaryOp<ngraph::op::Abs>},
{"Add", TranslateBinaryOp<ngraph::op::Add>},
{"AddN", TranslateAddNOp},
{"AvgPool", TranslateAvgPoolOp},
{"AvgPoolGrad", TranslateAvgPoolGradOp},
{"BatchMatMul", TranslateBatchMatMulOp},
{"BiasAdd", TranslateBiasAddOp},
{"BiasAddGrad", TranslateBiasAddGradOp},
{"Cast", TranslateCastOp},
Builder::TranslateGraph
NGraphStubOp
src/ngraph_stub_ops.cc
class NGraphStubOp : public OpKernel {
public:
explicit NGraphStubOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
OP_REQUIRES(ctx, false,
errors::Internal("NGraphStubOp compute kernel called"));
}
};
#define REGISTER_NGRAPH_STUB(name) 
REGISTER_KERNEL_BUILDER( 
Name(name).Device(ngraph_bridge::DEVICE_NGRAPH).Label("ngraph"), 
NGraphStubOp);
NGraphStubOp
src/ngraph_stub_ops.cc
REGISTER_NGRAPH_STUB("Abs");
REGISTER_NGRAPH_STUB("Add");
REGISTER_NGRAPH_STUB("AddN");
REGISTER_NGRAPH_STUB("AvgPool");
REGISTER_NGRAPH_STUB("AvgPoolGrad");
REGISTER_NGRAPH_STUB("BatchMatMul");
REGISTER_NGRAPH_STUB("BiasAdd");
REGISTER_NGRAPH_STUB("BiasAddGrad");
REGISTER_NGRAPH_STUB("Cast");
REGISTER_NGRAPH_STUB("ConcatV2");
REGISTER_NGRAPH_STUB("Conv2D");
REGISTER_NGRAPH_STUB("Conv2DBackpropFilter");
REGISTER_NGRAPH_STUB("Conv2DBackpropInput");
REGISTER_NGRAPH_STUB("DepthwiseConv2dNative");
REGISTER_NGRAPH_STUB("Equal");
REGISTER_NGRAPH_STUB("Exp");
REGISTER_NGRAPH_STUB("ExpandDims");
NGraphStubOp
src/ngraph_stub_ops.cc
REGISTER_NGRAPH_STUB("Fill")
REGISTER_NGRAPH_STUB("Floor");
REGISTER_NGRAPH_STUB("FloorDiv");
REGISTER_NGRAPH_STUB("FloorMod");
;REGISTER_NGRAPH_STUB("FusedBatchNorm");
REGISTER_NGRAPH_STUB("FusedBatchNormGrad");
REGISTER_NGRAPH_STUB("Greater");
REGISTER_NGRAPH_STUB("GreaterEqual");
REGISTER_NGRAPH_STUB("L2Loss");
REGISTER_NGRAPH_STUB("Less");
REGISTER_NGRAPH_STUB("LessEqual");
REGISTER_NGRAPH_STUB("Log");
REGISTER_NGRAPH_STUB("LogicalAnd");
REGISTER_NGRAPH_STUB("LogicalNot");
REGISTER_NGRAPH_STUB("MatMul");
REGISTER_NGRAPH_STUB("Maximum");
REGISTER_NGRAPH_STUB("Minimum");
NGraphStubOp
src/ngraph_stub_ops.cc
REGISTER_NGRAPH_STUB("Mul");
REGISTER_NGRAPH_STUB("Neg");
REGISTER_NGRAPH_STUB("Pack");
REGISTER_NGRAPH_STUB("Pad");
REGISTER_NGRAPH_STUB("Pow");
REGISTER_NGRAPH_STUB("PreventGradient");
REGISTER_NGRAPH_STUB("Prod");
REGISTER_NGRAPH_STUB("RealDiv");
REGISTER_NGRAPH_STUB("Reciprocal");
REGISTER_NGRAPH_STUB("Relu");
REGISTER_NGRAPH_STUB("Relu6");
REGISTER_NGRAPH_STUB("ReluGrad");
REGISTER_NGRAPH_STUB("Reshape");
REGISTER_NGRAPH_STUB("Rsqrt");
REGISTER_NGRAPH_STUB("Sigmoid");
REGISTER_NGRAPH_STUB("Sign");
REGISTER_NGRAPH_STUB("Slice");
NGraphStubOp
src/ngraph_stub_ops.cc
REGISTER_NGRAPH_STUB("Snapshot");
REGISTER_NGRAPH_STUB("Softmax");
REGISTER_NGRAPH_STUB("SparseSoftmaxCrossEntropyWithLogits");
REGISTER_NGRAPH_STUB("Split");
REGISTER_NGRAPH_STUB("SplitV");
REGISTER_NGRAPH_STUB("Square");
REGISTER_NGRAPH_STUB("SquaredDifference");
REGISTER_NGRAPH_STUB("Squeeze");
REGISTER_NGRAPH_STUB("StridedSlice");
REGISTER_NGRAPH_STUB("Sub");
REGISTER_NGRAPH_STUB("Sum");
REGISTER_NGRAPH_STUB("Tanh");
REGISTER_NGRAPH_STUB("Tile");
REGISTER_NGRAPH_STUB("Transpose");
REGISTER_NGRAPH_STUB("Unpack");
NGraphStubOp
ブログ (2007年~) : Vengineerの戯言
 http://blogs.yahoo.co.jp/verification_engineer
SlideShare :
 https://www.slideshare.net/ssuser479fa3
Twitter (2009年~) :
@Vengineer
ありがとうございました

Bridge TensorFlow to run on Intel nGraph backends (v0.5)