use tch::{Kind, Tensor};

#[test]
fn save_and_load() {
    let filename = std::env::temp_dir().join(format!("tch-{}", std::process::id()));
    let vec = [3.0, 1.0, 4.0, 1.0, 5.0].to_vec();
    let t1 = Tensor::of_slice(&vec);
    t1.save(&filename).unwrap();
    let t2 = Tensor::load(&filename).unwrap();
    assert_eq!(Vec::<f64>::from(&t2), vec)
}

#[test]
fn save_and_load_multi() {
    let filename = std::env::temp_dir().join(format!("tch2-{}", std::process::id()));
    let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::of_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::save_multi(&[(&"pi", &pi), (&"e", &e)], &filename).unwrap();
    let named_tensors = Tensor::load_multi(&filename).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz() {
    let filename = std::env::temp_dir().join(format!("tch3-{}.npz", std::process::id()));
    let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
    let e = Tensor::of_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
    Tensor::write_npz(&[(&"pi", &pi), (&"e", &e)], &filename).unwrap();
    let named_tensors = Tensor::read_npz(&filename).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz_half() {
    let filename = std::env::temp_dir().join(format!("tch4-{}.npz", std::process::id()));
    let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]).to2(Kind::Half, true, false);
    let e = Tensor::of_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]).to2(Kind::Half, true, false);
    Tensor::write_npz(&[(&"pi", &pi), (&"e", &e)], &filename).unwrap();
    let named_tensors = Tensor::read_npz(&filename).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz_byte() {
    let filename = std::env::temp_dir().join(format!("tch5-{}.npz", std::process::id()));
    let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]).to2(Kind::Int8, true, false);
    let e = Tensor::of_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]).to2(Kind::Int8, true, false);
    Tensor::write_npz(&[(&"pi", &pi), (&"e", &e)], &filename).unwrap();
    let named_tensors = Tensor::read_npz(&filename).unwrap();
    assert_eq!(named_tensors.len(), 2);
    assert_eq!(named_tensors[0].0, "pi");
    assert_eq!(named_tensors[1].0, "e");
    assert_eq!(i8::from(&named_tensors[1].1.sum(tch::Kind::Int8)), 57);
}

#[test]
fn save_and_load_npy() {
    let filename = std::env::temp_dir().join(format!("tch6-{}.npy", std::process::id()));
    let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
    pi.write_npy(&filename).unwrap();
    let pi = Tensor::read_npy(&filename).unwrap();
    assert_eq!(Vec::<f64>::from(&pi), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
    let pi = pi.reshape(&[3, 1, 2]);
    pi.write_npy(&filename).unwrap();
    let pi = Tensor::read_npy(&filename).unwrap();
    assert_eq!(pi.size(), [3, 1, 2]);
    assert_eq!(
        Vec::<f64>::from(pi.flatten(0, -1)),
        [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]
    );
}
