ferritin_onnx_models/
utilities.rs1use anyhow::Result;
2use candle_core::{Device, Tensor};
3use ndarray;
4
5pub fn ndarray_to_tensor_f32(
6 arr: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::IxDyn>,
7) -> Result<Tensor> {
8 let shape: Vec<usize> = arr.shape().to_vec();
9 let raw_data = arr
10 .as_slice()
11 .unwrap()
12 .iter()
13 .flat_map(|&x| x.to_ne_bytes())
14 .collect::<Vec<u8>>();
15
16 Tensor::from_raw_buffer(&raw_data, candle_core::DType::F32, &shape, &Device::Cpu)
17 .map_err(|e| anyhow::anyhow!("Failed to create tensor: {}", e))
18}
19
20pub fn tensor_to_ndarray_f32(
21 tensor: Tensor,
22) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::IxDyn>> {
23 let shape = tensor.dims().to_vec();
24 let flattened = tensor.flatten_all()?;
25 let f32_data = flattened.to_vec1::<f32>()?;
26 ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), f32_data)
27 .map_err(|e| anyhow::anyhow!("Failed to create ndarray: {}", e))
28}
29
30pub fn tensor_to_ndarray_i64(
31 tensor: Tensor,
32) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<i64>, ndarray::IxDyn>> {
33 let shape = tensor.dims().to_vec();
34 let flattened = tensor.flatten_all()?;
35 let i64_data = flattened.to_vec1::<i64>()?;
36 ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), i64_data)
37 .map_err(|e| anyhow::anyhow!("Failed to create ndarray: {}", e))
38}