TF Operation的建立

本文連結地址:

「連結」

Tensorflow建立OP的過程

以AddNOp為例說明Operation怎樣從ops的定義建立具體的kernel例項。

在Tensorflow Excecutor初始化的時候,會迭代計算圖中的所有節點,對每個節點的Operation進行建立。如下方法params_。create_kernel(n->def(), &item->kernel):

Tensorflow原始碼解讀

// Code in executor。cc Status ExecutorImpl::Initialize() {…… for (const Node* n : graph_->nodes()) { const int id = n->id(); const string& frame_name = cf_info。frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); // See if this node is a root node, and if so, add to root_nodes_。 if (n->in_edges()。empty()) { root_nodes_。push_back(n); } NodeItem* item = gview_。node(id); item->node = n; item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); Status s = params_。create_kernel(n->def(), &item->kernel);

params_。create_kernel是一個前面建立的lambda函式,對它的呼叫最後會呼叫到函式lib->CreateKernel(ndef, kernel)上,lib為FunctionLibraryRuntimeImpl例項:

//Code in direct_session。cc params。create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph。 Even if the function itself // is stateful, the `CallOp` that invokes it is not。 if (!OpSegment::ShouldOwnKernel(lib, ndef。op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { return lib->CreateKernel(ndef, kernel); }; // Kernels created for subgraph nodes need to be cached。 On // cache miss, create_fn() is invoked to create a kernel based // on the function library here + global op registry。 return opseg->FindOrCreate(session_handle_, ndef。name(), kernel, create_fn); };

程式碼CreateKernel最後呼叫CreateNonCachedKernel:

//Code in function。ccStatus FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, OpKernel** kernel) { return CreateKernel(ndef, base_lib_def_, kernel);} Status FunctionLibraryRuntimeImpl::CreateKernel( const NodeDef& ndef, const FunctionLibraryDefinition* lib_def, OpKernel** kernel) { // If a custom kernel creator is given, try that。 Status s; if (custom_kernel_creator_) { std::unique_ptr ret; s = custom_kernel_creator_(this, ndef, &ret); if (s。ok()) { *kernel = ret。release(); return s; } else { VLOG(2) << “Custom creator error: ” << s; // Falls through。 s = Status::OK(); } } if (lib_def->Find(ndef。op()) == nullptr) { // A primitive operation。 Creates the registered kernel。 return CreateNonCachedKernel(device_, this, ndef, graph_def_version_, kernel); }

executor。cc中的CreateNonCachedKernel方法呼叫op_kernel。cc中的CreateOpKernel方法,透過registration->factory->Create(&context)建立Operation。 其中,registration是透過FindKernelRegistration方法在GlobalKernelRegistry()裡面根據名稱AddN找到的。registraction->factory就是在註冊時建立的PtrOpKernelFactory例項。registraction->factory->Create最後就是呼叫new AddNOp(context)了。

//Code in executor。ccStatus CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, const NodeDef& ndef, int graph_def_version, OpKernel** kernel) { const auto device_type = DeviceType(device->attributes()。device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, ndef, graph_def_version, kernel);} //Code in op_kernel。ccStatus CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, const NodeDef& node_def, int graph_def_version, OpKernel** kernel) {…… // Look up kernel registration。 const KernelRegistration* registration; bool was_attr_mismatch; s = FindKernelRegistration(device_type, node_def, ®istration, &was_attr_mismatch);…… // Everything needed for OpKernel construction。 OpKernelConstruction context( device_type, device, allocator, &node_def, op_def, flib, inputs, input_memory_types, outputs, output_memory_types, graph_def_version, &s); *kernel = registration->factory->Create(&context); if (!s。ok()) { delete *kernel; *kernel = nullptr; } return s;}

透過AddNOp的構造方法把context引數傳入:

Session run以後,在節點呼叫Operation計算的時候就會呼叫Compute方法。

//Code in aggregate_ops。cctemplate class AddNOp : public OpKernel { public: explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { if (!ctx->ValidateInputsAreSameShape(this)) return;