https://github.com/scipy/scipy/issues/18759 https://github.com/scipy/scipy/pull/18763 https://github.com/scipy/scipy/commit/61d892c9faa543ad80bd5e2d0bf69821188487e0 From 61d892c9faa543ad80bd5e2d0bf69821188487e0 Mon Sep 17 00:00:00 2001 From: Ilhan Polat Date: Tue, 27 Jun 2023 12:00:38 +0200 Subject: [PATCH] MAINT:linalg.det:Return scalars for singleton inputs (#18763) --- a/scipy/linalg/_basic.py +++ b/scipy/linalg/_basic.py @@ -1001,7 +1001,8 @@ def det(a, overwrite_a=False, check_finite=True): det : (...) float or complex Determinant of `a`. For stacked arrays, a scalar is returned for each (m, m) slice in the last two dimensions of the input. For example, an - input of shape (p, q, m, m) will produce a result of shape (p, q). + input of shape (p, q, m, m) will produce a result of shape (p, q). If + all dimensions are 1 a scalar is returned regardless of ndim. Notes ----- @@ -1066,11 +1067,17 @@ def det(a, overwrite_a=False, check_finite=True): # Scalar case if a1.shape[-2:] == (1, 1): - if a1.dtype.char in 'dD': - return np.squeeze(a1) + # Either ndarray with spurious singletons or a single element + if max(*a1.shape) > 1: + temp = np.squeeze(a1) + if a1.dtype.char in 'dD': + return temp + else: + return (temp.astype('d') if a1.dtype.char == 'f' else + temp.astype('D')) else: - return (np.squeeze(a1).astype('d') if a1.dtype.char == 'f' else - np.squeeze(a1).astype('D')) + return (np.float64(a1.item()) if a1.dtype.char in 'fd' else + np.complex128(a1.item())) # Then check overwrite permission if not _datacopied(a1, a): # "a" still alive through "a1" --- a/scipy/linalg/tests/test_basic.py +++ b/scipy/linalg/tests/test_basic.py @@ -930,6 +930,23 @@ class TestDet: def setup_method(self): self.rng = np.random.default_rng(1680305949878959) + def test_1x1_all_singleton_dims(self): + a = np.array([[1]]) + deta = det(a) + assert deta.dtype.char == 'd' + assert np.isscalar(deta) + assert deta == 1. + a = np.array([[[[1]]]], dtype='f') + deta = det(a) + assert deta.dtype.char == 'd' + assert np.isscalar(deta) + assert deta == 1. + a = np.array([[[1 + 3.j]]], dtype=np.complex64) + deta = det(a) + assert deta.dtype.char == 'D' + assert np.isscalar(deta) + assert deta == 1.+3.j + def test_1by1_stacked_input_output(self): a = self.rng.random([4, 5, 1, 1], dtype=np.float32) deta = det(a)