ferritin_onnx_models/
utilities.rs

1use anyhow::Result;
2use candle_core::{Device, Tensor};
3use ndarray;
4
5pub fn ndarray_to_tensor_f32(
6    arr: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::IxDyn>,
7) -> Result<Tensor> {
8    let shape: Vec<usize> = arr.shape().to_vec();
9    let raw_data = arr
10        .as_slice()
11        .unwrap()
12        .iter()
13        .flat_map(|&x| x.to_ne_bytes())
14        .collect::<Vec<u8>>();
15
16    Tensor::from_raw_buffer(&raw_data, candle_core::DType::F32, &shape, &Device::Cpu)
17        .map_err(|e| anyhow::anyhow!("Failed to create tensor: {}", e))
18}
19
20pub fn tensor_to_ndarray_f32(
21    tensor: Tensor,
22) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::IxDyn>> {
23    let shape = tensor.dims().to_vec();
24    let flattened = tensor.flatten_all()?;
25    let f32_data = flattened.to_vec1::<f32>()?;
26    ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), f32_data)
27        .map_err(|e| anyhow::anyhow!("Failed to create ndarray: {}", e))
28}
29
30pub fn tensor_to_ndarray_i64(
31    tensor: Tensor,
32) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<i64>, ndarray::IxDyn>> {
33    let shape = tensor.dims().to_vec();
34    let flattened = tensor.flatten_all()?;
35    let i64_data = flattened.to_vec1::<i64>()?;
36    ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), i64_data)
37        .map_err(|e| anyhow::anyhow!("Failed to create ndarray: {}", e))
38}