numpy.VisibleDeprecationWarning

Here are the examples of the python api numpy.VisibleDeprecationWarning taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

60 Examples 7

3 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_header(self):
        # Test retrieving a header
        data = TextIO('gender age weight\nM 64.0 75.0\nF 25.0 60.0')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.ndfromtxt(data, dtype=None, names=True)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = {'gender': np.array([b'M', b'F']),
                   'age': np.array([64.0, 25.0]),
                   'weight': np.array([75.0, 60.0])}
        assert_equal(test['gender'], control['gender'])
        assert_equal(test['age'], control['age'])
        assert_equal(test['weight'], control['weight'])

    def test_auto_dtype(self):

3 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_auto_dtype(self):
        # Test the automatic definition of the output dtype
        data = TextIO('A 64 75.0 3+4j True\nBCD 25 60.0 5+6j False')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.ndfromtxt(data, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = [np.array([b'A', b'BCD']),
                   np.array([64, 25]),
                   np.array([75.0, 60.0]),
                   np.array([3 + 4j, 5 + 6j]),
                   np.array([True, False]), ]
        assert_equal(test.dtype.names, ['f0', 'f1', 'f2', 'f3', 'f4'])
        for (i, ctrl) in enumerate(control):
            assert_equal(test['f%i' % i], ctrl)

    def test_auto_dtype_uniform(self):

3 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_autonames_and_usecols(self):
        # Tests names and usecols
        data = TextIO('A B C D\n aaaa 121 45 9.1')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.ndfromtxt(data, usecols=('A', 'C', 'D'),
                                names=True, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = np.array(('aaaa', 45, 9.1),
                           dtype=[('A', '|S4'), ('C', int), ('D', float)])
        assert_equal(test, control)

    def test_converters_with_usecols(self):

3 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_converters_with_usecols_and_names(self):
        # Tests names and usecols
        data = TextIO('A B C D\n aaaa 121 45 9.1')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.ndfromtxt(data, usecols=('A', 'C', 'D'), names=True,
                                dtype=None,
                                converters={'C': lambda s: 2 * int(s)})
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = np.array(('aaaa', 90, 9.1),
                           dtype=[('A', '|S4'), ('C', int), ('D', float)])
        assert_equal(test, control)

    def test_converters_cornercases(self):

3 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_comments_is_none(self):
        # Github issue 329 (None was previously being converted to 'None').
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(TextIO("test1,testNonetherestofthedata"),
                                 dtype=None, comments=None, delimiter=',')
            assert_(w[0].category is np.VisibleDeprecationWarning)
        assert_equal(test[1], b'testNonetherestofthedata')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(TextIO("test1, testNonetherestofthedata"),
                                 dtype=None, comments=None, delimiter=',')
            assert_(w[0].category is np.VisibleDeprecationWarning)
        assert_equal(test[1], b' testNonetherestofthedata')

    def test_latin1(self):

3 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_utf8_byte_encoding(self):
        utf8 = b"\xcf\x96"
        norm = b"norm1,norm2,norm3\n"
        enc = b"test1,testNonethe" + utf8 + b",test3\n"
        s = norm + enc + norm
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(TextIO(s),
                                 dtype=None, comments=None, delimiter=',')
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctl = np.array([
                 [b'norm1', b'norm2', b'norm3'],
                 [b'test1', b'testNonethe' + utf8, b'test3'],
                 [b'norm1', b'norm2', b'norm3']])
        assert_array_equal(test, ctl)

    def test_utf8_file(self):

3 Source : test_regression.py
with Apache License 2.0
from aws-samples

    def test_typeNA(self):
        # Issue gh-515
        with suppress_warnings() as sup:
            sup.filter(np.VisibleDeprecationWarning)
            assert_equal(np.typeNA[np.int64], 'Int64')
            assert_equal(np.typeNA[np.uint64], 'UInt64')

    def test_dtype_names(self):

3 Source : test_deprecations.py
with Apache License 2.0
from dashanji

def test_deprecate_ragged_arrays():
    # 2019-11-29 1.19.0
    #
    # NEP 34 deprecated automatic object dtype when creating ragged
    # arrays. Also see the "ragged" tests in `test_multiarray`
    #
    # emits a VisibleDeprecationWarning
    arg = [1, [2, 3]]
    with assert_warns(np.VisibleDeprecationWarning):
        np.array(arg)


class TestToString(_DeprecationTestCase):

3 Source : test_io.py
with Apache License 2.0
from dashanji

    def test_header(self):
        # Test retrieving a header
        data = TextIO('gender age weight\nM 64.0 75.0\nF 25.0 60.0')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, dtype=None, names=True)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = {'gender': np.array([b'M', b'F']),
                   'age': np.array([64.0, 25.0]),
                   'weight': np.array([75.0, 60.0])}
        assert_equal(test['gender'], control['gender'])
        assert_equal(test['age'], control['age'])
        assert_equal(test['weight'], control['weight'])

    def test_auto_dtype(self):

3 Source : test_io.py
with Apache License 2.0
from dashanji

    def test_auto_dtype(self):
        # Test the automatic definition of the output dtype
        data = TextIO('A 64 75.0 3+4j True\nBCD 25 60.0 5+6j False')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = [np.array([b'A', b'BCD']),
                   np.array([64, 25]),
                   np.array([75.0, 60.0]),
                   np.array([3 + 4j, 5 + 6j]),
                   np.array([True, False]), ]
        assert_equal(test.dtype.names, ['f0', 'f1', 'f2', 'f3', 'f4'])
        for (i, ctrl) in enumerate(control):
            assert_equal(test['f%i' % i], ctrl)

    def test_auto_dtype_uniform(self):

3 Source : test_io.py
with Apache License 2.0
from dashanji

    def test_autonames_and_usecols(self):
        # Tests names and usecols
        data = TextIO('A B C D\n aaaa 121 45 9.1')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, usecols=('A', 'C', 'D'),
                                names=True, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = np.array(('aaaa', 45, 9.1),
                           dtype=[('A', '|S4'), ('C', int), ('D', float)])
        assert_equal(test, control)

    def test_converters_with_usecols(self):

3 Source : test_io.py
with Apache License 2.0
from dashanji

    def test_converters_with_usecols_and_names(self):
        # Tests names and usecols
        data = TextIO('A B C D\n aaaa 121 45 9.1')
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, usecols=('A', 'C', 'D'), names=True,
                                dtype=None,
                                converters={'C': lambda s: 2 * int(s)})
            assert_(w[0].category is np.VisibleDeprecationWarning)
        control = np.array(('aaaa', 90, 9.1),
                           dtype=[('A', '|S4'), ('C', int), ('D', float)])
        assert_equal(test, control)

    def test_converters_cornercases(self):

3 Source : test_distributions.py
with Apache License 2.0
from dashanji

def test_rvs_no_size_warning():
    class rvs_no_size_gen(stats.rv_continuous):
        def _rvs(self):
            return 1

    rvs_no_size = rvs_no_size_gen(name='rvs_no_size')

    with assert_warns(np.VisibleDeprecationWarning):
        rvs_no_size.rvs()

3 Source : test_voxels.py
with GNU General Public License v3.0
from dbbs-lab

    def test_data_ctor(self):
        with self.assertRaises(ValueError):
            VoxelSet([], [], [1])
        with self.assertRaises(ValueError):
            VoxelSet([], [], 1)
        VoxelSet([[1, 2, 3]], [[1, 0, 0]], [1])
        VoxelSet([[1, 2, 3]], 1, [1])
        VoxelSet([[1, 2, 3]], 1, [[1, 2]])
        with self.assertWarns(np.VisibleDeprecationWarning):
            VoxelSet([[1, 2, 3], [0, 0, 0]], 1, [[1, 2], [1]])
        het = VoxelSet([[1, 2, 3], [2, 0, 0]], 1, [[1, 2], [1, "a"]])

    def test_unequal_len(self):

3 Source : test_voxels.py
with GNU General Public License v3.0
from dbbs-lab

    def test_ragged(self):
        with self.assertRaises(ValueError):
            with self.assertWarns(np.VisibleDeprecationWarning):
                ragged = VoxelSet([[1, 0, 1], [1, 1]], [1, 1, 1, 1])

    def test_get_size(self):

3 Source : test_win_probability_simulation.py
with MIT License
from djcunningham0

def test_invalid_scores_to_result_proportions():
    # inconsistent lengths
    with pytest.raises(Exception):
        warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
        scores = np.array([
            [100, 90, 100],
            [95, 105],
            [80, 120, 80],
        ])
        multielo.MultiElo._convert_scores_to_result_proportions(scores)

3 Source : test_deprecations.py
with GNU General Public License v3.0
from dnn-security

def test_deprecate_ragged_arrays():
    # 2019-11-29 1.19.0
    #
    # NEP 34 deprecated automatic object dtype when creating ragged
    # arrays. Also see the "ragged" tests in `test_multiarray`
    #
    # emits a VisibleDeprecationWarning
    arg = [1, [2, 3]]
    with assert_warns(np.VisibleDeprecationWarning):
        np.array(arg)


class TestTooDeepDeprecation(_VisibleDeprecationTestCase):

3 Source : test_distributions.py
with MIT License
from osamhack2021

def test_rvs_no_size_warning():
    class rvs_no_size_gen(stats.rv_continuous):
        def _rvs(self):
            return 1

    rvs_no_size = rvs_no_size_gen(name='rvs_no_size')

    with assert_warns(np.VisibleDeprecationWarning):
        rvs_no_size.rvs()


@pytest.mark.parametrize('distname, args', invdistdiscrete + invdistcont)

3 Source : test_util.py
with GNU General Public License v3.0
from pyxem

    def test_deprecation_no_old_doc(self):
        @deprecated(since=0.7, alternative="bar", removal=0.8)
        def foo(n):
            return n + 1

        with pytest.warns(np.VisibleDeprecationWarning) as record:
            assert foo(4) == 5
        desired_msg = (
            "Function `foo()` is deprecated and will be removed in version 0.8. Use "
            "`bar()` instead."
        )
        assert str(record[0].message) == desired_msg
        assert foo.__doc__ == (
            "[*Deprecated*] \n"
            "\nNotes\n-----\n"
            ".. deprecated:: 0.7\n"
            f"   {desired_msg}"
        )

3 Source : test_distributions.py
with MIT License
from tpike3

def test_rvs_no_size_warning():
    class rvs_no_size_gen(stats.rv_continuous):
        def _rvs(self):
            return 1

    rvs_no_size = rvs_no_size_gen(name='rvs_no_size')

    with assert_warns(np.VisibleDeprecationWarning):
        rvs_no_size.rvs()

0 Source : test_indexing.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_multidim(self):
        # Automatically test combinations with complex indexes on 2nd (or 1st)
        # spot and the simple ones in one other spot.
        with warnings.catch_warnings():
            # This is so that np.array(True) is not accepted in a full integer
            # index, when running the file separately.
            warnings.filterwarnings('error', '', DeprecationWarning)
            warnings.filterwarnings('error', '', np.VisibleDeprecationWarning)

            def isskip(idx):
                return isinstance(idx, str) and idx == "skip"

            for simple_pos in [0, 2, 3]:
                tocheck = [self.fill_indices, self.complex_indices,
                           self.fill_indices, self.fill_indices]
                tocheck[simple_pos] = self.simple_indices
                for index in product(*tocheck):
                    index = tuple(i for i in index if not isskip(i))
                    self._check_multi_index(self.a, index)
                    self._check_multi_index(self.b, index)

        # Check very simple item getting:
        self._check_multi_index(self.a, (0, 0, 0, 0))
        self._check_multi_index(self.b, (0, 0, 0, 0))
        # Also check (simple cases of) too many indices:
        assert_raises(IndexError, self.a.__getitem__, (0, 0, 0, 0, 0))
        assert_raises(IndexError, self.a.__setitem__, (0, 0, 0, 0, 0), 0)
        assert_raises(IndexError, self.a.__getitem__, (0, 0, [1], 0, 0))
        assert_raises(IndexError, self.a.__setitem__, (0, 0, [1], 0, 0), 0)

    def test_1d(self):

0 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_commented_header(self):
        # Check that names can be retrieved even if the line is commented out.
        data = TextIO("""
#gender age weight
M   21  72.100000
F   35  58.330000
M   33  21.99
        """)
        # The # is part of the first name and should be deleted automatically.
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, names=True, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctrl = np.array([('M', 21, 72.1), ('F', 35, 58.33), ('M', 33, 21.99)],
                        dtype=[('gender', '|S1'), ('age', int), ('weight', float)])
        assert_equal(test, ctrl)
        # Ditto, but we should get rid of the first element
        data = TextIO(b"""
# gender age weight
M   21  72.100000
F   35  58.330000
M   33  21.99
        """)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, names=True, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        assert_equal(test, ctrl)

    def test_autonames_and_usecols(self):

0 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_autostrip(self):
        # Test autostrip
        data = "01/01/2003  , 1.3,   abcde"
        kwargs = dict(delimiter=",", dtype=None)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            mtest = np.ndfromtxt(TextIO(data), **kwargs)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctrl = np.array([('01/01/2003  ', 1.3, '   abcde')],
                        dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')])
        assert_equal(mtest, ctrl)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            mtest = np.ndfromtxt(TextIO(data), autostrip=True, **kwargs)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctrl = np.array([('01/01/2003', 1.3, 'abcde')],
                        dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')])
        assert_equal(mtest, ctrl)

    def test_replace_space(self):

0 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_latin1(self):
        latin1 = b'\xf6\xfc\xf6'
        norm = b"norm1,norm2,norm3\n"
        enc = b"test1,testNonethe" + latin1 + b",test3\n"
        s = norm + enc + norm
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(TextIO(s),
                                 dtype=None, comments=None, delimiter=',')
            assert_(w[0].category is np.VisibleDeprecationWarning)
        assert_equal(test[1, 0], b"test1")
        assert_equal(test[1, 1], b"testNonethe" + latin1)
        assert_equal(test[1, 2], b"test3")
        test = np.genfromtxt(TextIO(s),
                             dtype=None, comments=None, delimiter=',',
                             encoding='latin1')
        assert_equal(test[1, 0], u"test1")
        assert_equal(test[1, 1], u"testNonethe" + latin1.decode('latin1'))
        assert_equal(test[1, 2], u"test3")

        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(TextIO(b"0,testNonethe" + latin1),
                                 dtype=None, comments=None, delimiter=',')
            assert_(w[0].category is np.VisibleDeprecationWarning)
        assert_equal(test['f0'], 0)
        assert_equal(test['f1'], b"testNonethe" + latin1)

    def test_binary_decode_autodtype(self):

0 Source : test_io.py
with GNU General Public License v3.0
from adityaprakash-bobby

    def test_utf8_file_nodtype_unicode(self):
        # bytes encoding with non-latin1 -> unicode upcast
        utf8 = u'\u03d6'
        latin1 = u'\xf6\xfc\xf6'

        # skip test if cannot encode utf8 test string with preferred
        # encoding. The preferred encoding is assumed to be the default
        # encoding of io.open. Will need to change this for PyTest, maybe
        # using pytest.mark.xfail(raises=***).
        try:
            import locale
            encoding = locale.getpreferredencoding()
            utf8.encode(encoding)
        except (UnicodeError, ImportError):
            raise SkipTest('Skipping test_utf8_file_nodtype_unicode, '
                           'unable to encode utf8 in preferred encoding') 

        with temppath() as path:
            with io.open(path, "wt") as f:
                f.write(u"norm1,norm2,norm3\n")
                f.write(u"norm1," + latin1 + u",norm3\n")
                f.write(u"test1,testNonethe" + utf8 + u",test3\n")
            with warnings.catch_warnings(record=True) as w:
                warnings.filterwarnings('always', '',
                                        np.VisibleDeprecationWarning)
                test = np.genfromtxt(path, dtype=None, comments=None,
                                     delimiter=',')
                # Check for warning when encoding not specified.
                assert_(w[0].category is np.VisibleDeprecationWarning)
            ctl = np.array([
                     ["norm1", "norm2", "norm3"],
                     ["norm1", latin1, "norm3"],
                     ["test1", "testNonethe" + utf8, "test3"]],
                     dtype=np.unicode)
            assert_array_equal(test, ctl)

    def test_recfromtxt(self):

0 Source : test_histograms.py
with MIT License
from alvarobartt

    def test_normed(self):
        sup = suppress_warnings()
        with sup:
            rec = sup.record(np.VisibleDeprecationWarning, '.*normed.*')
            # Check that the integral of the density equals 1.
            n = 100
            v = np.random.rand(n)
            a, b = histogram(v, normed=True)
            area = np.sum(a * np.diff(b))
            assert_almost_equal(area, 1)
            assert_equal(len(rec), 1)

        sup = suppress_warnings()
        with sup:
            rec = sup.record(np.VisibleDeprecationWarning, '.*normed.*')
            # Check with non-constant bin widths (buggy but backwards
            # compatible)
            v = np.arange(10)
            bins = [0, 1, 5, 9, 10]
            a, b = histogram(v, bins, normed=True)
            area = np.sum(a * np.diff(b))
            assert_almost_equal(area, 1)
            assert_equal(len(rec), 1)

    def test_density(self):

0 Source : test_io.py
with MIT License
from alvarobartt

    def test_commented_header(self):
        # Check that names can be retrieved even if the line is commented out.
        data = TextIO("""
#gender age weight
M   21  72.100000
F   35  58.330000
M   33  21.99
        """)
        # The # is part of the first name and should be deleted automatically.
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, names=True, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctrl = np.array([('M', 21, 72.1), ('F', 35, 58.33), ('M', 33, 21.99)],
                        dtype=[('gender', '|S1'), ('age', int), ('weight', float)])
        assert_equal(test, ctrl)
        # Ditto, but we should get rid of the first element
        data = TextIO(b"""
# gender age weight
M   21  72.100000
F   35  58.330000
M   33  21.99
        """)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            test = np.genfromtxt(data, names=True, dtype=None)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        assert_equal(test, ctrl)

    def test_names_and_comments_none(self):

0 Source : test_io.py
with MIT License
from alvarobartt

    def test_utf8_file_nodtype_unicode(self):
        # bytes encoding with non-latin1 -> unicode upcast
        utf8 = u'\u03d6'
        latin1 = u'\xf6\xfc\xf6'

        # skip test if cannot encode utf8 test string with preferred
        # encoding. The preferred encoding is assumed to be the default
        # encoding of io.open. Will need to change this for PyTest, maybe
        # using pytest.mark.xfail(raises=***).
        try:
            encoding = locale.getpreferredencoding()
            utf8.encode(encoding)
        except (UnicodeError, ImportError):
            raise SkipTest('Skipping test_utf8_file_nodtype_unicode, '
                           'unable to encode utf8 in preferred encoding')

        with temppath() as path:
            with io.open(path, "wt") as f:
                f.write(u"norm1,norm2,norm3\n")
                f.write(u"norm1," + latin1 + u",norm3\n")
                f.write(u"test1,testNonethe" + utf8 + u",test3\n")
            with warnings.catch_warnings(record=True) as w:
                warnings.filterwarnings('always', '',
                                        np.VisibleDeprecationWarning)
                test = np.genfromtxt(path, dtype=None, comments=None,
                                     delimiter=',')
                # Check for warning when encoding not specified.
                assert_(w[0].category is np.VisibleDeprecationWarning)
            ctl = np.array([
                     ["norm1", "norm2", "norm3"],
                     ["norm1", latin1, "norm3"],
                     ["test1", "testNonethe" + utf8, "test3"]],
                     dtype=np.unicode)
            assert_array_equal(test, ctl)

    def test_recfromtxt(self):

0 Source : testing.py
with MIT License
from alvarobartt

def assert_warns(warning_class, func, *args, **kw):
    """Test that a certain warning occurs.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    func : callable
        Calable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`

    Returns
    -------

    result : the return value of `func`

    """
    # very important to avoid uncontrolled state propagation
    clean_warning_registry()
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        # Trigger a warning.
        result = func(*args, **kw)
        if hasattr(np, 'VisibleDeprecationWarning'):
            # Filter out numpy-specific warnings in numpy >= 1.9
            w = [e for e in w
                 if e.category is not np.VisibleDeprecationWarning]

        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = any(warning.category is warning_class for warning in w)
        if not found:
            raise AssertionError("%s did not give warning: %s( is %s)"
                                 % (func.__name__, warning_class, w))
    return result


def assert_warns_message(warning_class, message, func, *args, **kw):

0 Source : testing.py
with MIT License
from alvarobartt

def assert_warns_message(warning_class, message, func, *args, **kw):
    # very important to avoid uncontrolled state propagation
    """Test that a certain warning occurs and with a certain message.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    message : str | callable
        The entire message or a substring to  test for. If callable,
        it takes a string as argument and will trigger an assertion error
        if it returns `False`.

    func : callable
        Calable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`.

    Returns
    -------

    result : the return value of `func`

    """
    clean_warning_registry()
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        if hasattr(np, 'VisibleDeprecationWarning'):
            # Let's not catch the numpy internal DeprecationWarnings
            warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
        # Trigger a warning.
        result = func(*args, **kw)
        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = [issubclass(warning.category, warning_class) for warning in w]
        if not any(found):
            raise AssertionError("No warning raised for %s with class "
                                 "%s"
                                 % (func.__name__, warning_class))

        message_found = False
        # Checks the message of all warnings belong to warning_class
        for index in [i for i, x in enumerate(found) if x]:
            # substring will match, the entire message with typo won't
            msg = w[index].message  # For Python 3 compatibility
            msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
            if callable(message):  # add support for certain tests
                check_in_message = message
            else:
                check_in_message = lambda msg: message in msg

            if check_in_message(msg):
                message_found = True
                break

        if not message_found:
            raise AssertionError("Did not receive the message you expected "
                                 "('%s') for   <  %s>, got: '%s'"
                                 % (message, func.__name__, msg))

    return result


# To remove when we support numpy 1.7
def assert_no_warnings(func, *args, **kw):

0 Source : testing.py
with MIT License
from alvarobartt

def assert_no_warnings(func, *args, **kw):
    # very important to avoid uncontrolled state propagation
    clean_warning_registry()
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')

        result = func(*args, **kw)
        if hasattr(np, 'VisibleDeprecationWarning'):
            # Filter out numpy-specific warnings in numpy >= 1.9
            w = [e for e in w
                 if e.category is not np.VisibleDeprecationWarning]

        if len(w) > 0:
            raise AssertionError("Got warnings when calling %s: [%s]"
                                 % (func.__name__,
                                    ', '.join(str(warning) for warning in w)))
    return result


def ignore_warnings(obj=None, category=Warning):

0 Source : test_io.py
with Apache License 2.0
from aws-samples

    def test_utf8_file_nodtype_unicode(self):
        # bytes encoding with non-latin1 -> unicode upcast
        utf8 = u'\u03d6'
        latin1 = u'\xf6\xfc\xf6'

        # skip test if cannot encode utf8 test string with preferred
        # encoding. The preferred encoding is assumed to be the default
        # encoding of io.open. Will need to change this for PyTest, maybe
        # using pytest.mark.xfail(raises=***).
        try:
            encoding = locale.getpreferredencoding()
            utf8.encode(encoding)
        except (UnicodeError, ImportError):
            pytest.skip('Skipping test_utf8_file_nodtype_unicode, '
                        'unable to encode utf8 in preferred encoding')

        with temppath() as path:
            with io.open(path, "wt") as f:
                f.write(u"norm1,norm2,norm3\n")
                f.write(u"norm1," + latin1 + u",norm3\n")
                f.write(u"test1,testNonethe" + utf8 + u",test3\n")
            with warnings.catch_warnings(record=True) as w:
                warnings.filterwarnings('always', '',
                                        np.VisibleDeprecationWarning)
                test = np.genfromtxt(path, dtype=None, comments=None,
                                     delimiter=',')
                # Check for warning when encoding not specified.
                assert_(w[0].category is np.VisibleDeprecationWarning)
            ctl = np.array([
                     ["norm1", "norm2", "norm3"],
                     ["norm1", latin1, "norm3"],
                     ["test1", "testNonethe" + utf8, "test3"]],
                     dtype=np.unicode)
            assert_array_equal(test, ctl)

    def test_recfromtxt(self):

0 Source : __init__.py
with Apache License 2.0
from dashanji

def index_of(y):
    """
    A helper function to create reasonable x values for the given *y*.

    This is used for plotting (x, y) if x values are not explicitly given.

    First try ``y.index`` (assuming *y* is a `pandas.Series`), if that
    fails, use ``range(len(y))``.

    This will be extended in the future to deal with more types of
    labeled data.

    Parameters
    ----------
    y : float or array-like

    Returns
    -------
    x, y : ndarray
       The x and y values to plot.
    """
    try:
        return y.index.values, y.values
    except AttributeError:
        pass
    try:
        y = _check_1d(y)
    except (np.VisibleDeprecationWarning, ValueError):
        # NumPy 1.19 will warn on ragged input, and we can't actually use it.
        pass
    else:
        return np.arange(y.shape[0], dtype=float), y
    raise ValueError('Input could not be cast to an at-least-1D NumPy array')


def safe_first_element(obj):

0 Source : test_io.py
with Apache License 2.0
from dashanji

    def test_autostrip(self):
        # Test autostrip
        data = "01/01/2003  , 1.3,   abcde"
        kwargs = dict(delimiter=",", dtype=None)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            mtest = np.genfromtxt(TextIO(data), **kwargs)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctrl = np.array([('01/01/2003  ', 1.3, '   abcde')],
                        dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')])
        assert_equal(mtest, ctrl)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings('always', '', np.VisibleDeprecationWarning)
            mtest = np.genfromtxt(TextIO(data), autostrip=True, **kwargs)
            assert_(w[0].category is np.VisibleDeprecationWarning)
        ctrl = np.array([('01/01/2003', 1.3, 'abcde')],
                        dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')])
        assert_equal(mtest, ctrl)

    def test_replace_space(self):

0 Source : test_io.py
with Apache License 2.0
from dashanji

    def test_utf8_file_nodtype_unicode(self):
        # bytes encoding with non-latin1 -> unicode upcast
        utf8 = u'\u03d6'
        latin1 = u'\xf6\xfc\xf6'

        # skip test if cannot encode utf8 test string with preferred
        # encoding. The preferred encoding is assumed to be the default
        # encoding of io.open. Will need to change this for PyTest, maybe
        # using pytest.mark.xfail(raises=***).
        try:
            encoding = locale.getpreferredencoding()
            utf8.encode(encoding)
        except (UnicodeError, ImportError):
            pytest.skip('Skipping test_utf8_file_nodtype_unicode, '
                        'unable to encode utf8 in preferred encoding')

        with temppath() as path:
            with io.open(path, "wt") as f:
                f.write(u"norm1,norm2,norm3\n")
                f.write(u"norm1," + latin1 + u",norm3\n")
                f.write(u"test1,testNonethe" + utf8 + u",test3\n")
            with warnings.catch_warnings(record=True) as w:
                warnings.filterwarnings('always', '',
                                        np.VisibleDeprecationWarning)
                test = np.genfromtxt(path, dtype=None, comments=None,
                                     delimiter=',')
                # Check for warning when encoding not specified.
                assert_(w[0].category is np.VisibleDeprecationWarning)
            ctl = np.array([
                     ["norm1", "norm2", "norm3"],
                     ["norm1", latin1, "norm3"],
                     ["test1", "testNonethe" + utf8, "test3"]],
                     dtype=np.unicode_)
            assert_array_equal(test, ctl)

    def test_recfromtxt(self):

0 Source : test_distributions.py
with Apache License 2.0
from dashanji

    def test_weights(self, rng):

        import warnings
        warnings.simplefilter("error", np.VisibleDeprecationWarning)

        n = 100
        x, y = rng.multivariate_normal([1, 3], [(.2, .5), (.5, 2)], n).T
        hue = np.repeat([0, 1], n // 2)
        weights = rng.uniform(0, 1, n)

        f, (ax1, ax2) = plt.subplots(ncols=2)
        kdeplot(x=x, y=y, hue=hue, ax=ax1)
        kdeplot(x=x, y=y, hue=hue, weights=weights, ax=ax2)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            if c1.get_segments() and c2.get_segments():
                seg1 = np.concatenate(c1.get_segments(), axis=0)
                seg2 = np.concatenate(c2.get_segments(), axis=0)
                assert not np.array_equal(seg1, seg2)

    def test_hue_ignores_cmap(self, long_df):

0 Source : _testing.py
with Apache License 2.0
from dashanji

def assert_warns(warning_class, func, *args, **kw):
    """Test that a certain warning occurs.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    func : callable
        Callable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`

    Returns
    -------

    result : the return value of `func`

    """
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        # Trigger a warning.
        result = func(*args, **kw)
        if hasattr(np, 'FutureWarning'):
            # Filter out numpy-specific warnings in numpy >= 1.9
            w = [e for e in w
                 if e.category is not np.VisibleDeprecationWarning]

        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = any(warning.category is warning_class for warning in w)
        if not found:
            raise AssertionError("%s did not give warning: %s( is %s)"
                                 % (func.__name__, warning_class, w))
    return result


def assert_warns_message(warning_class, message, func, *args, **kw):

0 Source : _testing.py
with Apache License 2.0
from dashanji

def assert_warns_message(warning_class, message, func, *args, **kw):
    # very important to avoid uncontrolled state propagation
    """Test that a certain warning occurs and with a certain message.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    message : str | callable
        The message or a substring of the message to test for. If callable,
        it takes a string as the argument and will trigger an AssertionError
        if the callable returns `False`.

    func : callable
        Callable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`.

    Returns
    -------
    result : the return value of `func`

    """
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        if hasattr(np, 'FutureWarning'):
            # Let's not catch the numpy internal DeprecationWarnings
            warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
        # Trigger a warning.
        result = func(*args, **kw)
        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = [issubclass(warning.category, warning_class) for warning in w]
        if not any(found):
            raise AssertionError("No warning raised for %s with class "
                                 "%s"
                                 % (func.__name__, warning_class))

        message_found = False
        # Checks the message of all warnings belong to warning_class
        for index in [i for i, x in enumerate(found) if x]:
            # substring will match, the entire message with typo won't
            msg = w[index].message  # For Python 3 compatibility
            msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
            if callable(message):  # add support for certain tests
                check_in_message = message
            else:
                def check_in_message(msg): return message in msg

            if check_in_message(msg):
                message_found = True
                break

        if not message_found:
            raise AssertionError("Did not receive the message you expected "
                                 "('%s') for   <  %s>, got: '%s'"
                                 % (message, func.__name__, msg))

    return result


def assert_warns_div0(func, *args, **kw):

0 Source : _testing.py
with Apache License 2.0
from dashanji

def assert_no_warnings(func, *args, **kw):
    """
    Parameters
    ----------
    func
    *args
    **kw
    """
    # very important to avoid uncontrolled state propagation
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')

        result = func(*args, **kw)
        if hasattr(np, 'FutureWarning'):
            # Filter out numpy-specific warnings in numpy >= 1.9
            w = [e for e in w
                 if e.category is not np.VisibleDeprecationWarning]

        if len(w) > 0:
            raise AssertionError("Got warnings when calling %s: [%s]"
                                 % (func.__name__,
                                    ', '.join(str(warning) for warning in w)))
    return result


def ignore_warnings(obj=None, category=Warning):

0 Source : compute_delta_var.py
with MIT License
from DPBayes

def get_delta_R(
        sigma_t: np.ndarray,
        q_t: np.ndarray,
        k: np.ndarray,
        target_eps: float = 1.0,
        nx: int = int(1E6),
        L: float = 20.0
    ):
    """
    Computes the DP delta for the remove/add neighbouring relation of datasets.

    The computed delta privacy value is for the composition of DP operations
    as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
    privacy noise and subsampling ratio for each operation and `k` is the number
    of repetitions, i.e.,
    - `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
    - `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
    - etc
    for a total of `np.sum(k)` operations.

    Note that this function relies on numerical approximations, which are influenced
    by choice of parameters nx and L. Increasing L roughly increases the range over
    which the integral of the privacy loss distribution is approximated, while nx is
    the number of evaluation points in [-L,L]. If you find results output by this function
    to be inaccurate, try adjusting these parameters. Refer to [1] for more details.

    Parameters:
        sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
        q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
        k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
        target_eps (float): Target epsilon
        nx (int): Number of discretisation points
        L (float):  Limit for the approximation of the privacy loss distribution integral

    Returns:
        (float): delta value

    References:
        Antti Koskela, Joonas Jälkö, Antti Honkela:
        Computing Tight Differential Privacy Guarantees Using FFT
            https://arxiv.org/abs/1906.03049
    """
    warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_delta_R instead.", np.VisibleDeprecationWarning)
    return get_delta_R_new(target_eps, sigma_t, q_t, k, nx, L)


def get_delta_S(

0 Source : compute_delta_var.py
with MIT License
from DPBayes

def get_delta_S(
        sigma_t: np.ndarray,
        q_t: np.ndarray,
        k: np.ndarray,
        target_eps: float = 1.0,
        nx: int = int(1E6),
        L: float = 20.0
    ):
    """
    Computes the DP delta for the substitute neighbouring relation of datasets.

    The computed delta privacy value is for the composition of DP operations
    as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
    privacy noise and subsampling ratio for each operation and `k` is the number
    of repetitions, i.e.,
    - `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
    - `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
    - etc
    for a total of `np.sum(k)` operations.

    Note that this function relies on numerical approximations, which are influenced
    by choice of parameters nx and L. Increasing L roughly increases the range over
    which the integral of the privacy loss distribution is approximated, while nx is
    the number of evaluation points in [-L,L]. If you find results output by this function
    to be inaccurate, try adjusting these parameters. Refer to [1] for more details.

    Parameters:
        sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
        q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
        k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
        target_eps (float): Target epsilon
        nx (int): Number of discretisation points
        L (float):  Limit for the approximation of the privacy loss distribution integral

    Returns:
        (float): delta value

    References:
        Antti Koskela, Joonas Jälkö, Antti Honkela:
        Computing Tight Differential Privacy Guarantees Using FFT
            https://arxiv.org/abs/1906.03049
    """
    warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_delta_S instead.", np.VisibleDeprecationWarning)
    return get_delta_S_new(target_eps, sigma_t, q_t, k, nx, L)

0 Source : compute_eps_var.py
with MIT License
from DPBayes

def get_epsilon_R(
        sigma_t: np.ndarray,
        q_t: np.ndarray,
        k: np.ndarray,
        target_delta: float = 1e-6,
        nx: int = int(1E6),
        L: float = 20.0
    ):
    """
    Computes the DP epsilon for the remove/add neighbouring relation of datasets.

    The computed epsilon privacy value is for the composition of DP operations
    as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
    privacy noise and subsampling ratio for each operation and `k` is the number
    of repetitions, i.e.,
    - `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
    - `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
    - etc
    for a total of `np.sum(k)` operations.

    Note that this function relies on numerical approximations, which are influenced
    by choice of parameters nx and L. Increasing L roughly increases the range over
    which the integral of the privacy loss distribution is approximated. L must be chosen
    large enough to cover the computed epsilon, otherwise a ValueError is raised. Try
    increasing L if this happens.

    nx is the number of evaluation points in [-L,L]. If you find results output by this
    function to be inaccurate, try adjusting these parameters. Refer to [1] for more details.
    Due to numerical instabilities, corner cases exist where this function sometimes returns
    inaccurate values. If you think this is occuring, increasing nx and verifying that
    the returned value does not change by much is usually a good heuristic to verify the output.

    Parameters:
        sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
        q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
        k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
        target_delta (float): Target delta
        nx (int): Number of discretisation points
        L (float):  Limit for the approximation of the privacy loss distribution integral

    Returns:
        (float): epsilon value

    References:
        Antti Koskela, Joonas Jälkö, Antti Honkela:
        Computing Tight Differential Privacy Guarantees Using FFT
            https://arxiv.org/abs/1906.03049
    """
    warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_epsilon_R instead.", np.VisibleDeprecationWarning)
    return get_epsilon_R_new(target_delta, sigma_t, q_t, k, nx, L)

def get_epsilon_S(

0 Source : compute_eps_var.py
with MIT License
from DPBayes

def get_epsilon_S(
        sigma_t: np.ndarray,
        q_t: np.ndarray,
        k: np.ndarray,
        target_delta: float = 1e-6,
        nx: int = int(1E6),
        L: float = 20.0
    ):
    """
    Computes the DP epsilon for the substitute neighbouring relation of datasets.

    The computed epsilon privacy value is for the composition of DP operations
    as specified by `sigma_t`, `q_t` and `k`, where `sigma_t` and `q_t` specify
    privacy noise and subsampling ratio for each operation and `k` is the number
    of repetitions, i.e.,
    - `k[0]` operations with privacy noise `sigma_t[0]` and subsampling ratio `q_t[0]`
    - `k[1]` operations with privacy noise `sigma_t[1]` and subsampling ratio `q_t[1]`
    - etc
    for a total of `np.sum(k)` operations.

    Note that this function relies on numerical approximations, which are influenced
    by choice of parameters nx and L. Increasing L roughly increases the range over
    which the integral of the privacy loss distribution is approximated. L must be chosen
    large enough to cover the computed epsilon, otherwise a ValueError is raised. Try
    increasing L if this happens.

    nx is the number of evaluation points in [-L,L]. If you find results output by this
    function to be inaccurate, try adjusting these parameters. Refer to [1] for more details.
    Due to numerical instabilities, corner cases exist where this function sometimes returns
    inaccurate values. If you think this is occuring, increasing nx and verifying that
    the returned value does not change by much is usually a good heuristic to verify the output.

    Parameters:
        sigma_t (np.ndarray(float)): Privacy noise sigma for composed DP operations
        q_t (np.ndarray(float)): Subsampling ratios, i.e., how large are batches relative to the dataset
        k (np.ndarray(int)): Repetitions for each values in `sigma_t` and `q_t`
        target_delta (float): Target delta
        nx (int): Number of discretisation points
        L (float):  Limit for the approximation of the privacy loss distribution integral

    Returns:
        (float): epsilon value

    References:
        Antti Koskela, Joonas Jälkö, Antti Honkela:
        Computing Tight Differential Privacy Guarantees Using FFT
            https://arxiv.org/abs/1906.03049
    """
    warnings.warn("DEPRECATED FUNCTION! Use fourier_accountant.get_epsilon_S instead.", np.VisibleDeprecationWarning)
    return get_epsilon_S_new(target_delta, sigma_t, q_t, k, nx, L)

0 Source : arrow_dataset.py
with Apache License 2.0
from ExpressAI

    def to_tf_dataset(
        self,
        columns: Union[str, List[str]],
        batch_size: int,
        shuffle: bool,
        drop_remainder: bool = None,
        collate_fn: Callable = None,
        collate_fn_args: Dict[str, Any] = None,
        label_cols: Union[str, List[str]] = None,
        dummy_labels: bool = False,
        prefetch: bool = True,
    ):
        """Create a tf.data.Dataset from the underlying Dataset. This tf.data.Dataset will load and collate batches from
        the Dataset, and is suitable for passing to methods like model.fit() or model.predict().

        Args:
            columns (:obj:`List[str]` or :obj:`str`): Dataset column(s) to load in the tf.data.Dataset. In general,
            only columns that the model can use as input should be included here (numeric data only).
            batch_size (:obj:`int`): Size of batches to load from the dataset.
            shuffle(:obj:`bool`): Shuffle the dataset order when loading. Recommended True for training, False for
                validation/evaluation.
            drop_remainder(:obj:`bool`, default ``None``): Drop the last incomplete batch when loading. If not provided,
                defaults to the same setting as shuffle.
            collate_fn(:obj:`Callable`): A function or callable object (such as a `DataCollator`) that will collate
                lists of samples into a batch.
            collate_fn_args (:obj:`Dict`, optional): An optional `dict` of keyword arguments to be passed to the
                `collate_fn`.
            label_cols (:obj:`List[str]` or :obj:`str`, default ``None``): Dataset column(s) to load as
                labels. Note that many models compute loss internally rather than letting Keras do it, in which case it is
                not necessary to actually pass the labels here, as long as they're in the input `columns`.
            dummy_labels (:obj:`bool`, default ``False``): If no `label_cols` are set, output an array of "dummy" labels
                with each batch. This can avoid problems with `fit()` or `train_on_batch()` that expect labels to be
                a Tensor or np.ndarray, but should (hopefully) not be necessary with our standard train_step().
            prefetch (:obj:`bool`, default ``True``): Whether to run the dataloader in a separate thread and maintain
                a small buffer of batches for training. Improves performance by allowing data to be loaded in the
                background while the model is training.
        """

        # TODO There is some hacky hardcoding in this function that needs to be fixed.
        #      We're planning to rework it so less code is needed at the start to remove columns before
        #      we know the final list of fields (post-data collator). This should clean up most of the special
        #      casing while retaining the API.
        if config.TF_AVAILABLE:
            import tensorflow as tf
        else:
            raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.")

        if collate_fn_args is None:
            collate_fn_args = {}

        if label_cols is None:
            label_cols = []
        elif isinstance(label_cols, str):
            label_cols = [label_cols]
        elif len(set(label_cols))   <   len(label_cols):
            raise ValueError("List of label_cols contains duplicates.")
        if not columns:
            raise ValueError("Need to specify at least one column.")
        elif isinstance(columns, str):
            columns = [columns]
        elif len(set(columns))  <  len(columns):
            raise ValueError("List of columns contains duplicates.")
        if label_cols is not None:
            cols_to_retain = list(set(columns + label_cols))
        else:
            cols_to_retain = columns
        # Special casing when the dataset has 'label' and the model expects 'labels' and the collator fixes it up for us
        if "labels" in cols_to_retain and "labels" not in self.features and "label" in self.features:
            cols_to_retain[cols_to_retain.index("labels")] = "label"
        # Watch for nonexistent columns, except those that the data collators add for us
        for col in cols_to_retain:
            if col not in self.features and not (col in ("attention_mask", "labels") and collate_fn is not None):
                raise ValueError(f"Couldn't find column {col} in dataset.")
        if drop_remainder is None:
            # We assume that if you're shuffling it's the train set, so we drop the remainder unless told not to
            drop_remainder = shuffle
        dataset = self.with_format("python", columns=[col for col in cols_to_retain if col in self.features])

        def numpy_pad(data):
            try:
                # When this is finally fully removed, remove this line
                # Alternatively, find a more elegant way to do this whole thing
                np.warnings.filterwarnings("error", category=np.VisibleDeprecationWarning)
                data = np.array(data)
                if data.dtype == np.object:
                    raise AssertionError  # Do it this way so that the assert doesn't get optimized out
                return data
            except (np.VisibleDeprecationWarning, AssertionError):
                pass
            # Get lengths of each row of data
            lens = np.array([len(i) for i in data])

            # Mask of valid places in each row
            mask = np.arange(lens.max())  <  lens[:, None]

            # Setup output array and put elements from data into masked positions
            out = np.zeros(mask.shape, dtype=np.array(data[0]).dtype)
            out[mask] = np.concatenate(data)
            return out

        def np_get_batch(indices):
            batch = dataset[indices]
            out_batch = []
            if collate_fn is not None:
                actual_size = len(list(batch.values())[0])  # Get the length of one of the arrays, assume all same
                # Our collators expect a list of dicts, not a dict of lists/arrays, so we invert
                batch = [{key: value[i] for key, value in batch.items()} for i in range(actual_size)]
                batch = collate_fn(batch, **collate_fn_args)
                # Special casing when the dataset has 'label' and the model
                # expects 'labels' and the collator fixes it up for us
                if "label" in cols_to_retain and "label" not in batch and "labels" in batch:
                    cols_to_retain[cols_to_retain.index("label")] = "labels"
                for key in cols_to_retain:
                    # In case the collate_fn returns something strange
                    array = np.array(batch[key])
                    cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
                    array = array.astype(cast_dtype)
                    out_batch.append(array)
            else:
                for key in cols_to_retain:
                    array = batch[key]
                    array = numpy_pad(array)
                    cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
                    array = array.astype(cast_dtype)
                    out_batch.append(array)
            return [tf.convert_to_tensor(arr) for arr in out_batch]

        test_batch = np_get_batch(np.arange(batch_size))

        @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
        def fetch_function(indices):
            output = tf.numpy_function(
                np_get_batch, inp=[indices], Tout=[tf.dtypes.as_dtype(arr.dtype) for arr in test_batch]
            )
            return {key: output[i] for i, key in enumerate(cols_to_retain)}

        test_batch_dict = {key: test_batch[i] for i, key in enumerate(cols_to_retain)}
        output_signature = TensorflowDatasetMixin._get_output_signature(
            dataset, cols_to_retain, test_batch_dict, batch_size=batch_size if drop_remainder else None
        )

        def ensure_shapes(input_dict):
            return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()}

        tf_dataset = tf.data.Dataset.from_tensor_slices(np.arange(len(dataset), dtype=np.int64))

        if shuffle:
            tf_dataset = tf_dataset.shuffle(len(dataset))

        tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder).map(fetch_function).map(ensure_shapes)

        if label_cols:

            def split_features_and_labels(input_batch):
                features = {key: tensor for key, tensor in input_batch.items() if key in columns}
                labels = {key: tensor for key, tensor in input_batch.items() if key in label_cols}
                if len(features) == 1:
                    features = list(features.values())[0]
                if len(labels) == 1:
                    labels = list(labels.values())[0]
                return features, labels

            tf_dataset = tf_dataset.map(split_features_and_labels)

        elif len(columns) == 1:
            tf_dataset = tf_dataset.map(lambda x: list(x.values())[0])

        if dummy_labels and not label_cols:

            def add_dummy_labels(input_batch):
                return input_batch, tf.zeros(tf.shape(input_batch[columns[0]])[0])

            tf_dataset = tf_dataset.map(add_dummy_labels)

        def rename_label_col(inputs, labels=None):
            if not isinstance(inputs, tf.Tensor):
                if "label" in inputs:
                    inputs["labels"] = inputs["label"]
                    del inputs["label"]
            if labels is None:
                return inputs
            else:
                return inputs, labels

        tf_dataset = tf_dataset.map(rename_label_col)

        if prefetch:
            tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)

        # Remove a reference to the open Arrow file on delete
        def cleanup_callback(ref):
            dataset.__del__()
            self._TF_DATASET_REFS.remove(ref)

        self._TF_DATASET_REFS.add(weakref.ref(tf_dataset, cleanup_callback))
        return tf_dataset


class DatasetTransformationNotAllowedError(Exception):

0 Source : _testing.py
with GNU General Public License v3.0
from gustavowillam

def assert_warns_message(warning_class, message, func, *args, **kw):
    # very important to avoid uncontrolled state propagation
    """Test that a certain warning occurs and with a certain message.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    message : str | callable
        The message or a substring of the message to test for. If callable,
        it takes a string as the argument and will trigger an AssertionError
        if the callable returns `False`.

    func : callable
        Callable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`.

    Returns
    -------
    result : the return value of `func`

    """
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        if hasattr(np, 'FutureWarning'):
            # Let's not catch the numpy internal DeprecationWarnings
            warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
        # Trigger a warning.
        result = func(*args, **kw)
        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = [issubclass(warning.category, warning_class) for warning in w]
        if not any(found):
            raise AssertionError("No warning raised for %s with class "
                                 "%s"
                                 % (func.__name__, warning_class))

        message_found = False
        # Checks the message of all warnings belong to warning_class
        for index in [i for i, x in enumerate(found) if x]:
            # substring will match, the entire message with typo won't
            msg = w[index].message  # For Python 3 compatibility
            msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
            if callable(message):  # add support for certain tests
                check_in_message = message
            else:
                check_in_message = lambda msg: message in msg

            if check_in_message(msg):
                message_found = True
                break

        if not message_found:
            raise AssertionError("Did not receive the message you expected "
                                 "('%s') for   <  %s>, got: '%s'"
                                 % (message, func.__name__, msg))

    return result


def assert_warns_div0(func, *args, **kw):

0 Source : multiclass.py
with GNU General Public License v3.0
from gustavowillam

def is_multilabel(y):
    """ Check if ``y`` is in a multilabel format.

    Parameters
    ----------
    y : ndarray of shape (n_samples,)
        Target values.

    Returns
    -------
    out : bool
        Return ``True``, if ``y`` is in a multilabel format, else ```False``.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.utils.multiclass import is_multilabel
    >>> is_multilabel([0, 1, 0, 1])
    False
    >>> is_multilabel([[1], [0, 2], []])
    False
    >>> is_multilabel(np.array([[1, 0], [0, 0]]))
    True
    >>> is_multilabel(np.array([[1], [0], [0]]))
    False
    >>> is_multilabel(np.array([[1, 0, 0]]))
    True
    """
    if hasattr(y, '__array__') or isinstance(y, Sequence):
        # DeprecationWarning will be replaced by ValueError, see NEP 34
        # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
        with warnings.catch_warnings():
            warnings.simplefilter('error', np.VisibleDeprecationWarning)
            try:
                y = np.asarray(y)
            except np.VisibleDeprecationWarning:
                # dtype=object should be provided explicitly for ragged arrays,
                # see NEP 34
                y = np.array(y, dtype=object)

    if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1):
        return False

    if issparse(y):
        if isinstance(y, (dok_matrix, lil_matrix)):
            y = y.tocsr()
        return (len(y.data) == 0 or np.unique(y.data).size == 1 and
                (y.dtype.kind in 'biu' or  # bool, int, uint
                 _is_integral_float(np.unique(y.data))))
    else:
        labels = np.unique(y)

        return len(labels)   <   3 and (y.dtype.kind in 'biu' or  # bool, int, uint
                                    _is_integral_float(labels))


def check_classification_targets(y):

0 Source : multiclass.py
with GNU General Public License v3.0
from gustavowillam

def type_of_target(y):
    """Determine the type of data indicated by the target.

    Note that this type is the most specific type that can be inferred.
    For example:

        * ``binary`` is more specific but compatible with ``multiclass``.
        * ``multiclass`` of integers is more specific but compatible with
          ``continuous``.
        * ``multilabel-indicator`` is more specific but compatible with
          ``multiclass-multioutput``.

    Parameters
    ----------
    y : array-like

    Returns
    -------
    target_type : str
        One of:

        * 'continuous': `y` is an array-like of floats that are not all
          integers, and is 1d or a column vector.
        * 'continuous-multioutput': `y` is a 2d array of floats that are
          not all integers, and both dimensions are of size > 1.
        * 'binary': `y` contains   <  = 2 discrete values and is 1d or a column
          vector.
        * 'multiclass': `y` contains more than two discrete values, is not a
          sequence of sequences, and is 1d or a column vector.
        * 'multiclass-multioutput': `y` is a 2d array that contains more
          than two discrete values, is not a sequence of sequences, and both
          dimensions are of size > 1.
        * 'multilabel-indicator': `y` is a label indicator matrix, an array
          of two dimensions with at least two columns, and at most 2 unique
          values.
        * 'unknown': `y` is array-like but none of the above, such as a 3d
          array, sequence of sequences, or an array of non-sequence objects.

    Examples
    --------
    >>> import numpy as np
    >>> type_of_target([0.1, 0.6])
    'continuous'
    >>> type_of_target([1, -1, -1, 1])
    'binary'
    >>> type_of_target(['a', 'b', 'a'])
    'binary'
    >>> type_of_target([1.0, 2.0])
    'binary'
    >>> type_of_target([1, 0, 2])
    'multiclass'
    >>> type_of_target([1.0, 0.0, 3.0])
    'multiclass'
    >>> type_of_target(['a', 'b', 'c'])
    'multiclass'
    >>> type_of_target(np.array([[1, 2], [3, 1]]))
    'multiclass-multioutput'
    >>> type_of_target([[1, 2]])
    'multilabel-indicator'
    >>> type_of_target(np.array([[1.5, 2.0], [3.0, 1.6]]))
    'continuous-multioutput'
    >>> type_of_target(np.array([[0, 1], [1, 1]]))
    'multilabel-indicator'
    """
    valid = ((isinstance(y, (Sequence, spmatrix)) or hasattr(y, '__array__'))
             and not isinstance(y, str))

    if not valid:
        raise ValueError('Expected array-like (array or non-string sequence), '
                         'got %r' % y)

    sparse_pandas = (y.__class__.__name__ in ['SparseSeries', 'SparseArray'])
    if sparse_pandas:
        raise ValueError("y cannot be class 'SparseSeries' or 'SparseArray'")

    if is_multilabel(y):
        return 'multilabel-indicator'

    # DeprecationWarning will be replaced by ValueError, see NEP 34
    # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
    with warnings.catch_warnings():
        warnings.simplefilter('error', np.VisibleDeprecationWarning)
        try:
            y = np.asarray(y)
        except np.VisibleDeprecationWarning:
            # dtype=object should be provided explicitly for ragged arrays,
            # see NEP 34
            y = np.asarray(y, dtype=object)

    # The old sequence of sequences format
    try:
        if (not hasattr(y[0], '__array__') and isinstance(y[0], Sequence)
                and not isinstance(y[0], str)):
            raise ValueError('You appear to be using a legacy multi-label data'
                             ' representation. Sequence of sequences are no'
                             ' longer supported; use a binary array or sparse'
                             ' matrix instead - the MultiLabelBinarizer'
                             ' transformer can convert to this format.')
    except IndexError:
        pass

    # Invalid inputs
    if y.ndim > 2 or (y.dtype == object and len(y) and
                      not isinstance(y.flat[0], str)):
        return 'unknown'  # [[[1, 2]]] or [obj_1] and not ["label_1"]

    if y.ndim == 2 and y.shape[1] == 0:
        return 'unknown'  # [[]]

    if y.ndim == 2 and y.shape[1] > 1:
        suffix = "-multioutput"  # [[1, 2], [1, 2]]
    else:
        suffix = ""  # [1, 2, 3] or [[1], [2], [3]]

    # check float and contains non-integer float values
    if y.dtype.kind == 'f' and np.any(y != y.astype(int)):
        # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
        _assert_all_finite(y)
        return 'continuous' + suffix

    if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
        return 'multiclass' + suffix  # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
    else:
        return 'binary'  # [1, 2] or [["a"], ["b"]]


def _check_partial_fit_first_call(clf, classes=None):

0 Source : _testing.py
with GNU General Public License v3.0
from gustavowillam

def assert_warns(warning_class, func, *args, **kw):
    """Test that a certain warning occurs.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    func : callable
        Callable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`

    Returns
    -------
    result : the return value of `func`

    """
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        # Trigger a warning.
        result = func(*args, **kw)
        if hasattr(np, 'FutureWarning'):
            # Filter out numpy-specific warnings in numpy >= 1.9
            w = [e for e in w
                 if e.category is not np.VisibleDeprecationWarning]

        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = any(warning.category is warning_class for warning in w)
        if not found:
            raise AssertionError("%s did not give warning: %s( is %s)"
                                 % (func.__name__, warning_class, w))
    return result


def assert_warns_message(warning_class, message, func, *args, **kw):

0 Source : _testing.py
with GNU General Public License v3.0
from gustavowillam

def assert_warns_message(warning_class, message, func, *args, **kw):
    # very important to avoid uncontrolled state propagation
    """Test that a certain warning occurs and with a certain message.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    message : str or callable
        The message or a substring of the message to test for. If callable,
        it takes a string as the argument and will trigger an AssertionError
        if the callable returns `False`.

    func : callable
        Callable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`.

    Returns
    -------
    result : the return value of `func`

    """
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        if hasattr(np, 'FutureWarning'):
            # Let's not catch the numpy internal DeprecationWarnings
            warnings.simplefilter('ignore', np.VisibleDeprecationWarning)
        # Trigger a warning.
        result = func(*args, **kw)
        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = [issubclass(warning.category, warning_class) for warning in w]
        if not any(found):
            raise AssertionError("No warning raised for %s with class "
                                 "%s"
                                 % (func.__name__, warning_class))

        message_found = False
        # Checks the message of all warnings belong to warning_class
        for index in [i for i, x in enumerate(found) if x]:
            # substring will match, the entire message with typo won't
            msg = w[index].message  # For Python 3 compatibility
            msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
            if callable(message):  # add support for certain tests
                check_in_message = message
            else:
                def check_in_message(msg): return message in msg

            if check_in_message(msg):
                message_found = True
                break

        if not message_found:
            raise AssertionError("Did not receive the message you expected "
                                 "('%s') for   <  %s>, got: '%s'"
                                 % (message, func.__name__, msg))

    return result


def assert_warns_div0(func, *args, **kw):

0 Source : testing.py
with GNU General Public License v3.0
from HHHHhgqcdxhg

def assert_warns(warning_class, func, *args, **kw):
    """Test that a certain warning occurs.

    Parameters
    ----------
    warning_class : the warning class
        The class to test for, e.g. UserWarning.

    func : callable
        Callable object to trigger warnings.

    *args : the positional arguments to `func`.

    **kw : the keyword arguments to `func`

    Returns
    -------

    result : the return value of `func`

    """
    clean_warning_registry()
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered.
        warnings.simplefilter("always")
        # Trigger a warning.
        result = func(*args, **kw)
        if hasattr(np, 'VisibleDeprecationWarning'):
            # Filter out numpy-specific warnings in numpy >= 1.9
            w = [e for e in w
                 if e.category is not np.VisibleDeprecationWarning]

        # Verify some things
        if not len(w) > 0:
            raise AssertionError("No warning raised when calling %s"
                                 % func.__name__)

        found = any(warning.category is warning_class for warning in w)
        if not found:
            raise AssertionError("%s did not give warning: %s( is %s)"
                                 % (func.__name__, warning_class, w))
    return result


def assert_warns_message(warning_class, message, func, *args, **kw):

See More Examples