/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file lib_api.cc
 * \brief APIs to interact with libraries
 * This API specifies function prototypes to
 * register custom ops, partitioner, and passes
 * for library authors
 * See example/extension/lib_custom_op/README.md
 * See example/extension/lib_subgraph/README.md
 * See example/extension/lib_pass/README.md
 */

#include "mxnet/lib_api.h"

mxnet::ext::MXerrorMsgs* mxnet::ext::MXerrorMsgs::get() {
  static MXerrorMsgs inst;
  return &inst;
}

std::stringstream& mxnet::ext::MXerrorMsgs::add(const char* file, int line) {
  messages.emplace_back();
  messages.back() << file << "[" << line << "]: ";
  return messages.back();
}

int mxnet::ext::MXerrorMsgs::size() {
  return messages.size();
}

const std::string* mxnet::ext::MXerrorMsgs::get(int idx) {
  return new std::string(messages.at(idx).str());
}

mxnet::ext::MXContext::MXContext() : dev_type("error"), dev_id(-1) {}

mxnet::ext::MXContext::MXContext(std::string dev_type_, int dev_id_)
    : dev_type(std::move(dev_type_)), dev_id(dev_id_) {}

mxnet::ext::MXContext::MXContext(const char* dev_type_, int dev_id_)
    : dev_type(dev_type_), dev_id(dev_id_) {}

mxnet::ext::MXContext mxnet::ext::MXContext::CPU() {
  return MXContext("cpu", 0);
}

mxnet::ext::MXContext mxnet::ext::MXContext::GPU() {
  return MXContext("gpu", 0);
}

mxnet::ext::MXContext mxnet::ext::MXContext::CPU(int dev_id) {
  return MXContext("cpu", dev_id);
}

mxnet::ext::MXContext mxnet::ext::MXContext::GPU(int dev_id) {
  return MXContext("gpu", dev_id);
}

void mxnet::ext::MXSparse::set(void* data_ptr,
                               const int64_t* dims,
                               int ndims,
                               void* idx,
                               int64_t num_idx,
                               void* idx_ptr,
                               int64_t num_idx_ptr) {
  data = data_ptr;
  // If CSR, num of non-zero elemets is num_idx,
  // If row sparse, num of elements is num_idx * width.
  data_len = num_idx;
  if (!idx_ptr) {
    for (int i = 1; i < ndims; ++i)
      data_len *= dims[i];
  }

  indices     = reinterpret_cast<int64_t*>(idx);
  indices_len = num_idx;

  if (idx_ptr) {
    indptr     = reinterpret_cast<int64_t*>(idx_ptr);
    indptr_len = num_idx_ptr;
  }
}

mxnet::ext::MXTensor::MXTensor()
    : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {}
mxnet::ext::MXTensor::MXTensor(const MXTensor& oth)
    : data_ptr(oth.data_ptr),
      shape(oth.shape),
      dtype(oth.dtype),
      verID(oth.verID),
      ctx(oth.ctx),
      stype(oth.stype) {
  setDLTensor();
}

mxnet::ext::MXTensor::MXTensor(void* data_ptr,
                               std::vector<int64_t> shape,
                               MXDType dtype,
                               size_t vID,
                               MXContext mx_ctx,
                               MXStorageType stype)
    : data_ptr(data_ptr),
      shape(std::move(shape)),
      dtype(dtype),
      verID(vID),
      ctx(std::move(mx_ctx)),
      stype(stype) {
  setDLTensor();
}

void mxnet::ext::MXTensor::setTensor(void* dptr,
                                     MXDType type,
                                     const int64_t* dims,
                                     int ndims,
                                     size_t vID,
                                     MXContext mx_ctx,
                                     MXStorageType storage_type) {
  data_ptr = dptr;
  dtype    = type;
  verID    = vID;
  ctx      = mx_ctx;
  stype    = storage_type;
  shape.clear();
  for (int j = 0; j < ndims; j++) {
    shape.push_back(dims[j]);
  }
  setDLTensor();
}

void mxnet::ext::MXTensor::setDLTensor() {
  dltensor.data          = data_ptr;
  dltensor.ndim          = shape.size();
  dltensor.shape         = const_cast<int64_t*>(shape.data());
  dltensor.strides       = nullptr;
  dltensor.byte_offset   = 0;
  dltensor.dtype.lanes   = 1;
  dltensor.ctx.device_id = ctx.dev_id;
  if (ctx.dev_type == "cpu")
    dltensor.ctx.device_type = kDLCPU;
  else if (ctx.dev_type == "gpu")
    dltensor.ctx.device_type = kDLGPU;
  else if (ctx.dev_type == "opencl")
    dltensor.ctx.device_type = kDLOpenCL;
  else if (ctx.dev_type == "vulcan")
    dltensor.ctx.device_type = kDLVulkan;
  else if (ctx.dev_type == "metal")
    dltensor.ctx.device_type = kDLMetal;
  else if (ctx.dev_type == "vpi")
    dltensor.ctx.device_type = kDLVPI;
  else if (ctx.dev_type == "rocm")
    dltensor.ctx.device_type = kDLROCM;
  else
    dltensor.ctx.device_type = kDLExtDev;
  switch (dtype) {
    case kFloat32:
      dltensor.dtype.code = kDLFloat;
      dltensor.dtype.bits = 32;
      break;
    case kFloat64:
      dltensor.dtype.code = kDLFloat;
      dltensor.dtype.bits = 64;
      break;
    case kFloat16:
      dltensor.dtype.code = kDLFloat;
      dltensor.dtype.bits = 16;
      break;
    case kUint8:
      dltensor.dtype.code = kDLUInt;
      dltensor.dtype.bits = 8;
      break;
    case kInt32:
      dltensor.dtype.code = kDLInt;
      dltensor.dtype.bits = 32;
      break;
    case kInt8:
      dltensor.dtype.code = kDLInt;
      dltensor.dtype.bits = 8;
      break;
    case kInt64:
      dltensor.dtype.code = kDLInt;
      dltensor.dtype.bits = 64;
      break;
    default:
      dltensor.dtype.code = 0;
      dltensor.dtype.bits = 0;
      throw std::runtime_error(
          "Error! Invalid dtype flag: " + std::to_string(static_cast<int>(dtype)) +
          " when constructing MXTensor");
  }
}

int64_t mxnet::ext::MXTensor::size() const {
  int64_t size = 1;
  for (auto& s : shape)
    size *= s;
  return size;
}

bool mxnet::ext::MXTensor::isSame(const MXTensor& oth) const {
  return data_ptr == oth.data_ptr && dtype == oth.dtype && verID == oth.verID &&
         ctx.dev_type == oth.ctx.dev_type && ctx.dev_id == oth.ctx.dev_id && shape == oth.shape &&
         stype == oth.stype;
}

mxnet::ext::PassResource::PassResource(std::unordered_map<std::string, MXTensor>* new_args,
                                       std::unordered_map<std::string, MXTensor>* new_aux,
                                       nd_malloc_t nd_malloc,
                                       const void* nd_alloc)
    : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {}

mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_arg(const std::string& name,
                                                          const std::vector<int64_t>& shapes,
                                                          const mxnet::ext::MXContext& ctx,
                                                          mxnet::ext::MXDType dtype) const {
  void* data;
  nd_malloc_(nd_alloc_,
             shapes.data(),
             shapes.size(),
             ctx.dev_type.c_str(),
             ctx.dev_id,
             dtype,
             name.c_str(),
             1,
             &data);
  MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
  (*new_args_)[name] = tensor;
  return &(new_args_->at(name));
}

mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_aux(const std::string& name,
                                                          const std::vector<int64_t>& shapes,
                                                          const mxnet::ext::MXContext& ctx,
                                                          mxnet::ext::MXDType dtype) const {
  void* data;
  nd_malloc_(nd_alloc_,
             shapes.data(),
             shapes.size(),
             ctx.dev_type.c_str(),
             ctx.dev_id,
             dtype,
             name.c_str(),
             0,
             &data);
  MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
  (*new_aux_)[name] = tensor;
  return &(new_aux_->at(name));
}

mxnet::ext::OpResource::OpResource(xpu_malloc_t cpu_malloc_fp,
                                   void* cpu_alloc_fp,
                                   xpu_malloc_t gpu_malloc_fp,
                                   void* gpu_alloc_fp,
                                   void* stream,
                                   sparse_malloc_t sparse_malloc_fp,
                                   void* sparse_alloc_fp,
                                   void* rng_cpu_states,
                                   void* rng_gpu_states)
    : cpu_malloc(cpu_malloc_fp),
      gpu_malloc(gpu_malloc_fp),
      cpu_alloc(cpu_alloc_fp),
      gpu_alloc(gpu_alloc_fp),
      cuda_stream(stream),
      sparse_malloc(sparse_malloc_fp),
      sparse_alloc(sparse_alloc_fp),
      rand_cpu_states(rng_cpu_states),
      rand_gpu_states(rng_gpu_states) {}

void* mxnet::ext::OpResource::alloc_cpu(int size) const {
  return cpu_malloc(cpu_alloc, size);
}

void* mxnet::ext::OpResource::alloc_gpu(int size) const {
  return gpu_malloc(gpu_alloc, size);
}

void mxnet::ext::OpResource::alloc_sparse(mxnet::ext::MXSparse* sparse,
                                          int index,
                                          int indices_len,
                                          int indptr_len) const {
  sparse_malloc(sparse_alloc,
                index,
                indices_len,
                indptr_len,
                &(sparse->data),
                &(sparse->indices),
                &(sparse->indptr));
}

mxnet::ext::mx_cpu_rand_t* mxnet::ext::OpResource::get_cpu_rand_states() const {
  return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
}

std::string mxnet::ext::getShapeAt(const std::string& shape, unsigned index) {
  int idx = 1;  // start at 1 to skip the first square bracket [
  // find the beginning of the output shape for the particular output index
  for (unsigned x = 0; x < index; x++)
    idx = shape.find('[', idx + 1);
  int stop = shape.find(']', idx);  // find stop index for this output shape
  // add this shape to the list
  return shape.substr(idx, stop - idx + 1);
}

std::string mxnet::ext::getDtypeAt(const std::string& dtype, unsigned index) {
  // find the beginning of the output dtype for the particular output index
  int idx = 0;
  for (unsigned x = 0; x < index; x++)
    idx = dtype.find(',', idx + 1);
  int stop = dtype.find(',', idx + 1);  // find stop index for this output dtype
  if (stop == -1)
    stop = dtype.find(']', idx + 1);
  return dtype.substr(idx + 1, stop - idx - 1);
}

mxnet::ext::JsonVal::JsonVal() : type(ERR), num(-1), str("") {}
mxnet::ext::JsonVal::JsonVal(mxnet::ext::JsonType t) : type(t), num(-1), str("") {}
mxnet::ext::JsonVal::JsonVal(std::string s) : type(STR), num(-1), str(std::move(s)) {}
mxnet::ext::JsonVal::JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {}
mxnet::ext::JsonVal::JsonVal(JsonType t, int n, std::string s)
    : type(t), num(n), str(std::move(s)) {}

bool mxnet::ext::JsonVal::operator<(const mxnet::ext::JsonVal& o) const {
  // for string JSON objects compare the string
  if (type == STR)
    return type == o.type && str < o.str;
  // for number JSON objects compare the number
  if (type == NUM)
    return type == o.type && num < o.num;
  // for list JSON objects, compare the size of list, and then each object in the list
  if (type == LIST) {
    if (list.size() != o.list.size())
      return false;
    for (unsigned int i = 0; i < list.size(); i++)
      if (list[i] < o.list[i])
        return false;  // if we find an object that doesnt match return
    return true;       // all objects in lists matched
  }
  // for map JSON objects, compare the size of map, and then each key/value in the maps
  if (type == MAP) {
    if (map.size() != o.map.size())
      return false;
    for (auto& item : map) {
      // if one map is missing a key in another return
      if (o.map.find(item.first) == o.map.end())
        return false;
      if (item.second < o.map.at(item.first))
        return false;
    }
    return true;
  }
  return type < o.type;
}

std::string mxnet::ext::JsonVal::dump() const {
  std::string ret;
  switch (type) {
    case ERR:
      ret = "json(Error)";
      break;
    case STR:
      ret = "\"" + str + "\"";
      break;
    case NUM:
      ret = str;
      break;
    case LIST:
      ret = "[";
      for (unsigned i = 0; i < list.size(); i++) {
        auto& item = list[i];
        ret += item.dump();
        if (i < list.size() - 1)
          ret += ",";
      }
      ret += "]";
      break;
    case MAP:
      ret          = "{";
      unsigned cnt = 0;
      for (auto& item : map) {
        ret += item.first.dump() + " : " + item.second.dump();
        if (cnt++ < map.size() - 1)
          ret += ",";
      }
      ret += "}";
      break;
  }
  return ret;
}

mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) {
  unsigned int idx = 0;
  return JsonVal::parse(json, &idx);
}

mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) {
  JsonVal ret(STR);
  while (*idx < json.size()) {
    if (json[*idx] == '"' &&
        (ret.str.size() == 0 || (ret.str.size() > 0 && ret.str.back() != '\\'))) {
      ++(*idx);
      return ret;
    } else {
      ret.str += json[*idx];
      ++(*idx);
    }
  }
  MX_ERROR_MSG << "Error! Unable to parse string: '" << json.substr(*idx) << "'" << std::endl;
  return JsonVal();
}

mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_num(const std::string& json, unsigned int* idx) {
  JsonVal ret(NUM);
  while (*idx < json.size()) {
    if (json[*idx] >= '0' && json[*idx] <= '9') {
      ret.str += json[*idx];
      ++(*idx);
    } else {
      break;
    }
  }
  ret.num = std::stoi(ret.str);
  return ret;
}

mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_list(const std::string& json, unsigned int* idx) {
  JsonVal ret(LIST);
  while (*idx < json.size()) {
    if (json[*idx] == ']') {
      ++(*idx);
      return ret;
    } else {
      JsonVal item = JsonVal::parse(json, idx);
      if (item.type != ERR)
        ret.list.push_back(item);
    }
  }
  MX_ERROR_MSG << "Error! Unable to parse list: '" << json.substr(*idx) << "'" << std::endl;
  return JsonVal();
}

mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_map(const std::string& json, unsigned int* idx) {
  JsonVal ret(MAP), key;
  while (*idx < json.size()) {
    if (json[*idx] == '}') {
      ++(*idx);
      return ret;
    } else {
      JsonVal item = JsonVal::parse(json, idx);
      if (key.type == ERR) {
        key = item;
      } else {
        ret.map[key] = item;
        key.type     = ERR;
      }
    }
  }
  MX_ERROR_MSG << "Error! Unable to parse map: '" << json.substr(*idx) << "'" << std::endl;
  return mxnet::ext::JsonVal();
}

mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned int* idx) {
  JsonVal ret;
  while (*idx < json.size()) {
    if (json[*idx] == '"') {
      ++(*idx);
      ret = JsonVal::parse_string(json, idx);
    } else if (json[*idx] >= '0' && json[*idx] <= '9') {
      ret = JsonVal::parse_num(json, idx);
    } else if (json[*idx] == '[') {
      ++(*idx);
      ret = JsonVal::parse_list(json, idx);
    } else if (json[*idx] == '{') {
      ++(*idx);
      ret = JsonVal::parse_map(json, idx);
    } else if (json[*idx] == ']' || json[*idx] == '}') {
      return ret;
    }
    if (ret.type != ERR)
      return ret;
    ++(*idx);
  }
  return ret;
}

std::string mxnet::ext::JsonVal::toString() const {
  std::string ret;
  switch (type) {
    case ERR:
      ret = "json(Error)";
      break;
    case STR:
      ret = "json(STR:" + str + ")";
      break;
    case NUM:
      ret = "json(INT:" + str + ")";
      break;
    case LIST:
      ret = "json(LIST:[";
      for (auto& item : list)
        ret += item.toString() + ",";
      ret += "])";
      break;
    case MAP:
      ret = "json(MAP:{";
      for (auto& item : map)
        ret += item.first.toString() + " : " + item.second.toString() + ",";
      ret += "})";
      break;
  }
  return ret;
}

mxnet::ext::Node::Node() {
  tensor = nullptr;
}

void mxnet::ext::Node::_setPassResource(mxnet::ext::PassResource* res_) {
  res = res_;
}

void mxnet::ext::Node::alloc_arg(const std::vector<int64_t>& shapes,
                                 const mxnet::ext::MXContext& ctx,
                                 mxnet::ext::MXDType dtype) {
  if (!res)
    throw std::runtime_error("Node not initialized. Cannot use alloc_arg outside of graph passes.");
  tensor = res->alloc_arg(name, shapes, ctx, dtype);
}

void mxnet::ext::Node::alloc_aux(const std::vector<int64_t>& shapes,
                                 const mxnet::ext::MXContext& ctx,
                                 mxnet::ext::MXDType dtype) {
  if (!res)
    throw std::runtime_error("Node not initialized. Cannot use alloc_aux outside of graph passes.");
  tensor = res->alloc_aux(name, shapes, ctx, dtype);
}

mxnet::ext::Graph::Graph() : res(nullptr) {}

mxnet::ext::Graph::~Graph() {
  for (auto& node : nodes)
    delete node;
}

mxnet::ext::Graph* mxnet::ext::Graph::fromString(const std::string& json) {
  JsonVal val = JsonVal::parse(json);
  return fromJson(val);
}

mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) {
  // get nodes list
  JsonVal nodes = val.map[JsonVal("nodes")];
  Graph* g      = new Graph();

  std::map<int, Node*> nodeMap;
  // loop over nodes
  for (int i = 0; i < nodes.list.size(); i++) {
    Node* n = new Node();
    g->nodes.push_back(n);
    JsonVal node = nodes.list[i];

    // set the op info
    n->op   = node.map[JsonVal("op")].str;
    n->name = node.map[JsonVal("name")].str;

    // if op is null it is an input to the graph
    if (n->op.compare("null") == 0)
      g->inputs.push_back(n);

    // set attrs
    JsonVal attributes = node.map[JsonVal("attrs")];
    for (auto& kv : attributes.map) {
      n->attrs[kv.first.str] = kv.second.str;
    }

    // set subgraphs, parsing each into a graph
    if (node.map.count(JsonVal("subgraphs")) > 0) {
      JsonVal subgraphs = node.map[JsonVal("subgraphs")];
      for (auto& subgraph : subgraphs.list) {
        n->subgraphs.push_back(fromJson(subgraph));
      }
    }

    // set node inputs
    JsonVal node_inputs = node.map[JsonVal("inputs")];
    n->inputs.resize(node_inputs.list.size());
    for (int j = 0; j < node_inputs.list.size(); j++) {
      JsonVal input    = node_inputs.list[j];
      NodeEntry& entry = n->inputs[j];
      // get pointer to other node
      entry.node = nodeMap[input.list[0].num];
      // get the other node's output index
      entry.entry = input.list[1].num;
      // set other nodes output as connected to this node
      entry.node->outputs.push_back({n, j});
    }
    nodeMap[i] = n;
  }

  // set graph level outputs
  JsonVal& heads = val.map[JsonVal("heads")];
  g->outputs.resize(heads.list.size());
  for (int i = 0; i < heads.list.size(); i++) {
    JsonVal head        = heads.list[i];
    g->outputs[i].node  = nodeMap[head.list[0].num];
    g->outputs[i].entry = head.list[1].num;
  }

  // add all attributes to the graph
  for (auto& kv : val.map) {
    if (kv.first.str.compare("nodes") != 0 && kv.first.str.compare("heads") != 0 &&
        kv.first.str.compare("node_row_ptr") != 0 && kv.first.str.compare("arg_nodes") != 0) {
      g->attrs[kv.first.str] = kv.second;
    }
  }
  return g;
}

/* \brief convert graph object back to JSON object */
mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const {
  // top level object is a map
  JsonVal val(MAP);

  // add attributes
  for (auto& kv : attrs) {
    val.map[JsonVal(kv.first)] = kv.second;
  }

  // sort graph nodes in topological order, create mapping of node to index
  std::map<Node*, int> nodeMap;
  std::vector<Node*> sorted = topological_sort();
  // nodes are in reverse topological order in the vector (back is first)
  // so loop from end to front over the vector 'sorted'
  for (int i = sorted.size() - 1; i >= 0; i--) {
    nodeMap[sorted[i]] = sorted.size() - 1 - i;
  }

  // create node_row_ptr entry
  val.map[JsonVal("node_row_ptr")] = JsonVal(LIST);
  JsonVal& node_row_ptr            = val.map[JsonVal("node_row_ptr")];
  for (int i = 0; i < nodes.size(); i++)
    node_row_ptr.list.emplace_back(i);

  // add all input nodes
  val.map[JsonVal("arg_nodes")] = JsonVal(LIST);
  JsonVal& arg_nodes            = val.map[JsonVal("arg_nodes")];
  for (auto& input : inputs)
    arg_nodes.list.emplace_back(nodeMap[input]);

  // add all output nodes
  val.map[JsonVal("heads")] = JsonVal(LIST);
  JsonVal& heads            = val.map[JsonVal("heads")];
  for (int i = 0; i < outputs.size(); i++) {
    heads.list.emplace_back(LIST);
    JsonVal& out = heads.list[i];
    out.list.emplace_back(nodeMap[outputs[i].node]);
    out.list.emplace_back(outputs[i].entry);
    out.list.emplace_back(0);
  }

  // add all graph nodes
  val.map[JsonVal("nodes")] = JsonVal(LIST);
  JsonVal& nodes_           = val.map[JsonVal("nodes")];
  for (int i = sorted.size() - 1; i >= 0; i--) {
    // each node is a map
    nodes_.list.emplace_back(MAP);
    Node* n     = sorted[i];
    JsonVal& n_ = nodes_.list[nodes_.list.size() - 1];

    n_.map[JsonVal("op")]     = JsonVal(n->op);
    n_.map[JsonVal("name")]   = JsonVal(n->name);
    n_.map[JsonVal("inputs")] = JsonVal(LIST);

    // add inputs for this node
    JsonVal& inputs_ = n_.map[JsonVal("inputs")];
    for (int j = 0; j < n->inputs.size(); j++) {
      inputs_.list.emplace_back(LIST);
      NodeEntry& entry = n->inputs[j];
      JsonVal& in      = inputs_.list[j];
      in.list.emplace_back(nodeMap[entry.node]);
      in.list.emplace_back(entry.entry);
      in.list.emplace_back(0);
    }

    // add subgraphs for this node, convert each back to JSON
    if (n->subgraphs.size() > 0) {
      n_.map[JsonVal("subgraphs")] = JsonVal(LIST);
      JsonVal& subgraphs_          = n_.map[JsonVal("subgraphs")];
      for (Graph* subgraph : n->subgraphs) {
        subgraphs_.list.push_back(subgraph->toJson());
      }
    }

    // add attributes for this node
    n_.map[JsonVal("attrs")] = JsonVal(MAP);
    JsonVal& attrs_          = n_.map[JsonVal("attrs")];
    for (auto& kv : n->attrs) {
      attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second);
    }
  }
  return val;
}

/* \brief convert graph object to JSON string */
std::string mxnet::ext::Graph::toString() const {
  return toJson().dump();
}

/* \brief visits a node "n" */
void mxnet::ext::Graph::_dfs_util(Node* n,
                                  std::unordered_set<mxnet::ext::Node*>* to_visit,
                                  std::function<void(mxnet::ext::Node*)> handler) const {
  to_visit->erase(n);  // remove node now that we're visiting it
  for (NodeEntry& e : n->outputs) {
    Node* o = e.node;
    if (to_visit->count(o) != 0) {
      _dfs_util(o, to_visit, handler);  // visit neighbor
    }
  }
  handler(n);  // post-order visit this node
}

/* \brief post-order DFS graph traversal */
void mxnet::ext::Graph::DFS(std::function<void(Node*)> handler) const {
  std::unordered_set<Node*> to_visit;
  // put all nodes in set to visit
  for (auto& n : nodes)
    to_visit.insert(n);
  // visit all inputs first
  for (auto& i : inputs)
    if (to_visit.count(i) != 0)
      _dfs_util(i, &to_visit, handler);
  // visit any nodes left
  while (to_visit.size() > 0)
    _dfs_util(*(to_visit.begin()), &to_visit, handler);
}

/* \brief sort graph nodes in topological order */
std::vector<mxnet::ext::Node*> mxnet::ext::Graph::topological_sort() const {
  std::vector<mxnet::ext::Node*> sorted;
  auto handler = [&](mxnet::ext::Node* n) {
    sorted.push_back(n);  // when visiting each node, add it in order to the vector
  };
  DFS(handler);
  return sorted;
}

/* \brief print out graph details */
void mxnet::ext::Graph::print(int indent) const {
  std::string space = "";
  for (int i = 0; i < indent; i++)
    space += " ";

  std::cout << space << "########### Graph #############" << std::endl;
  std::cout << space << "attributes: " << std::endl;
  for (auto& kv : attrs)
    std::cout << space << "\t" << kv.first << " : " << kv.second.str << std::endl;
  std::cout << space << "inputs: " << inputs.size() << std::endl;
  std::cout << space << "outputs: " << outputs.size() << std::endl;
  std::cout << space << "nodes: " << nodes.size() << std::endl;
  std::vector<mxnet::ext::Node*> sorted = topological_sort();
  // loop over each node and print out its inputs/outputs
  for (int i = sorted.size() - 1; i >= 0; i--) {
    std::cout << space << "Node: " << sorted[i]->name << std::endl;
    for (auto& input : sorted[i]->inputs) {
      std::cout << space << "\tInput: " << input.node->name << " " << input.entry << std::endl;
    }
    for (auto& output : sorted[i]->outputs) {
      std::cout << space << "\tOutput: " << output.node->name << " " << output.entry << std::endl;
    }
    if (sorted[i]->subgraphs.size() > 0) {
      for (auto& subgraph : sorted[i]->subgraphs) {
        std::cout << space << "\tSubgraph:" << std::endl;
        subgraph->print(indent + 2);
      }
    }
  }
  std::cout << space << "###############################" << std::endl;
}

/* \brief add a new node to this graph */
mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) {
  Node* n = new Node();
  nodes.push_back(n);
  n->name = name;
  n->op   = op;
  if (res)
    n->_setPassResource(res);
  return n;
}

/* \brief get node at index in graph */
mxnet::ext::Node* mxnet::ext::Graph::getNode(size_t idx) {
  return nodes[idx];
}

/* \brief get const node at index in const graph */
const mxnet::ext::Node* mxnet::ext::Graph::getNode(size_t idx) const {
  return nodes.at(idx);
}

/* \brief get attribute on graph */
const mxnet::ext::JsonVal& mxnet::ext::Graph::getAttr(const std::string& key) const {
  return attrs.at(key);
}

/* \brief get number of nodes in the graph */
size_t mxnet::ext::Graph::size() const {
  return nodes.size();
}

// internally set passResource to enable tensor allocation for graph passes
void mxnet::ext::Graph::_setPassResource(PassResource* res_) {
  res = res_;
  // set passResource for each node
  for (Node* node : nodes) {
    node->_setPassResource(res);
  }
}

// internally set arg/aux params when available
void mxnet::ext::Graph::_setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
                                   std::unordered_map<std::string, mxnet::ext::MXTensor>* aux) {
  // set params for each input node
  for (Node* node : inputs) {
    std::string name = node->name;
    if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0)
      // mapping name back to original node name from subgraph input name
      name = node->attrs["argName"];
    if (args->count(name) > 0)
      node->tensor = &args->at(name);
    else if (aux->count(name) > 0)
      node->tensor = &aux->at(name);
  }
}

mxnet::ext::CustomOp::CustomOp(const char* op_name)
    : name(op_name),
      parse_attrs(nullptr),
      infer_type(nullptr),
      infer_storage_type(nullptr),
      infer_shape(nullptr),
      mutate_inputs(nullptr),
      isSGop(false) {}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setForward(mxnet::ext::fcomp_t fcomp, const char* ctx) {
  if (forward_ctx_map.count(ctx) > 0)
    raiseDuplicateContextError();
  forward_ctx_map[ctx] = fcomp;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setBackward(mxnet::ext::fcomp_t fgrad,
                                                        const char* ctx) {
  if (backward_ctx_map.count(ctx) > 0)
    raiseDuplicateContextError();
  backward_ctx_map[ctx] = fgrad;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setParseAttrs(mxnet::ext::parseAttrs_t func) {
  parse_attrs = func;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferType(mxnet::ext::inferType_t func) {
  infer_type = func;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferSType(mxnet::ext::inferSType_t func) {
  infer_storage_type = func;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setInferShape(mxnet::ext::inferShape_t func) {
  infer_shape = func;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setMutateInputs(mxnet::ext::mutateInputs_t func) {
  mutate_inputs = func;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setCreateOpState(mxnet::ext::createOpState_t func,
                                                             const char* ctx) {
  if (create_op_ctx_map.count(ctx) > 0)
    raiseDuplicateContextError();
  create_op_ctx_map[ctx] = func;
  return *this;
}

mxnet::ext::CustomOp& mxnet::ext::CustomOp::setIsSubgraphOp() {
  isSGop = true;
  return *this;
}

void mxnet::ext::CustomOp::mapToVector() {
  for (auto kv : forward_ctx_map) {
    forward_ctx_cstr.push_back(kv.first);
    forward_fp.push_back(kv.second);
  }
  for (auto kv : backward_ctx_map) {
    backward_ctx_cstr.push_back(kv.first);
    backward_fp.push_back(kv.second);
  }
  for (auto kv : create_op_ctx_map) {
    create_op_ctx_cstr.push_back(kv.first);
    create_op_fp.push_back(kv.second);
  }
}

void mxnet::ext::CustomOp::raiseDuplicateContextError() {
  std::string op_name_str(name);
  throw std::runtime_error(
      "Error! Error! Cannot register multiple functions under same context for operator '" +
      op_name_str + "'");
}

mxnet::ext::CustomStatefulOp::CustomStatefulOp() : ignore_warn(false), created(false) {}
mxnet::ext::CustomStatefulOp::~CustomStatefulOp() = default;

mxnet::ext::CustomStatefulOpWrapper::~CustomStatefulOpWrapper() {
  destroy_(instance);
}

mxnet::ext::CustomPass::CustomPass() : name("ERROR") {}
mxnet::ext::CustomPass::CustomPass(const char* pass_name) : name(pass_name) {}
mxnet::ext::CustomPass& mxnet::ext::CustomPass::setBody(graphPass_t fn) {
  pass = fn;
  return *this;
}

mxnet::ext::CustomPartitioner::CustomPartitioner() : name("ERROR") {}
mxnet::ext::CustomPartitioner::CustomPartitioner(const char* backend_name) : name(backend_name) {}

mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::addStrategy(const char* prop_name,
                                                                          const char* sg_name) {
  strategies.push_back(prop_name);
  op_names.push_back(sg_name);
  return *this;
}

mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setSupportedOps(
    const char* prop_name,
    mxnet::ext::supportedOps_t fn) {
  supported_map[std::string(prop_name)] = fn;
  return *this;
}

mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setCreateSelector(
    const char* prop_name,
    mxnet::ext::createSelector_t fn) {
  selector_map[std::string(prop_name)] = fn;
  return *this;
}

mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setReviewSubgraph(
    const char* prop_name,
    mxnet::ext::reviewSubgraph_t fn) {
  review_map[std::string(prop_name)] = fn;
  return *this;
}

mxnet::ext::supportedOps_t mxnet::ext::CustomPartitioner::getSupportedOps(int stg_id) {
  std::string prop(strategies[stg_id]);
  if (supported_map.count(prop) > 0)
    return supported_map[prop];
  else
    return nullptr;
}

mxnet::ext::createSelector_t mxnet::ext::CustomPartitioner::getCreateSelector(int stg_id) {
  std::string prop(strategies[stg_id]);
  if (selector_map.count(prop) > 0)
    return selector_map[prop];
  else
    return nullptr;
}

mxnet::ext::reviewSubgraph_t mxnet::ext::CustomPartitioner::getReviewSubgraph(int stg_id) {
  std::string prop(strategies[stg_id]);
  if (review_map.count(prop) > 0)
    return review_map[prop];
  else
    return nullptr;
}

/*! \brief returns MXNet library version */
MX_INT_RET _opVersion() {
  return MX_LIBRARY_VERSION;
}

/*! \brief returns number of ops registered in this library */
MX_INT_RET _opRegSize() {
  return mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->size();
}

/*! \brief returns operator registration at specified index */
MX_VOID_RET _opRegGet(int idx,
                      const char** name,
                      int* isSGop,
                      const char*** forward_ctx,
                      mxnet::ext::fcomp_t** forward_fp,
                      int* forward_count,
                      const char*** backward_ctx,
                      mxnet::ext::fcomp_t** backward_fp,
                      int* backward_count,
                      const char*** create_op_ctx,
                      mxnet::ext::createOpState_t** create_op_fp,
                      int* create_op_count,
                      mxnet::ext::parseAttrs_t* parse,
                      mxnet::ext::inferType_t* type,
                      mxnet::ext::inferSType_t* stype,
                      mxnet::ext::inferShape_t* shape,
                      mxnet::ext::mutateInputs_t* mutate) {
  mxnet::ext::CustomOp& op = mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->get(idx);
  *name                    = op.name;
  *parse                   = op.parse_attrs;
  *type                    = op.infer_type;
  *stype                   = op.infer_storage_type;
  *shape                   = op.infer_shape;
  *mutate                  = op.mutate_inputs;
  *isSGop                  = op.isSGop;
  op.mapToVector();
  *forward_ctx     = op.forward_ctx_cstr.data();
  *forward_fp      = op.forward_fp.data();
  *forward_count   = op.forward_fp.size();
  *backward_ctx    = op.backward_ctx_cstr.data();
  *backward_fp     = op.backward_fp.data();
  *backward_count  = op.backward_fp.size();
  *create_op_ctx   = op.create_op_ctx_cstr.data();
  *create_op_fp    = op.create_op_fp.data();
  *create_op_count = op.create_op_fp.size();
}

/*! \brief calls free from the external library for library allocated arrays */
MX_VOID_RET _opCallFree(void* ptr) {
  free(ptr);
}

/*! \brief returns status of calling parse attributes function for operator from library */
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs,
                             const char* const* keys,
                             const char* const* vals,
                             int num,
                             int* num_in,
                             int* num_out) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }
  return parseAttrs(attrs, num_in, num_out);
}

/*! \brief returns status of calling inferShape function for operator from library */
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape,
                             const char* const* keys,
                             const char* const* vals,
                             int num,
                             unsigned int** inshapes,
                             int* indims,
                             int num_in,
                             unsigned int*** mod_inshapes,
                             int** mod_indims,
                             unsigned int*** outshapes,
                             int** outdims,
                             int num_out) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }

  // create a vector of shapes for inputs
  std::vector<std::vector<unsigned int> > in_shapes(num_in);
  for (int i = 0; i < num_in; i++) {
    for (int j = 0; j < indims[i]; j++) {
      in_shapes[i].push_back(inshapes[i][j]);
    }
  }

  // create a vector of shapes for outputs
  std::vector<std::vector<unsigned int> > out_shapes(num_out);

  int retval = inferShape(attrs, &in_shapes, &out_shapes);
  if (!retval)
    return retval;

  // allocate space for modified input dims, shape
  *mod_indims   = static_cast<int*>(malloc(num_in * sizeof(int)));
  *mod_inshapes = static_cast<unsigned**>(malloc(num_in * sizeof(unsigned*)));

  // copy modified input shapes
  for (int i = 0; i < num_in; i++) {
    (*mod_indims)[i]   = in_shapes[i].size();
    (*mod_inshapes)[i] = static_cast<unsigned*>(malloc((*mod_indims)[i] * sizeof(unsigned)));
    for (int j = 0; j < (*mod_indims)[i]; j++) {
      (*mod_inshapes)[i][j] = in_shapes[i][j];
    }
  }

  // allocate space for output dims, shape
  *outdims   = static_cast<int*>(malloc(num_out * sizeof(int)));
  *outshapes = static_cast<unsigned**>(malloc(num_out * sizeof(unsigned*)));

  // copy output shapes
  for (int i = 0; i < num_out; i++) {
    (*outdims)[i]   = out_shapes[i].size();
    (*outshapes)[i] = static_cast<unsigned*>(malloc((*outdims)[i] * sizeof(unsigned)));
    for (int j = 0; j < (*outdims)[i]; j++) {
      (*outshapes)[i][j] = out_shapes[i][j];
    }
  }
  return retval;
}

/*! \brief returns status of calling inferType function for operator from library */
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType,
                            const char* const* keys,
                            const char* const* vals,
                            int num,
                            int* intypes,
                            int num_in,
                            int* outtypes,
                            int num_out) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }

  // create a vector of types for inputs
  std::vector<int> in_types(num_in);
  for (int i = 0; i < num_in; i++) {
    in_types[i] = intypes[i];
  }

  // create a vector of types for outputs
  std::vector<int> out_types(num_out, -1);

  int retval = inferType(attrs, &in_types, &out_types);
  if (!retval)
    return retval;

  // copy modified input types
  for (int i = 0; i < num_in; i++) {
    intypes[i] = in_types[i];
  }
  // copy output types
  for (int i = 0; i < num_out; i++) {
    outtypes[i] = out_types[i];
  }

  return retval;
}

/*! \brief returns status of calling inferSType function for operator from library */
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType,
                             const char* const* keys,
                             const char* const* vals,
                             int num,
                             int* instypes,
                             int num_in,
                             int* outstypes,
                             int num_out) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }

  // create a vector of types for inputs
  std::vector<int> in_stypes(num_in);
  for (int i = 0; i < num_in; i++) {
    in_stypes[i] = instypes[i];
  }

  // create a vector of types for outputs
  std::vector<int> out_stypes(num_out, -1);

  int retval = inferSType(attrs, &in_stypes, &out_stypes);

  if (!retval)
    return retval;

  // copy modified input storage types
  for (int i = 0; i < num_in; i++) {
    instypes[i] = in_stypes[i];
  }
  // copy output storage types
  for (int i = 0; i < num_out; i++) {
    outstypes[i] = out_stypes[i];
  }

  return retval;
}

/*! \brief returns status of calling Forward/Backward function for operator from library */
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp,
                           const char* const* keys,
                           const char* const* vals,
                           int num,
                           const int64_t** inshapes,
                           int* indims,
                           void** indata,
                           int* intypes,
                           size_t* inIDs,
                           const char** indev_type,
                           int* indev_id,
                           int num_in,
                           const int64_t** outshapes,
                           int* outdims,
                           void** outdata,
                           int* outtypes,
                           size_t* outIDs,
                           const char** outdev_type,
                           int* outdev_id,
                           int num_out,
                           mxnet::ext::xpu_malloc_t cpu_malloc,
                           void* cpu_alloc,
                           mxnet::ext::xpu_malloc_t gpu_malloc,
                           void* gpu_alloc,
                           void* cuda_stream,
                           mxnet::ext::sparse_malloc_t sparse_malloc,
                           void* sparse_alloc,
                           int* instypes,
                           int* outstypes,
                           void** in_indices,
                           void** out_indices,
                           void** in_indptr,
                           void** out_indptr,
                           int64_t* in_indices_shapes,
                           int64_t* out_indices_shapes,
                           int64_t* in_indptr_shapes,
                           int64_t* out_indptr_shapes,
                           void* rng_cpu_states,
                           void* rng_gpu_states) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }

  // create a vector of tensors for inputs
  std::vector<mxnet::ext::MXTensor> inputs(num_in);
  // create a vector for sparse inputs
  std::vector<mxnet::ext::MXSparse> in_sparse(num_in);

  for (int i = 0; i < num_in; i++) {
    // Dense representation.
    if (instypes[i] == 0) {
      inputs[i].setTensor(indata[i],
                          (mxnet::ext::MXDType)intypes[i],
                          inshapes[i],
                          indims[i],
                          inIDs[i],
                          mxnet::ext::MXContext(indev_type[i], indev_id[i]),
                          mxnet::ext::kDefaultStorage);
    } else {
      // Sparse representation.
      mxnet::ext::MXStorageType type;
      if (instypes[i] == 1) {
        type = mxnet::ext::kRowSparseStorage;
        in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
      } else {
        type = mxnet::ext::kCSRStorage;
        in_sparse[i].set(indata[i],
                         inshapes[i],
                         indims[i],
                         in_indices[i],
                         in_indices_shapes[i],
                         in_indptr[i],
                         in_indptr_shapes[i]);
      }
      inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]),
                          (mxnet::ext::MXDType)intypes[i],
                          inshapes[i],
                          indims[i],
                          inIDs[i],
                          mxnet::ext::MXContext(indev_type[i], indev_id[i]),
                          type);
    }
  }

  // create a vector of tensors for outputs
  std::vector<mxnet::ext::MXTensor> outputs(num_out);
  std::vector<mxnet::ext::MXSparse> out_sparse(num_out);

  for (int i = 0; i < num_out; i++) {
    // Dense representation.
    if (outstypes[i] == 0) {
      outputs[i].setTensor(outdata[i],
                           (mxnet::ext::MXDType)outtypes[i],
                           outshapes[i],
                           outdims[i],
                           outIDs[i],
                           mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
                           mxnet::ext::kDefaultStorage);
    } else {
      // Sparse representation.
      mxnet::ext::MXStorageType type;
      if (outstypes[i] == 1) {
        type = mxnet::ext::kRowSparseStorage;
        out_sparse[i].set(
            outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]);
      } else {
        type = mxnet::ext::kCSRStorage;
        out_sparse[i].set(outdata[i],
                          outshapes[i],
                          outdims[i],
                          out_indices[i],
                          out_indices_shapes[i],
                          out_indptr[i],
                          out_indptr_shapes[i]);
      }
      outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
                           (mxnet::ext::MXDType)outtypes[i],
                           outshapes[i],
                           outdims[i],
                           outIDs[i],
                           mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
                           type);
    }
  }

  mxnet::ext::OpResource res(cpu_malloc,
                             cpu_alloc,
                             gpu_malloc,
                             gpu_alloc,
                             cuda_stream,
                             sparse_malloc,
                             sparse_alloc,
                             rng_cpu_states,
                             rng_gpu_states);
  return fcomp(attrs, &inputs, &outputs, res);
}

/*! \brief returns status of calling mutateInputs function for operator from library */
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate,
                               const char* const* keys,
                               const char* const* vals,
                               int num,
                               int** mutate_indices,
                               int* indices_size) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }

  // create a vector of mutate input indices
  std::vector<int> mut_ind;

  int retval = mutate(attrs, &mut_ind);
  if (!retval)
    return retval;

  // output the input indices
  *indices_size   = mut_ind.size();
  *mutate_indices = static_cast<int*>(malloc(*indices_size * sizeof(int)));
  for (int i = 0; i < *indices_size; i++) {
    (*mutate_indices)[i] = mut_ind[i];
  }

  return retval;
}

/*! \brief returns status of calling createStatefulOp function for operator from library */
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op,
                                const char* const* keys,
                                const char* const* vals,
                                int num,
                                const char* dev_type,
                                int dev_id,
                                unsigned int** inshapes,
                                int* indims,
                                int num_in,
                                const int* intypes,
                                void** state_op) {
  // create map of attributes from list
  std::unordered_map<std::string, std::string> attrs;
  for (int i = 0; i < num; i++) {
    attrs[std::string(keys[i])] = std::string(vals[i]);
  }

  mxnet::ext::MXContext ctx(dev_type, dev_id);

  // create a vector of shapes for inputs
  std::vector<std::vector<unsigned int> > in_shapes(num_in);
  for (int i = 0; i < num_in; i++) {
    for (int j = 0; j < indims[i]; j++) {
      in_shapes[i].push_back(inshapes[i][j]);
    }
  }

  // create a vector of types for inputs
  std::vector<int> in_types(num_in);
  for (int i = 0; i < num_in; i++) {
    in_types[i] = intypes[i];
  }

  // void pointer to hold custom state op instance created in custom library
  // eventually state_op pointer is populated by instance from custom library
  mxnet::ext::CustomStatefulOp** op_ptr =
      reinterpret_cast<mxnet::ext::CustomStatefulOp**>(state_op);
  return create_op(attrs, ctx, in_shapes, in_types, op_ptr);
}

/*! \brief calls StatefulOp destructor for operator from library */
MX_VOID_RET _opCallDestroyOpState(void* state_op) {
  mxnet::ext::CustomStatefulOp* op_ptr = reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
  delete op_ptr;
}

/*! \brief returns status of calling Stateful Forward/Backward for operator from library */
MX_INT_RET _opCallFStatefulCompute(int is_forward,
                                   void* state_op,
                                   const int64_t** inshapes,
                                   int* indims,
                                   void** indata,
                                   int* intypes,
                                   size_t* inIDs,
                                   const char** indev_type,
                                   int* indev_id,
                                   int num_in,
                                   const int64_t** outshapes,
                                   int* outdims,
                                   void** outdata,
                                   int* outtypes,
                                   size_t* outIDs,
                                   const char** outdev_type,
                                   int* outdev_id,
                                   int num_out,
                                   mxnet::ext::xpu_malloc_t cpu_malloc,
                                   void* cpu_alloc,
                                   mxnet::ext::xpu_malloc_t gpu_malloc,
                                   void* gpu_alloc,
                                   void* stream,
                                   mxnet::ext::sparse_malloc_t sparse_malloc,
                                   void* sparse_alloc,
                                   int* instypes,
                                   int* outstypes,
                                   void** in_indices,
                                   void** out_indices,
                                   void** in_indptr,
                                   void** out_indptr,
                                   int64_t* in_indices_shapes,
                                   int64_t* out_indices_shapes,
                                   int64_t* in_indptr_shapes,
                                   int64_t* out_indptr_shapes,
                                   void* rng_cpu_states,
                                   void* rng_gpu_states) {
  // create a vector of tensors for inputs
  std::vector<mxnet::ext::MXTensor> inputs(num_in);
  // create a vector for sparse inputs
  std::vector<mxnet::ext::MXSparse> in_sparse(num_in);

  for (int i = 0; i < num_in; i++) {
    if (instypes[i] == 0) {
      // Dense representation.
      inputs[i].setTensor(indata[i],
                          (mxnet::ext::MXDType)intypes[i],
                          inshapes[i],
                          indims[i],
                          inIDs[i],
                          mxnet::ext::MXContext(indev_type[i], indev_id[i]),
                          mxnet::ext::kDefaultStorage);
    } else {
      // Sparse representation.
      mxnet::ext::MXStorageType type;
      if (instypes[i] == 1) {
        type = mxnet::ext::kRowSparseStorage;
        in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
      } else {
        type = mxnet::ext::kCSRStorage;
        in_sparse[i].set(indata[i],
                         inshapes[i],
                         indims[i],
                         in_indices[i],
                         in_indices_shapes[i],
                         in_indptr[i],
                         in_indptr_shapes[i]);
      }
      inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]),
                          (mxnet::ext::MXDType)intypes[i],
                          inshapes[i],
                          indims[i],
                          inIDs[i],
                          mxnet::ext::MXContext(indev_type[i], indev_id[i]),
                          type);
    }
  }

  // create a vector of tensors for outputs
  std::vector<mxnet::ext::MXTensor> outputs(num_out);
  // create a vector for sparse outputs
  std::vector<mxnet::ext::MXSparse> out_sparse(num_out);

  for (int i = 0; i < num_out; i++) {
    if (outstypes[i] == 0) {
      // Dense representation.
      outputs[i].setTensor(outdata[i],
                           (mxnet::ext::MXDType)outtypes[i],
                           outshapes[i],
                           outdims[i],
                           outIDs[i],
                           mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
                           mxnet::ext::kDefaultStorage);
    } else {
      // Sparse representation.
      mxnet::ext::MXStorageType type;
      if (outstypes[i] == 1) {
        type = mxnet::ext::kRowSparseStorage;
        out_sparse[i].set(
            outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]);
      } else {
        type = mxnet::ext::kCSRStorage;
        out_sparse[i].set(outdata[i],
                          outshapes[i],
                          outdims[i],
                          out_indices[i],
                          out_indices_shapes[i],
                          out_indptr[i],
                          out_indptr_shapes[i]);
      }
      outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
                           (mxnet::ext::MXDType)outtypes[i],
                           outshapes[i],
                           outdims[i],
                           outIDs[i],
                           mxnet::ext::MXContext(outdev_type[i], outdev_id[i]),
                           type);
    }
  }

  mxnet::ext::OpResource res(cpu_malloc,
                             cpu_alloc,
                             gpu_malloc,
                             gpu_alloc,
                             stream,
                             sparse_malloc,
                             sparse_alloc,
                             rng_cpu_states,
                             rng_gpu_states);

  mxnet::ext::CustomStatefulOp* op_ptr = reinterpret_cast<mxnet::ext::CustomStatefulOp*>(state_op);
  if (is_forward) {
    return op_ptr->Forward(&inputs, &outputs, res);
  }
  return op_ptr->Backward(&inputs, &outputs, res);
}

/*! \brief returns number of partitioners registered in this library */
MX_INT_RET _partRegSize() {
  return mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->size();
}

/* returns number of strategies registered for partitioner
 * at specified index */
MX_INT_RET _partRegGetCount(int idx, const char** name) {
  mxnet::ext::CustomPartitioner part =
      mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->get(idx);
  *name = part.name;
  return part.strategies.size();
}

/*! \brief returns partitioner registration at specified index */
MX_VOID_RET _partRegGet(int part_idx,
                        int stg_idx,
                        const char** strategy,
                        mxnet::ext::supportedOps_t* supportedOps,
                        mxnet::ext::createSelector_t* createSelector,
                        mxnet::ext::reviewSubgraph_t* reviewSubgraph,
                        const char** op_name) {
  mxnet::ext::CustomPartitioner part =
      mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->get(part_idx);
  *strategy       = part.strategies[stg_idx];
  *op_name        = part.op_names[stg_idx];
  *supportedOps   = part.getSupportedOps(stg_idx);
  *createSelector = part.getCreateSelector(stg_idx);
  *reviewSubgraph = part.getReviewSubgraph(stg_idx);
}

/*! \brief returns status of calling supported ops function from library */
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps,
                                 const char* json,
                                 int num_ids,
                                 int* ids,
                                 const char* const* opt_keys,
                                 const char* const* opt_vals,
                                 int num_opts) {
  mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json);
  // create map of options from list
  std::unordered_map<std::string, std::string> opts;
  for (int i = 0; i < num_opts; i++)
    opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);

  // create array of subgraph IDs for operator support
  std::vector<int> _ids(num_ids, -2);
  // call user's supportedOps function
  mxnet::ext::MXReturnValue retval = supportedOps(graph, &_ids, opts);
  if (!retval)
    return retval;

  // copy bools in ids to ints
  for (int i = 0; i < num_ids; i++)
    ids[i] = _ids[i];

  return retval;
}

/*! \brief returns status of calling create selector function from library */
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector,
                                   const char* json,
                                   void** selector,
                                   const char* const* opt_keys,
                                   const char* const* opt_vals,
                                   int num_opts) {
  mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json);
  // create map of options from list
  std::unordered_map<std::string, std::string> opts;
  for (int i = 0; i < num_opts; i++)
    opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);

  // void pointer to hold selector instance created in custom library
  // eventually pointer is populated by instance from custom library
  mxnet::ext::CustomOpSelector** sel_ptr =
      reinterpret_cast<mxnet::ext::CustomOpSelector**>(selector);

  // call user's createSelector function
  return createSelector(graph, sel_ptr, opts);
}

/*! \brief returns status of calling select function from library */
MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) {
  mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
  *selected                             = sel_ptr->Select(nodeID);
}

/*! \brief returns status of calling select input function from library */
MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, int input_nodeID, int* selected) {
  mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
  *selected                             = sel_ptr->SelectInput(nodeID, input_nodeID);
}

/*! \brief returns status of calling select output function from library */
MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, int output_nodeID, int* selected) {
  mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
  *selected                             = sel_ptr->SelectOutput(nodeID, output_nodeID);
}

/*! \brief returns status of calling filter function from library */
MX_VOID_RET _partCallFilter(void* sel_inst,
                            int* candidates,
                            int num_candidates,
                            int** keep,
                            int* num_keep) {
  mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
  std::vector<int> candidates_(num_candidates);
  for (int i = 0; i < num_candidates; i++) {
    candidates_[i] = candidates[i];
  }
  std::vector<int> keep_;

  sel_ptr->Filter(candidates_, &keep_);

  *num_keep = keep_.size();
  *keep     = static_cast<int*>(malloc(keep_.size() * sizeof(int)));
  for (unsigned i = 0; i < keep_.size(); i++)
    (*keep)[i] = keep_[i];
}

/*! \brief returns status of calling reset selector function from library */
MX_VOID_RET _partCallReset(void* sel_inst) {
  mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast<mxnet::ext::CustomOpSelector*>(sel_inst);
  sel_ptr->Reset();
}

/*! \brief returns status of calling review subgraph function from library */
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph,
                                   const char* json,
                                   int subgraph_id,
                                   int* accept,
                                   const char* const* opt_keys,
                                   const char* const* opt_vals,
                                   int num_opts,
                                   char*** attr_keys,
                                   char*** attr_vals,
                                   int* num_attrs,
                                   const char* const* arg_names,
                                   int num_args,
                                   void* const* arg_data,
                                   const int64_t* const* arg_shapes,
                                   const int* arg_dims,
                                   const int* arg_types,
                                   const size_t* arg_IDs,
                                   const char* const* arg_dev_type,
                                   const int* arg_dev_id,
                                   const char* const* aux_names,
                                   int num_aux,
                                   void* const* aux_data,
                                   const int64_t* const* aux_shapes,
                                   const int* aux_dims,
                                   const int* aux_types,
                                   const size_t* aux_IDs,
                                   const char* const* aux_dev_type,
                                   const int* aux_dev_id) {
  mxnet::ext::Graph* subgraph = mxnet::ext::Graph::fromString(json);
  bool accept_bool            = false;
  // create map of attributes from list
  std::unordered_map<std::string, std::string> opts;
  for (int i = 0; i < num_opts; i++)
    opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);

  // create a map of named tensors for args
  std::unordered_map<std::string, mxnet::ext::MXTensor> args;
  for (int i = 0; i < num_args; i++) {
    std::vector<int64_t> shapes;
    shapes.reserve(arg_dims[i]);
    for (int j = 0; j < arg_dims[i]; j++)
      shapes.push_back(arg_shapes[i][j]);

    mxnet::ext::MXTensor tensor(arg_data[i],
                                shapes,
                                (mxnet::ext::MXDType)arg_types[i],
                                arg_IDs[i],
                                mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i]));
    args[arg_names[i]] = tensor;
  }
  // create a map of named tensors for aux
  std::unordered_map<std::string, mxnet::ext::MXTensor> aux;
  for (int i = 0; i < num_aux; i++) {
    std::vector<int64_t> shapes;
    shapes.reserve(aux_dims[i]);
    for (int j = 0; j < aux_dims[i]; j++)
      shapes.push_back(aux_shapes[i][j]);

    mxnet::ext::MXTensor tensor(aux_data[i],
                                shapes,
                                (mxnet::ext::MXDType)aux_types[i],
                                aux_IDs[i],
                                mxnet::ext::MXContext(aux_dev_type[i], aux_dev_id[i]));
    aux[aux_names[i]] = tensor;
  }

  subgraph->_setParams(&args, &aux);

  std::unordered_map<std::string, std::string> attrs;
  mxnet::ext::MXReturnValue retval =
      reviewSubgraph(subgraph, subgraph_id, &accept_bool, opts, &attrs);
  if (!retval)
    return retval;

  *accept = accept_bool;

  if (attrs.size() > 0) {
    *num_attrs = attrs.size();
    // allocate space for attributes
    *attr_keys = static_cast<char**>(malloc(*num_attrs * sizeof(char*)));
    *attr_vals = static_cast<char**>(malloc(*num_attrs * sizeof(char*)));

    // copy attributes
    int i = 0;
    for (auto kv : attrs) {
      (*attr_keys)[i] = static_cast<char*>(malloc((kv.first.size() + 1) * sizeof(char)));  // NOLINT
      (*attr_vals)[i] =
          static_cast<char*>(malloc((kv.second.size() + 1) * sizeof(char)));  // NOLINT
      snprintf((*attr_keys)[i], kv.first.size() + 1, "%s", kv.first.c_str());
      snprintf((*attr_vals)[i], kv.second.size() + 1, "%s", kv.second.c_str());
      i++;
    }
  }

  return retval;
}

/*! \brief returns number of graph passes registered in this library */
MX_INT_RET _passRegSize() {
  return mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->size();
}

/*! \brief returns pass registration at specified index */
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, const char** pass_name) {
  mxnet::ext::CustomPass pass = mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->get(pass_idx);
  *graphPass                  = pass.pass;
  *pass_name                  = pass.name;
}

/*! \brief returns status of calling graph pass function from library */
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass,
                              const char* json,
                              char** out_graph,
                              const char* const* opt_keys,
                              const char* const* opt_vals,
                              int num_opts,
                              const char* pass_name,
                              const char* const* arg_names,
                              int num_args,
                              void* const* arg_data,
                              const int64_t* const* arg_shapes,
                              const int* arg_dims,
                              const int* arg_types,
                              const size_t* arg_IDs,
                              const char* const* arg_dev_type,
                              const int* arg_dev_id,
                              const char* const* aux_names,
                              int num_aux,
                              void* const* aux_data,
                              const int64_t* const* aux_shapes,
                              const int* aux_dims,
                              const int* aux_types,
                              const size_t* aux_IDs,
                              const char* const* aux_dev_type,
                              const int* aux_dev_id,
                              mxnet::ext::nd_malloc_t nd_malloc,
                              const void* nd_alloc) {
  mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json);
  // create map of attributes from list
  std::unordered_map<std::string, std::string> opts;
  for (int i = 0; i < num_opts; i++)
    opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);

  // create a map of named tensors for args
  std::unordered_map<std::string, mxnet::ext::MXTensor> args;
  for (int i = 0; i < num_args; i++) {
    std::vector<int64_t> shapes;
    shapes.reserve(arg_dims[i]);
    for (int j = 0; j < arg_dims[i]; j++)
      shapes.push_back(arg_shapes[i][j]);

    mxnet::ext::MXTensor tensor(arg_data[i],
                                shapes,
                                (mxnet::ext::MXDType)arg_types[i],
                                arg_IDs[i],
                                mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i]));
    args[arg_names[i]] = tensor;
  }
  // create a map of named tensors for aux
  std::unordered_map<std::string, mxnet::ext::MXTensor> aux;
  for (int i = 0; i < num_aux; i++) {
    std::vector<int64_t> shapes;
    shapes.reserve(aux_dims[i]);
    for (int j = 0; j < aux_dims[i]; j++)
      shapes.push_back(aux_shapes[i][j]);

    mxnet::ext::MXTensor tensor(aux_data[i],
                                shapes,
                                (mxnet::ext::MXDType)aux_types[i],
                                aux_IDs[i],
                                mxnet::ext::MXContext(aux_dev_type[i], aux_dev_id[i]));
    aux[aux_names[i]] = tensor;
  }

  std::unordered_map<std::string, mxnet::ext::MXTensor> new_args, new_aux;
  mxnet::ext::PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc);
  graph->_setParams(&args, &aux);
  graph->_setPassResource(&res);
  mxnet::ext::MXReturnValue retval = graphPass(graph, opts);
  if (!retval)
    return retval;

  std::string tmp = graph->toString();
  *out_graph      = static_cast<char*>(malloc((tmp.size() + 1) * sizeof(char)));  // NOLINT
  snprintf((*out_graph), tmp.size() + 1, "%s", tmp.c_str());
  return retval;
}

/*!
 * \brief Checks if the MXNet version is supported by the library.
 * If supported, initializes the library.
 * \param version MXNet version number passed to library and defined as:
 *                MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH)
 * \return Non-zero value on error i.e. library incompatible with passed MXNet version
 */
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
__declspec(dllexport) mxnet::ext::MXReturnValue __cdecl
#else
mxnet::ext::MXReturnValue
#endif
    initialize(int version);

MX_INT_RET _msgSize() {
  return mxnet::ext::MXerrorMsgs::get()->size();
}

/*! \brief returns operator registration at specified index */
MX_VOID_RET _msgGet(int idx, const char** msg) {
  *msg = mxnet::ext::MXerrorMsgs::get()->get(idx)->c_str();
}
