Here are the examples of the python api numpy.VisibleDeprecationWarning taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
60 Examples
3
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_header(self):
# Test retrieving a header
data = TextIO('gender age weight\nM 64.0 75.0\nF 25.0 60.0')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.ndfromtxt(data, dtype=None, names=True)
assert_(w[0].category is np.VisibleDeprecationWarning)
control = {'gender': np.array([b'M', b'F']),
'age': np.array([64.0, 25.0]),
'weight': np.array([75.0, 60.0])}
assert_equal(test['gender'], control['gender'])
assert_equal(test['age'], control['age'])
assert_equal(test['weight'], control['weight'])
def test_auto_dtype(self):
3
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_auto_dtype(self):
# Test the automatic definition of the output dtype
data = TextIO('A 64 75.0 3+4j True\nBCD 25 60.0 5+6j False')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.ndfromtxt(data, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
control = [np.array([b'A', b'BCD']),
np.array([64, 25]),
np.array([75.0, 60.0]),
np.array([3 + 4j, 5 + 6j]),
np.array([True, False]), ]
assert_equal(test.dtype.names, ['f0', 'f1', 'f2', 'f3', 'f4'])
for (i, ctrl) in enumerate(control):
assert_equal(test['f%i' % i], ctrl)
def test_auto_dtype_uniform(self):
3
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_autonames_and_usecols(self):
# Tests names and usecols
data = TextIO('A B C D\n aaaa 121 45 9.1')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.ndfromtxt(data, usecols=('A', 'C', 'D'),
names=True, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
control = np.array(('aaaa', 45, 9.1),
dtype=[('A', '|S4'), ('C', int), ('D', float)])
assert_equal(test, control)
def test_converters_with_usecols(self):
3
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_converters_with_usecols_and_names(self):
# Tests names and usecols
data = TextIO('A B C D\n aaaa 121 45 9.1')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.ndfromtxt(data, usecols=('A', 'C', 'D'), names=True,
dtype=None,
converters={'C': lambda s: 2 * int(s)})
assert_(w[0].category is np.VisibleDeprecationWarning)
control = np.array(('aaaa', 90, 9.1),
dtype=[('A', '|S4'), ('C', int), ('D', float)])
assert_equal(test, control)
def test_converters_cornercases(self):
3
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_comments_is_none(self):
# Github issue 329 (None was previously being converted to 'None').
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(TextIO("test1,testNonetherestofthedata"),
dtype=None, comments=None, delimiter=',')
assert_(w[0].category is np.VisibleDeprecationWarning)
assert_equal(test[1], b'testNonetherestofthedata')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(TextIO("test1, testNonetherestofthedata"),
dtype=None, comments=None, delimiter=',')
assert_(w[0].category is np.VisibleDeprecationWarning)
assert_equal(test[1], b' testNonetherestofthedata')
def test_latin1(self):
3
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_utf8_byte_encoding(self):
utf8 = b"\xcf\x96"
norm = b"norm1,norm2,norm3\n"
enc = b"test1,testNonethe" + utf8 + b",test3\n"
s = norm + enc + norm
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(TextIO(s),
dtype=None, comments=None, delimiter=',')
assert_(w[0].category is np.VisibleDeprecationWarning)
ctl = np.array([
[b'norm1', b'norm2', b'norm3'],
[b'test1', b'testNonethe' + utf8, b'test3'],
[b'norm1', b'norm2', b'norm3']])
assert_array_equal(test, ctl)
def test_utf8_file(self):
3
Source : test_regression.py
with Apache License 2.0
from aws-samples
with Apache License 2.0
from aws-samples
def test_typeNA(self):
# Issue gh-515
with suppress_warnings() as sup:
sup.filter(np.VisibleDeprecationWarning)
assert_equal(np.typeNA[np.int64], 'Int64')
assert_equal(np.typeNA[np.uint64], 'UInt64')
def test_dtype_names(self):
3
Source : test_deprecations.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_deprecate_ragged_arrays():
# 2019-11-29 1.19.0
#
# NEP 34 deprecated automatic object dtype when creating ragged
# arrays. Also see the "ragged" tests in `test_multiarray`
#
# emits a VisibleDeprecationWarning
arg = [1, [2, 3]]
with assert_warns(np.VisibleDeprecationWarning):
np.array(arg)
class TestToString(_DeprecationTestCase):
3
Source : test_io.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_header(self):
# Test retrieving a header
data = TextIO('gender age weight\nM 64.0 75.0\nF 25.0 60.0')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, dtype=None, names=True)
assert_(w[0].category is np.VisibleDeprecationWarning)
control = {'gender': np.array([b'M', b'F']),
'age': np.array([64.0, 25.0]),
'weight': np.array([75.0, 60.0])}
assert_equal(test['gender'], control['gender'])
assert_equal(test['age'], control['age'])
assert_equal(test['weight'], control['weight'])
def test_auto_dtype(self):
3
Source : test_io.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_auto_dtype(self):
# Test the automatic definition of the output dtype
data = TextIO('A 64 75.0 3+4j True\nBCD 25 60.0 5+6j False')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
control = [np.array([b'A', b'BCD']),
np.array([64, 25]),
np.array([75.0, 60.0]),
np.array([3 + 4j, 5 + 6j]),
np.array([True, False]), ]
assert_equal(test.dtype.names, ['f0', 'f1', 'f2', 'f3', 'f4'])
for (i, ctrl) in enumerate(control):
assert_equal(test['f%i' % i], ctrl)
def test_auto_dtype_uniform(self):
3
Source : test_io.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_autonames_and_usecols(self):
# Tests names and usecols
data = TextIO('A B C D\n aaaa 121 45 9.1')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, usecols=('A', 'C', 'D'),
names=True, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
control = np.array(('aaaa', 45, 9.1),
dtype=[('A', '|S4'), ('C', int), ('D', float)])
assert_equal(test, control)
def test_converters_with_usecols(self):
3
Source : test_io.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_converters_with_usecols_and_names(self):
# Tests names and usecols
data = TextIO('A B C D\n aaaa 121 45 9.1')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, usecols=('A', 'C', 'D'), names=True,
dtype=None,
converters={'C': lambda s: 2 * int(s)})
assert_(w[0].category is np.VisibleDeprecationWarning)
control = np.array(('aaaa', 90, 9.1),
dtype=[('A', '|S4'), ('C', int), ('D', float)])
assert_equal(test, control)
def test_converters_cornercases(self):
3
Source : test_distributions.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_rvs_no_size_warning():
class rvs_no_size_gen(stats.rv_continuous):
def _rvs(self):
return 1
rvs_no_size = rvs_no_size_gen(name='rvs_no_size')
with assert_warns(np.VisibleDeprecationWarning):
rvs_no_size.rvs()
3
Source : test_voxels.py
with GNU General Public License v3.0
from dbbs-lab
with GNU General Public License v3.0
from dbbs-lab
def test_data_ctor(self):
with self.assertRaises(ValueError):
VoxelSet([], [], [1])
with self.assertRaises(ValueError):
VoxelSet([], [], 1)
VoxelSet([[1, 2, 3]], [[1, 0, 0]], [1])
VoxelSet([[1, 2, 3]], 1, [1])
VoxelSet([[1, 2, 3]], 1, [[1, 2]])
with self.assertWarns(np.VisibleDeprecationWarning):
VoxelSet([[1, 2, 3], [0, 0, 0]], 1, [[1, 2], [1]])
het = VoxelSet([[1, 2, 3], [2, 0, 0]], 1, [[1, 2], [1, "a"]])
def test_unequal_len(self):
3
Source : test_voxels.py
with GNU General Public License v3.0
from dbbs-lab
with GNU General Public License v3.0
from dbbs-lab
def test_ragged(self):
with self.assertRaises(ValueError):
with self.assertWarns(np.VisibleDeprecationWarning):
ragged = VoxelSet([[1, 0, 1], [1, 1]], [1, 1, 1, 1])
def test_get_size(self):
3
Source : test_win_probability_simulation.py
with MIT License
from djcunningham0
with MIT License
from djcunningham0
def test_invalid_scores_to_result_proportions():
# inconsistent lengths
with pytest.raises(Exception):
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
scores = np.array([
[100, 90, 100],
[95, 105],
[80, 120, 80],
])
multielo.MultiElo._convert_scores_to_result_proportions(scores)
3
Source : test_deprecations.py
with GNU General Public License v3.0
from dnn-security
with GNU General Public License v3.0
from dnn-security
def test_deprecate_ragged_arrays():
# 2019-11-29 1.19.0
#
# NEP 34 deprecated automatic object dtype when creating ragged
# arrays. Also see the "ragged" tests in `test_multiarray`
#
# emits a VisibleDeprecationWarning
arg = [1, [2, 3]]
with assert_warns(np.VisibleDeprecationWarning):
np.array(arg)
class TestTooDeepDeprecation(_VisibleDeprecationTestCase):
3
Source : test_distributions.py
with MIT License
from osamhack2021
with MIT License
from osamhack2021
def test_rvs_no_size_warning():
class rvs_no_size_gen(stats.rv_continuous):
def _rvs(self):
return 1
rvs_no_size = rvs_no_size_gen(name='rvs_no_size')
with assert_warns(np.VisibleDeprecationWarning):
rvs_no_size.rvs()
@pytest.mark.parametrize('distname, args', invdistdiscrete + invdistcont)
3
Source : test_util.py
with GNU General Public License v3.0
from pyxem
with GNU General Public License v3.0
from pyxem
def test_deprecation_no_old_doc(self):
@deprecated(since=0.7, alternative="bar", removal=0.8)
def foo(n):
return n + 1
with pytest.warns(np.VisibleDeprecationWarning) as record:
assert foo(4) == 5
desired_msg = (
"Function `foo()` is deprecated and will be removed in version 0.8. Use "
"`bar()` instead."
)
assert str(record[0].message) == desired_msg
assert foo.__doc__ == (
"[*Deprecated*] \n"
"\nNotes\n-----\n"
".. deprecated:: 0.7\n"
f" {desired_msg}"
)
3
Source : test_distributions.py
with MIT License
from tpike3
with MIT License
from tpike3
def test_rvs_no_size_warning():
class rvs_no_size_gen(stats.rv_continuous):
def _rvs(self):
return 1
rvs_no_size = rvs_no_size_gen(name='rvs_no_size')
with assert_warns(np.VisibleDeprecationWarning):
rvs_no_size.rvs()
0
Source : test_indexing.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_multidim(self):
# Automatically test combinations with complex indexes on 2nd (or 1st)
# spot and the simple ones in one other spot.
with warnings.catch_warnings():
# This is so that np.array(True) is not accepted in a full integer
# index, when running the file separately.
warnings.filterwarnings('error', '', DeprecationWarning)
warnings.filterwarnings('error', '', np.VisibleDeprecationWarning)
def isskip(idx):
return isinstance(idx, str) and idx == "skip"
for simple_pos in [0, 2, 3]:
tocheck = [self.fill_indices, self.complex_indices,
self.fill_indices, self.fill_indices]
tocheck[simple_pos] = self.simple_indices
for index in product(*tocheck):
index = tuple(i for i in index if not isskip(i))
self._check_multi_index(self.a, index)
self._check_multi_index(self.b, index)
# Check very simple item getting:
self._check_multi_index(self.a, (0, 0, 0, 0))
self._check_multi_index(self.b, (0, 0, 0, 0))
# Also check (simple cases of) too many indices:
assert_raises(IndexError, self.a.__getitem__, (0, 0, 0, 0, 0))
assert_raises(IndexError, self.a.__setitem__, (0, 0, 0, 0, 0), 0)
assert_raises(IndexError, self.a.__getitem__, (0, 0, [1], 0, 0))
assert_raises(IndexError, self.a.__setitem__, (0, 0, [1], 0, 0), 0)
def test_1d(self):
0
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_commented_header(self):
# Check that names can be retrieved even if the line is commented out.
data = TextIO("""
#gender age weight
M 21 72.100000
F 35 58.330000
M 33 21.99
""")
# The # is part of the first name and should be deleted automatically.
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, names=True, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
ctrl = np.array([('M', 21, 72.1), ('F', 35, 58.33), ('M', 33, 21.99)],
dtype=[('gender', '|S1'), ('age', int), ('weight', float)])
assert_equal(test, ctrl)
# Ditto, but we should get rid of the first element
data = TextIO(b"""
# gender age weight
M 21 72.100000
F 35 58.330000
M 33 21.99
""")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, names=True, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
assert_equal(test, ctrl)
def test_autonames_and_usecols(self):
0
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_autostrip(self):
# Test autostrip
data = "01/01/2003 , 1.3, abcde"
kwargs = dict(delimiter=",", dtype=None)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
mtest = np.ndfromtxt(TextIO(data), **kwargs)
assert_(w[0].category is np.VisibleDeprecationWarning)
ctrl = np.array([('01/01/2003 ', 1.3, ' abcde')],
dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')])
assert_equal(mtest, ctrl)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
mtest = np.ndfromtxt(TextIO(data), autostrip=True, **kwargs)
assert_(w[0].category is np.VisibleDeprecationWarning)
ctrl = np.array([('01/01/2003', 1.3, 'abcde')],
dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')])
assert_equal(mtest, ctrl)
def test_replace_space(self):
0
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_latin1(self):
latin1 = b'\xf6\xfc\xf6'
norm = b"norm1,norm2,norm3\n"
enc = b"test1,testNonethe" + latin1 + b",test3\n"
s = norm + enc + norm
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(TextIO(s),
dtype=None, comments=None, delimiter=',')
assert_(w[0].category is np.VisibleDeprecationWarning)
assert_equal(test[1, 0], b"test1")
assert_equal(test[1, 1], b"testNonethe" + latin1)
assert_equal(test[1, 2], b"test3")
test = np.genfromtxt(TextIO(s),
dtype=None, comments=None, delimiter=',',
encoding='latin1')
assert_equal(test[1, 0], u"test1")
assert_equal(test[1, 1], u"testNonethe" + latin1.decode('latin1'))
assert_equal(test[1, 2], u"test3")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(TextIO(b"0,testNonethe" + latin1),
dtype=None, comments=None, delimiter=',')
assert_(w[0].category is np.VisibleDeprecationWarning)
assert_equal(test['f0'], 0)
assert_equal(test['f1'], b"testNonethe" + latin1)
def test_binary_decode_autodtype(self):
0
Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby
with GNU General Public License v3.0
from adityaprakash-bobby
def test_utf8_file_nodtype_unicode(self):
# bytes encoding with non-latin1 -> unicode upcast
utf8 = u'\u03d6'
latin1 = u'\xf6\xfc\xf6'
# skip test if cannot encode utf8 test string with preferred
# encoding. The preferred encoding is assumed to be the default
# encoding of io.open. Will need to change this for PyTest, maybe
# using pytest.mark.xfail(raises=***).
try:
import locale
encoding = locale.getpreferredencoding()
utf8.encode(encoding)
except (UnicodeError, ImportError):
raise SkipTest('Skipping test_utf8_file_nodtype_unicode, '
'unable to encode utf8 in preferred encoding')
with temppath() as path:
with io.open(path, "wt") as f:
f.write(u"norm1,norm2,norm3\n")
f.write(u"norm1," + latin1 + u",norm3\n")
f.write(u"test1,testNonethe" + utf8 + u",test3\n")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '',
np.VisibleDeprecationWarning)
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',')
# Check for warning when encoding not specified.
assert_(w[0].category is np.VisibleDeprecationWarning)
ctl = np.array([
["norm1", "norm2", "norm3"],
["norm1", latin1, "norm3"],
["test1", "testNonethe" + utf8, "test3"]],
dtype=np.unicode)
assert_array_equal(test, ctl)
def test_recfromtxt(self):
0
Source : test_histograms.py
with MIT License
from alvarobartt
with MIT License
from alvarobartt
def test_normed(self):
sup = suppress_warnings()
with sup:
rec = sup.record(np.VisibleDeprecationWarning, '.*normed.*')
# Check that the integral of the density equals 1.
n = 100
v = np.random.rand(n)
a, b = histogram(v, normed=True)
area = np.sum(a * np.diff(b))
assert_almost_equal(area, 1)
assert_equal(len(rec), 1)
sup = suppress_warnings()
with sup:
rec = sup.record(np.VisibleDeprecationWarning, '.*normed.*')
# Check with non-constant bin widths (buggy but backwards
# compatible)
v = np.arange(10)
bins = [0, 1, 5, 9, 10]
a, b = histogram(v, bins, normed=True)
area = np.sum(a * np.diff(b))
assert_almost_equal(area, 1)
assert_equal(len(rec), 1)
def test_density(self):
0
Source : test_io.py
with MIT License
from alvarobartt
with MIT License
from alvarobartt
def test_commented_header(self):
# Check that names can be retrieved even if the line is commented out.
data = TextIO("""
#gender age weight
M 21 72.100000
F 35 58.330000
M 33 21.99
""")
# The # is part of the first name and should be deleted automatically.
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, names=True, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
ctrl = np.array([('M', 21, 72.1), ('F', 35, 58.33), ('M', 33, 21.99)],
dtype=[('gender', '|S1'), ('age', int), ('weight', float)])
assert_equal(test, ctrl)
# Ditto, but we should get rid of the first element
data = TextIO(b"""
# gender age weight
M 21 72.100000
F 35 58.330000
M 33 21.99
""")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
test = np.genfromtxt(data, names=True, dtype=None)
assert_(w[0].category is np.VisibleDeprecationWarning)
assert_equal(test, ctrl)
def test_names_and_comments_none(self):
0
Source : test_io.py
with MIT License
from alvarobartt
with MIT License
from alvarobartt
def test_utf8_file_nodtype_unicode(self):
# bytes encoding with non-latin1 -> unicode upcast
utf8 = u'\u03d6'
latin1 = u'\xf6\xfc\xf6'
# skip test if cannot encode utf8 test string with preferred
# encoding. The preferred encoding is assumed to be the default
# encoding of io.open. Will need to change this for PyTest, maybe
# using pytest.mark.xfail(raises=***).
try:
encoding = locale.getpreferredencoding()
utf8.encode(encoding)
except (UnicodeError, ImportError):
raise SkipTest('Skipping test_utf8_file_nodtype_unicode, '
'unable to encode utf8 in preferred encoding')
with temppath() as path:
with io.open(path, "wt") as f:
f.write(u"norm1,norm2,norm3\n")
f.write(u"norm1," + latin1 + u",norm3\n")
f.write(u"test1,testNonethe" + utf8 + u",test3\n")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '',
np.VisibleDeprecationWarning)
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',')
# Check for warning when encoding not specified.
assert_(w[0].category is np.VisibleDeprecationWarning)
ctl = np.array([
["norm1", "norm2", "norm3"],
["norm1", latin1, "norm3"],
["test1", "testNonethe" + utf8, "test3"]],
dtype=np.unicode)
assert_array_equal(test, ctl)
def test_recfromtxt(self):
0
Source : testing.py
with MIT License
from alvarobartt
with MIT License
from alvarobartt
def assert_warns(warning_class, func, *args, **kw):
"""Test that a certain warning occurs.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
func : callable
Calable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`
Returns
-------
result : the return value of `func`
"""
# very important to avoid uncontrolled state propagation
clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
result = func(*args, **kw)
if hasattr(np, 'VisibleDeprecationWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if e.category is not np.VisibleDeprecationWarning]
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = any(warning.category is warning_class for warning in w)
if not found:
raise AssertionError("%s did not give warning: %s( is %s)"
% (func.__name__, warning_class, w))
return result
def assert_warns_message(warning_class, message, func, *args, **kw):
0
Source : testing.py
with MIT License
from alvarobartt
with MIT License
from alvarobartt
def assert_warns_message(warning_class, message, func, *args, **kw):
# very important to avoid uncontrolled state propagation
"""Test that a certain warning occurs and with a certain message.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
message : str | callable
The entire message or a substring to test for. If callable,
it takes a string as argument and will trigger an assertion error
if it returns `False`.
func : callable
Calable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`.
Returns
-------
result : the return value of `func`
"""
clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
if hasattr(np, 'VisibleDeprecationWarning'):
# Let's not catch the numpy internal DeprecationWarnings
warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
# Trigger a warning.
result = func(*args, **kw)
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = [issubclass(warning.category, warning_class) for warning in w]
if not any(found):
raise AssertionError("No warning raised for %s with class "
"%s"
% (func.__name__, warning_class))
message_found = False
# Checks the message of all warnings belong to warning_class
for index in [i for i, x in enumerate(found) if x]:
# substring will match, the entire message with typo won't
msg = w[index].message # For Python 3 compatibility
msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
if callable(message): # add support for certain tests
check_in_message = message
else:
check_in_message = lambda msg: message in msg
if check_in_message(msg):
message_found = True
break
if not message_found:
raise AssertionError("Did not receive the message you expected "
"('%s') for < %s>, got: '%s'"
% (message, func.__name__, msg))
return result
# To remove when we support numpy 1.7
def assert_no_warnings(func, *args, **kw):
0
Source : testing.py
with MIT License
from alvarobartt
with MIT License
from alvarobartt
def assert_no_warnings(func, *args, **kw):
# very important to avoid uncontrolled state propagation
clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
result = func(*args, **kw)
if hasattr(np, 'VisibleDeprecationWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if e.category is not np.VisibleDeprecationWarning]
if len(w) > 0:
raise AssertionError("Got warnings when calling %s: [%s]"
% (func.__name__,
', '.join(str(warning) for warning in w)))
return result
def ignore_warnings(obj=None, category=Warning):
0
Source : test_io.py
with Apache License 2.0
from aws-samples
with Apache License 2.0
from aws-samples
def test_utf8_file_nodtype_unicode(self):
# bytes encoding with non-latin1 -> unicode upcast
utf8 = u'\u03d6'
latin1 = u'\xf6\xfc\xf6'
# skip test if cannot encode utf8 test string with preferred
# encoding. The preferred encoding is assumed to be the default
# encoding of io.open. Will need to change this for PyTest, maybe
# using pytest.mark.xfail(raises=***).
try:
encoding = locale.getpreferredencoding()
utf8.encode(encoding)
except (UnicodeError, ImportError):
pytest.skip('Skipping test_utf8_file_nodtype_unicode, '
'unable to encode utf8 in preferred encoding')
with temppath() as path:
with io.open(path, "wt") as f:
f.write(u"norm1,norm2,norm3\n")
f.write(u"norm1," + latin1 + u",norm3\n")
f.write(u"test1,testNonethe" + utf8 + u",test3\n")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '',
np.VisibleDeprecationWarning)
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',')
# Check for warning when encoding not specified.
assert_(w[0].category is np.VisibleDeprecationWarning)
ctl = np.array([
["norm1", "norm2", "norm3"],
["norm1", latin1, "norm3"],
["test1", "testNonethe" + utf8, "test3"]],
dtype=np.unicode)
assert_array_equal(test, ctl)
def test_recfromtxt(self):
0
Source : __init__.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def index_of(y):
"""
A helper function to create reasonable x values for the given *y*.
This is used for plotting (x, y) if x values are not explicitly given.
First try ``y.index`` (assuming *y* is a `pandas.Series`), if that
fails, use ``range(len(y))``.
This will be extended in the future to deal with more types of
labeled data.
Parameters
----------
y : float or array-like
Returns
-------
x, y : ndarray
The x and y values to plot.
"""
try:
return y.index.values, y.values
except AttributeError:
pass
try:
y = _check_1d(y)
except (np.VisibleDeprecationWarning, ValueError):
# NumPy 1.19 will warn on ragged input, and we can't actually use it.
pass
else:
return np.arange(y.shape[0], dtype=float), y
raise ValueError('Input could not be cast to an at-least-1D NumPy array')
def safe_first_element(obj):
0
Source : test_io.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_autostrip(self):
# Test autostrip
data = "01/01/2003 , 1.3, abcde"
kwargs = dict(delimiter=",", dtype=None)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
mtest = np.genfromtxt(TextIO(data), **kwargs)
assert_(w[0].category is np.VisibleDeprecationWarning)
ctrl = np.array([('01/01/2003 ', 1.3, ' abcde')],
dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')])
assert_equal(mtest, ctrl)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
mtest = np.genfromtxt(TextIO(data), autostrip=True, **kwargs)
assert_(w[0].category is np.VisibleDeprecationWarning)
ctrl = np.array([('01/01/2003', 1.3, 'abcde')],
dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')])
assert_equal(mtest, ctrl)
def test_replace_space(self):
0
Source : test_io.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_utf8_file_nodtype_unicode(self):
# bytes encoding with non-latin1 -> unicode upcast
utf8 = u'\u03d6'
latin1 = u'\xf6\xfc\xf6'
# skip test if cannot encode utf8 test string with preferred
# encoding. The preferred encoding is assumed to be the default
# encoding of io.open. Will need to change this for PyTest, maybe
# using pytest.mark.xfail(raises=***).
try:
encoding = locale.getpreferredencoding()
utf8.encode(encoding)
except (UnicodeError, ImportError):
pytest.skip('Skipping test_utf8_file_nodtype_unicode, '
'unable to encode utf8 in preferred encoding')
with temppath() as path:
with io.open(path, "wt") as f:
f.write(u"norm1,norm2,norm3\n")
f.write(u"norm1," + latin1 + u",norm3\n")
f.write(u"test1,testNonethe" + utf8 + u",test3\n")
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '',
np.VisibleDeprecationWarning)
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',')
# Check for warning when encoding not specified.
assert_(w[0].category is np.VisibleDeprecationWarning)
ctl = np.array([
["norm1", "norm2", "norm3"],
["norm1", latin1, "norm3"],
["test1", "testNonethe" + utf8, "test3"]],
dtype=np.unicode_)
assert_array_equal(test, ctl)
def test_recfromtxt(self):
0
Source : test_distributions.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def test_weights(self, rng):
import warnings
warnings.simplefilter("error", np.VisibleDeprecationWarning)
n = 100
x, y = rng.multivariate_normal([1, 3], [(.2, .5), (.5, 2)], n).T
hue = np.repeat([0, 1], n // 2)
weights = rng.uniform(0, 1, n)
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(x=x, y=y, hue=hue, ax=ax1)
kdeplot(x=x, y=y, hue=hue, weights=weights, ax=ax2)
for c1, c2 in zip(ax1.collections, ax2.collections):
if c1.get_segments() and c2.get_segments():
seg1 = np.concatenate(c1.get_segments(), axis=0)
seg2 = np.concatenate(c2.get_segments(), axis=0)
assert not np.array_equal(seg1, seg2)
def test_hue_ignores_cmap(self, long_df):
0
Source : _testing.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def assert_warns(warning_class, func, *args, **kw):
"""Test that a certain warning occurs.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
func : callable
Callable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`
Returns
-------
result : the return value of `func`
"""
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
result = func(*args, **kw)
if hasattr(np, 'FutureWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if e.category is not np.VisibleDeprecationWarning]
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = any(warning.category is warning_class for warning in w)
if not found:
raise AssertionError("%s did not give warning: %s( is %s)"
% (func.__name__, warning_class, w))
return result
def assert_warns_message(warning_class, message, func, *args, **kw):
0
Source : _testing.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def assert_warns_message(warning_class, message, func, *args, **kw):
# very important to avoid uncontrolled state propagation
"""Test that a certain warning occurs and with a certain message.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
message : str | callable
The message or a substring of the message to test for. If callable,
it takes a string as the argument and will trigger an AssertionError
if the callable returns `False`.
func : callable
Callable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`.
Returns
-------
result : the return value of `func`
"""
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
if hasattr(np, 'FutureWarning'):
# Let's not catch the numpy internal DeprecationWarnings
warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
# Trigger a warning.
result = func(*args, **kw)
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = [issubclass(warning.category, warning_class) for warning in w]
if not any(found):
raise AssertionError("No warning raised for %s with class "
"%s"
% (func.__name__, warning_class))
message_found = False
# Checks the message of all warnings belong to warning_class
for index in [i for i, x in enumerate(found) if x]:
# substring will match, the entire message with typo won't
msg = w[index].message # For Python 3 compatibility
msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
if callable(message): # add support for certain tests
check_in_message = message
else:
def check_in_message(msg): return message in msg
if check_in_message(msg):
message_found = True
break
if not message_found:
raise AssertionError("Did not receive the message you expected "
"('%s') for < %s>, got: '%s'"
% (message, func.__name__, msg))
return result
def assert_warns_div0(func, *args, **kw):
0
Source : _testing.py
with Apache License 2.0
from dashanji
with Apache License 2.0
from dashanji
def assert_no_warnings(func, *args, **kw):
"""
Parameters
----------
func
*args
**kw
"""
# very important to avoid uncontrolled state propagation
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
result = func(*args, **kw)
if hasattr(np, 'FutureWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if e.category is not np.VisibleDeprecationWarning]
if len(w) > 0:
raise AssertionError("Got warnings when calling %s: [%s]"
% (func.__name__,
', '.join(str(warning) for warning in w)))
return result
def ignore_warnings(obj=None, category=Warning):
0
Source : compute_delta_var.py
with MIT License
from DPBayes
with MIT License
from DPBayes
def get_delta_R(
sigma_t: np.ndarray,
q_t: np.ndarray,
k: np.ndarray,
target_eps: float = 1.0,
nx: int = int(1E6),
L: float = 20.0
):
"""
Computes the DP delta for the remove/add neighbouring relation of datasets.
The computed delta privacy value is for the composition of DP operations
as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
privacy noise and subsampling ratio for each operation and `k` is the number
of repetitions, i.e.,
- `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
- `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
- etc
for a total of `np.sum(k)` operations.
Note that this function relies on numerical approximations, which are influenced
by choice of parameters nx and L. Increasing L roughly increases the range over
which the integral of the privacy loss distribution is approximated, while nx is
the number of evaluation points in [-L,L]. If you find results output by this function
to be inaccurate, try adjusting these parameters. Refer to [1] for more details.
Parameters:
sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
target_eps (float): Target epsilon
nx (int): Number of discretisation points
L (float): Limit for the approximation of the privacy loss distribution integral
Returns:
(float): delta value
References:
Antti Koskela, Joonas Jälkö, Antti Honkela:
Computing Tight Differential Privacy Guarantees Using FFT
https://arxiv.org/abs/1906.03049
"""
warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_delta_R instead.", np.VisibleDeprecationWarning)
return get_delta_R_new(target_eps, sigma_t, q_t, k, nx, L)
def get_delta_S(
0
Source : compute_delta_var.py
with MIT License
from DPBayes
with MIT License
from DPBayes
def get_delta_S(
sigma_t: np.ndarray,
q_t: np.ndarray,
k: np.ndarray,
target_eps: float = 1.0,
nx: int = int(1E6),
L: float = 20.0
):
"""
Computes the DP delta for the substitute neighbouring relation of datasets.
The computed delta privacy value is for the composition of DP operations
as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
privacy noise and subsampling ratio for each operation and `k` is the number
of repetitions, i.e.,
- `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
- `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
- etc
for a total of `np.sum(k)` operations.
Note that this function relies on numerical approximations, which are influenced
by choice of parameters nx and L. Increasing L roughly increases the range over
which the integral of the privacy loss distribution is approximated, while nx is
the number of evaluation points in [-L,L]. If you find results output by this function
to be inaccurate, try adjusting these parameters. Refer to [1] for more details.
Parameters:
sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
target_eps (float): Target epsilon
nx (int): Number of discretisation points
L (float): Limit for the approximation of the privacy loss distribution integral
Returns:
(float): delta value
References:
Antti Koskela, Joonas Jälkö, Antti Honkela:
Computing Tight Differential Privacy Guarantees Using FFT
https://arxiv.org/abs/1906.03049
"""
warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_delta_S instead.", np.VisibleDeprecationWarning)
return get_delta_S_new(target_eps, sigma_t, q_t, k, nx, L)
0
Source : compute_eps_var.py
with MIT License
from DPBayes
with MIT License
from DPBayes
def get_epsilon_R(
sigma_t: np.ndarray,
q_t: np.ndarray,
k: np.ndarray,
target_delta: float = 1e-6,
nx: int = int(1E6),
L: float = 20.0
):
"""
Computes the DP epsilon for the remove/add neighbouring relation of datasets.
The computed epsilon privacy value is for the composition of DP operations
as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
privacy noise and subsampling ratio for each operation and `k` is the number
of repetitions, i.e.,
- `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
- `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
- etc
for a total of `np.sum(k)` operations.
Note that this function relies on numerical approximations, which are influenced
by choice of parameters nx and L. Increasing L roughly increases the range over
which the integral of the privacy loss distribution is approximated. L must be chosen
large enough to cover the computed epsilon, otherwise a ValueError is raised. Try
increasing L if this happens.
nx is the number of evaluation points in [-L,L]. If you find results output by this
function to be inaccurate, try adjusting these parameters. Refer to [1] for more details.
Due to numerical instabilities, corner cases exist where this function sometimes returns
inaccurate values. If you think this is occuring, increasing nx and verifying that
the returned value does not change by much is usually a good heuristic to verify the output.
Parameters:
sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
target_delta (float): Target delta
nx (int): Number of discretisation points
L (float): Limit for the approximation of the privacy loss distribution integral
Returns:
(float): epsilon value
References:
Antti Koskela, Joonas Jälkö, Antti Honkela:
Computing Tight Differential Privacy Guarantees Using FFT
https://arxiv.org/abs/1906.03049
"""
warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_epsilon_R instead.", np.VisibleDeprecationWarning)
return get_epsilon_R_new(target_delta, sigma_t, q_t, k, nx, L)
def get_epsilon_S(
0
Source : compute_eps_var.py
with MIT License
from DPBayes
with MIT License
from DPBayes
def get_epsilon_S(
sigma_t: np.ndarray,
q_t: np.ndarray,
k: np.ndarray,
target_delta: float = 1e-6,
nx: int = int(1E6),
L: float = 20.0
):
"""
Computes the DP epsilon for the substitute neighbouring relation of datasets.
The computed epsilon privacy value is for the composition of DP operations
as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
privacy noise and subsampling ratio for each operation and `k` is the number
of repetitions, i.e.,
- `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
- `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
- etc
for a total of `np.sum(k)` operations.
Note that this function relies on numerical approximations, which are influenced
by choice of parameters nx and L. Increasing L roughly increases the range over
which the integral of the privacy loss distribution is approximated. L must be chosen
large enough to cover the computed epsilon, otherwise a ValueError is raised. Try
increasing L if this happens.
nx is the number of evaluation points in [-L,L]. If you find results output by this
function to be inaccurate, try adjusting these parameters. Refer to [1] for more details.
Due to numerical instabilities, corner cases exist where this function sometimes returns
inaccurate values. If you think this is occuring, increasing nx and verifying that
the returned value does not change by much is usually a good heuristic to verify the output.
Parameters:
sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
target_delta (float): Target delta
nx (int): Number of discretisation points
L (float): Limit for the approximation of the privacy loss distribution integral
Returns:
(float): epsilon value
References:
Antti Koskela, Joonas Jälkö, Antti Honkela:
Computing Tight Differential Privacy Guarantees Using FFT
https://arxiv.org/abs/1906.03049
"""
warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_epsilon_S instead.", np.VisibleDeprecationWarning)
return get_epsilon_S_new(target_delta, sigma_t, q_t, k, nx, L)
0
Source : arrow_dataset.py
with Apache License 2.0
from ExpressAI
with Apache License 2.0
from ExpressAI
def to_tf_dataset(
self,
columns: Union[str, List[str]],
batch_size: int,
shuffle: bool,
drop_remainder: bool = None,
collate_fn: Callable = None,
collate_fn_args: Dict[str, Any] = None,
label_cols: Union[str, List[str]] = None,
dummy_labels: bool = False,
prefetch: bool = True,
):
"""Create a tf.data.Dataset from the underlying Dataset. This tf.data.Dataset will load and collate batches from
the Dataset, and is suitable for passing to methods like model.fit() or model.predict().
Args:
columns (:obj:`List[str]` or :obj:`str`): Dataset column(s) to load in the tf.data.Dataset. In general,
only columns that the model can use as input should be included here (numeric data only).
batch_size (:obj:`int`): Size of batches to load from the dataset.
shuffle(:obj:`bool`): Shuffle the dataset order when loading. Recommended True for training, False for
validation/evaluation.
drop_remainder(:obj:`bool`, default ``None``): Drop the last incomplete batch when loading. If not provided,
defaults to the same setting as shuffle.
collate_fn(:obj:`Callable`): A function or callable object (such as a `DataCollator`) that will collate
lists of samples into a batch.
collate_fn_args (:obj:`Dict`, optional): An optional `dict` of keyword arguments to be passed to the
`collate_fn`.
label_cols (:obj:`List[str]` or :obj:`str`, default ``None``): Dataset column(s) to load as
labels. Note that many models compute loss internally rather than letting Keras do it, in which case it is
not necessary to actually pass the labels here, as long as they're in the input `columns`.
dummy_labels (:obj:`bool`, default ``False``): If no `label_cols` are set, output an array of "dummy" labels
with each batch. This can avoid problems with `fit()` or `train_on_batch()` that expect labels to be
a Tensor or np.ndarray, but should (hopefully) not be necessary with our standard train_step().
prefetch (:obj:`bool`, default ``True``): Whether to run the dataloader in a separate thread and maintain
a small buffer of batches for training. Improves performance by allowing data to be loaded in the
background while the model is training.
"""
# TODO There is some hacky hardcoding in this function that needs to be fixed.
# We're planning to rework it so less code is needed at the start to remove columns before
# we know the final list of fields (post-data collator). This should clean up most of the special
# casing while retaining the API.
if config.TF_AVAILABLE:
import tensorflow as tf
else:
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.")
if collate_fn_args is None:
collate_fn_args = {}
if label_cols is None:
label_cols = []
elif isinstance(label_cols, str):
label_cols = [label_cols]
elif len(set(label_cols)) < len(label_cols):
raise ValueError("List of label_cols contains duplicates.")
if not columns:
raise ValueError("Need to specify at least one column.")
elif isinstance(columns, str):
columns = [columns]
elif len(set(columns)) < len(columns):
raise ValueError("List of columns contains duplicates.")
if label_cols is not None:
cols_to_retain = list(set(columns + label_cols))
else:
cols_to_retain = columns
# Special casing when the dataset has 'label' and the model expects 'labels' and the collator fixes it up for us
if "labels" in cols_to_retain and "labels" not in self.features and "label" in self.features:
cols_to_retain[cols_to_retain.index("labels")] = "label"
# Watch for nonexistent columns, except those that the data collators add for us
for col in cols_to_retain:
if col not in self.features and not (col in ("attention_mask", "labels") and collate_fn is not None):
raise ValueError(f"Couldn't find column {col} in dataset.")
if drop_remainder is None:
# We assume that if you're shuffling it's the train set, so we drop the remainder unless told not to
drop_remainder = shuffle
dataset = self.with_format("python", columns=[col for col in cols_to_retain if col in self.features])
def numpy_pad(data):
try:
# When this is finally fully removed, remove this line
# Alternatively, find a more elegant way to do this whole thing
np.warnings.filterwarnings("error", category=np.VisibleDeprecationWarning)
data = np.array(data)
if data.dtype == np.object:
raise AssertionError # Do it this way so that the assert doesn't get optimized out
return data
except (np.VisibleDeprecationWarning, AssertionError):
pass
# Get lengths of each row of data
lens = np.array([len(i) for i in data])
# Mask of valid places in each row
mask = np.arange(lens.max()) < lens[:, None]
# Setup output array and put elements from data into masked positions
out = np.zeros(mask.shape, dtype=np.array(data[0]).dtype)
out[mask] = np.concatenate(data)
return out
def np_get_batch(indices):
batch = dataset[indices]
out_batch = []
if collate_fn is not None:
actual_size = len(list(batch.values())[0]) # Get the length of one of the arrays, assume all same
# Our collators expect a list of dicts, not a dict of lists/arrays, so we invert
batch = [{key: value[i] for key, value in batch.items()} for i in range(actual_size)]
batch = collate_fn(batch, **collate_fn_args)
# Special casing when the dataset has 'label' and the model
# expects 'labels' and the collator fixes it up for us
if "label" in cols_to_retain and "label" not in batch and "labels" in batch:
cols_to_retain[cols_to_retain.index("label")] = "labels"
for key in cols_to_retain:
# In case the collate_fn returns something strange
array = np.array(batch[key])
cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
array = array.astype(cast_dtype)
out_batch.append(array)
else:
for key in cols_to_retain:
array = batch[key]
array = numpy_pad(array)
cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
array = array.astype(cast_dtype)
out_batch.append(array)
return [tf.convert_to_tensor(arr) for arr in out_batch]
test_batch = np_get_batch(np.arange(batch_size))
@tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
def fetch_function(indices):
output = tf.numpy_function(
np_get_batch, inp=[indices], Tout=[tf.dtypes.as_dtype(arr.dtype) for arr in test_batch]
)
return {key: output[i] for i, key in enumerate(cols_to_retain)}
test_batch_dict = {key: test_batch[i] for i, key in enumerate(cols_to_retain)}
output_signature = TensorflowDatasetMixin._get_output_signature(
dataset, cols_to_retain, test_batch_dict, batch_size=batch_size if drop_remainder else None
)
def ensure_shapes(input_dict):
return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()}
tf_dataset = tf.data.Dataset.from_tensor_slices(np.arange(len(dataset), dtype=np.int64))
if shuffle:
tf_dataset = tf_dataset.shuffle(len(dataset))
tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder).map(fetch_function).map(ensure_shapes)
if label_cols:
def split_features_and_labels(input_batch):
features = {key: tensor for key, tensor in input_batch.items() if key in columns}
labels = {key: tensor for key, tensor in input_batch.items() if key in label_cols}
if len(features) == 1:
features = list(features.values())[0]
if len(labels) == 1:
labels = list(labels.values())[0]
return features, labels
tf_dataset = tf_dataset.map(split_features_and_labels)
elif len(columns) == 1:
tf_dataset = tf_dataset.map(lambda x: list(x.values())[0])
if dummy_labels and not label_cols:
def add_dummy_labels(input_batch):
return input_batch, tf.zeros(tf.shape(input_batch[columns[0]])[0])
tf_dataset = tf_dataset.map(add_dummy_labels)
def rename_label_col(inputs, labels=None):
if not isinstance(inputs, tf.Tensor):
if "label" in inputs:
inputs["labels"] = inputs["label"]
del inputs["label"]
if labels is None:
return inputs
else:
return inputs, labels
tf_dataset = tf_dataset.map(rename_label_col)
if prefetch:
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Remove a reference to the open Arrow file on delete
def cleanup_callback(ref):
dataset.__del__()
self._TF_DATASET_REFS.remove(ref)
self._TF_DATASET_REFS.add(weakref.ref(tf_dataset, cleanup_callback))
return tf_dataset
class DatasetTransformationNotAllowedError(Exception):
0
Source : _testing.py
with GNU General Public License v3.0
from gustavowillam
with GNU General Public License v3.0
from gustavowillam
def assert_warns_message(warning_class, message, func, *args, **kw):
# very important to avoid uncontrolled state propagation
"""Test that a certain warning occurs and with a certain message.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
message : str | callable
The message or a substring of the message to test for. If callable,
it takes a string as the argument and will trigger an AssertionError
if the callable returns `False`.
func : callable
Callable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`.
Returns
-------
result : the return value of `func`
"""
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
if hasattr(np, 'FutureWarning'):
# Let's not catch the numpy internal DeprecationWarnings
warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
# Trigger a warning.
result = func(*args, **kw)
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = [issubclass(warning.category, warning_class) for warning in w]
if not any(found):
raise AssertionError("No warning raised for %s with class "
"%s"
% (func.__name__, warning_class))
message_found = False
# Checks the message of all warnings belong to warning_class
for index in [i for i, x in enumerate(found) if x]:
# substring will match, the entire message with typo won't
msg = w[index].message # For Python 3 compatibility
msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
if callable(message): # add support for certain tests
check_in_message = message
else:
check_in_message = lambda msg: message in msg
if check_in_message(msg):
message_found = True
break
if not message_found:
raise AssertionError("Did not receive the message you expected "
"('%s') for < %s>, got: '%s'"
% (message, func.__name__, msg))
return result
def assert_warns_div0(func, *args, **kw):
0
Source : multiclass.py
with GNU General Public License v3.0
from gustavowillam
with GNU General Public License v3.0
from gustavowillam
def is_multilabel(y):
""" Check if ``y`` is in a multilabel format.
Parameters
----------
y : ndarray of shape (n_samples,)
Target values.
Returns
-------
out : bool
Return ``True``, if ``y`` is in a multilabel format, else ```False``.
Examples
--------
>>> import numpy as np
>>> from sklearn.utils.multiclass import is_multilabel
>>> is_multilabel([0, 1, 0, 1])
False
>>> is_multilabel([[1], [0, 2], []])
False
>>> is_multilabel(np.array([[1, 0], [0, 0]]))
True
>>> is_multilabel(np.array([[1], [0], [0]]))
False
>>> is_multilabel(np.array([[1, 0, 0]]))
True
"""
if hasattr(y, '__array__') or isinstance(y, Sequence):
# DeprecationWarning will be replaced by ValueError, see NEP 34
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
with warnings.catch_warnings():
warnings.simplefilter('error', np.VisibleDeprecationWarning)
try:
y = np.asarray(y)
except np.VisibleDeprecationWarning:
# dtype=object should be provided explicitly for ragged arrays,
# see NEP 34
y = np.array(y, dtype=object)
if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1):
return False
if issparse(y):
if isinstance(y, (dok_matrix, lil_matrix)):
y = y.tocsr()
return (len(y.data) == 0 or np.unique(y.data).size == 1 and
(y.dtype.kind in 'biu' or # bool, int, uint
_is_integral_float(np.unique(y.data))))
else:
labels = np.unique(y)
return len(labels) < 3 and (y.dtype.kind in 'biu' or # bool, int, uint
_is_integral_float(labels))
def check_classification_targets(y):
0
Source : multiclass.py
with GNU General Public License v3.0
from gustavowillam
with GNU General Public License v3.0
from gustavowillam
def type_of_target(y):
"""Determine the type of data indicated by the target.
Note that this type is the most specific type that can be inferred.
For example:
* ``binary`` is more specific but compatible with ``multiclass``.
* ``multiclass`` of integers is more specific but compatible with
``continuous``.
* ``multilabel-indicator`` is more specific but compatible with
``multiclass-multioutput``.
Parameters
----------
y : array-like
Returns
-------
target_type : str
One of:
* 'continuous': `y` is an array-like of floats that are not all
integers, and is 1d or a column vector.
* 'continuous-multioutput': `y` is a 2d array of floats that are
not all integers, and both dimensions are of size > 1.
* 'binary': `y` contains < = 2 discrete values and is 1d or a column
vector.
* 'multiclass': `y` contains more than two discrete values, is not a
sequence of sequences, and is 1d or a column vector.
* 'multiclass-multioutput': `y` is a 2d array that contains more
than two discrete values, is not a sequence of sequences, and both
dimensions are of size > 1.
* 'multilabel-indicator': `y` is a label indicator matrix, an array
of two dimensions with at least two columns, and at most 2 unique
values.
* 'unknown': `y` is array-like but none of the above, such as a 3d
array, sequence of sequences, or an array of non-sequence objects.
Examples
--------
>>> import numpy as np
>>> type_of_target([0.1, 0.6])
'continuous'
>>> type_of_target([1, -1, -1, 1])
'binary'
>>> type_of_target(['a', 'b', 'a'])
'binary'
>>> type_of_target([1.0, 2.0])
'binary'
>>> type_of_target([1, 0, 2])
'multiclass'
>>> type_of_target([1.0, 0.0, 3.0])
'multiclass'
>>> type_of_target(['a', 'b', 'c'])
'multiclass'
>>> type_of_target(np.array([[1, 2], [3, 1]]))
'multiclass-multioutput'
>>> type_of_target([[1, 2]])
'multilabel-indicator'
>>> type_of_target(np.array([[1.5, 2.0], [3.0, 1.6]]))
'continuous-multioutput'
>>> type_of_target(np.array([[0, 1], [1, 1]]))
'multilabel-indicator'
"""
valid = ((isinstance(y, (Sequence, spmatrix)) or hasattr(y, '__array__'))
and not isinstance(y, str))
if not valid:
raise ValueError('Expected array-like (array or non-string sequence), '
'got %r' % y)
sparse_pandas = (y.__class__.__name__ in ['SparseSeries', 'SparseArray'])
if sparse_pandas:
raise ValueError("y cannot be class 'SparseSeries' or 'SparseArray'")
if is_multilabel(y):
return 'multilabel-indicator'
# DeprecationWarning will be replaced by ValueError, see NEP 34
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
with warnings.catch_warnings():
warnings.simplefilter('error', np.VisibleDeprecationWarning)
try:
y = np.asarray(y)
except np.VisibleDeprecationWarning:
# dtype=object should be provided explicitly for ragged arrays,
# see NEP 34
y = np.asarray(y, dtype=object)
# The old sequence of sequences format
try:
if (not hasattr(y[0], '__array__') and isinstance(y[0], Sequence)
and not isinstance(y[0], str)):
raise ValueError('You appear to be using a legacy multi-label data'
' representation. Sequence of sequences are no'
' longer supported; use a binary array or sparse'
' matrix instead - the MultiLabelBinarizer'
' transformer can convert to this format.')
except IndexError:
pass
# Invalid inputs
if y.ndim > 2 or (y.dtype == object and len(y) and
not isinstance(y.flat[0], str)):
return 'unknown' # [[[1, 2]]] or [obj_1] and not ["label_1"]
if y.ndim == 2 and y.shape[1] == 0:
return 'unknown' # [[]]
if y.ndim == 2 and y.shape[1] > 1:
suffix = "-multioutput" # [[1, 2], [1, 2]]
else:
suffix = "" # [1, 2, 3] or [[1], [2], [3]]
# check float and contains non-integer float values
if y.dtype.kind == 'f' and np.any(y != y.astype(int)):
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
_assert_all_finite(y)
return 'continuous' + suffix
if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
return 'multiclass' + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
else:
return 'binary' # [1, 2] or [["a"], ["b"]]
def _check_partial_fit_first_call(clf, classes=None):
0
Source : _testing.py
with GNU General Public License v3.0
from gustavowillam
with GNU General Public License v3.0
from gustavowillam
def assert_warns(warning_class, func, *args, **kw):
"""Test that a certain warning occurs.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
func : callable
Callable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`
Returns
-------
result : the return value of `func`
"""
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
result = func(*args, **kw)
if hasattr(np, 'FutureWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if e.category is not np.VisibleDeprecationWarning]
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = any(warning.category is warning_class for warning in w)
if not found:
raise AssertionError("%s did not give warning: %s( is %s)"
% (func.__name__, warning_class, w))
return result
def assert_warns_message(warning_class, message, func, *args, **kw):
0
Source : _testing.py
with GNU General Public License v3.0
from gustavowillam
with GNU General Public License v3.0
from gustavowillam
def assert_warns_message(warning_class, message, func, *args, **kw):
# very important to avoid uncontrolled state propagation
"""Test that a certain warning occurs and with a certain message.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
message : str or callable
The message or a substring of the message to test for. If callable,
it takes a string as the argument and will trigger an AssertionError
if the callable returns `False`.
func : callable
Callable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`.
Returns
-------
result : the return value of `func`
"""
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
if hasattr(np, 'FutureWarning'):
# Let's not catch the numpy internal DeprecationWarnings
warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
# Trigger a warning.
result = func(*args, **kw)
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = [issubclass(warning.category, warning_class) for warning in w]
if not any(found):
raise AssertionError("No warning raised for %s with class "
"%s"
% (func.__name__, warning_class))
message_found = False
# Checks the message of all warnings belong to warning_class
for index in [i for i, x in enumerate(found) if x]:
# substring will match, the entire message with typo won't
msg = w[index].message # For Python 3 compatibility
msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
if callable(message): # add support for certain tests
check_in_message = message
else:
def check_in_message(msg): return message in msg
if check_in_message(msg):
message_found = True
break
if not message_found:
raise AssertionError("Did not receive the message you expected "
"('%s') for < %s>, got: '%s'"
% (message, func.__name__, msg))
return result
def assert_warns_div0(func, *args, **kw):
0
Source : testing.py
with GNU General Public License v3.0
from HHHHhgqcdxhg
with GNU General Public License v3.0
from HHHHhgqcdxhg
def assert_warns(warning_class, func, *args, **kw):
"""Test that a certain warning occurs.
Parameters
----------
warning_class : the warning class
The class to test for, e.g. UserWarning.
func : callable
Callable object to trigger warnings.
*args : the positional arguments to `func`.
**kw : the keyword arguments to `func`
Returns
-------
result : the return value of `func`
"""
clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
result = func(*args, **kw)
if hasattr(np, 'VisibleDeprecationWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if e.category is not np.VisibleDeprecationWarning]
# Verify some things
if not len(w) > 0:
raise AssertionError("No warning raised when calling %s"
% func.__name__)
found = any(warning.category is warning_class for warning in w)
if not found:
raise AssertionError("%s did not give warning: %s( is %s)"
% (func.__name__, warning_class, w))
return result
def assert_warns_message(warning_class, message, func, *args, **kw):
See More Examples