<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">"""Test functionality of mldata fetching utilities."""

import os
import shutil
import tempfile
import scipy as sp

from sklearn import datasets
from sklearn.datasets import mldata_filename, fetch_mldata

from sklearn.utils.testing import assert_in
from sklearn.utils.testing import assert_not_in
from sklearn.utils.testing import mock_mldata_urlopen
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import with_setup
from sklearn.utils.testing import assert_array_equal


tmpdir = None


def setup_tmpdata():
    # create temporary dir
    global tmpdir
    tmpdir = tempfile.mkdtemp()
    os.makedirs(os.path.join(tmpdir, 'mldata'))


def teardown_tmpdata():
    # remove temporary dir
    if tmpdir is not None:
        shutil.rmtree(tmpdir)


def test_mldata_filename():
    cases = [('datasets-UCI iris', 'datasets-uci-iris'),
             ('news20.binary', 'news20binary'),
             ('book-crossing-ratings-1.0', 'book-crossing-ratings-10'),
             ('Nile Water Level', 'nile-water-level'),
             ('MNIST (original)', 'mnist-original')]
    for name, desired in cases:
        assert_equal(mldata_filename(name), desired)


@with_setup(setup_tmpdata, teardown_tmpdata)
def test_download():
    """Test that fetch_mldata is able to download and cache a data set."""

    _urlopen_ref = datasets.mldata.urlopen
    datasets.mldata.urlopen = mock_mldata_urlopen({
        'mock': {
            'label': sp.ones((150,)),
            'data': sp.ones((150, 4)),
        },
    })
    try:
        mock = fetch_mldata('mock', data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "target", "data"]:
            assert_in(n, mock)

        assert_equal(mock.target.shape, (150,))
        assert_equal(mock.data.shape, (150, 4))

        assert_raises(datasets.mldata.HTTPError,
                      fetch_mldata, 'not_existing_name')
    finally:
        datasets.mldata.urlopen = _urlopen_ref


@with_setup(setup_tmpdata, teardown_tmpdata)
def test_fetch_one_column():
    _urlopen_ref = datasets.mldata.urlopen
    try:
        dataname = 'onecol'
        # create fake data set in cache
        x = sp.arange(6).reshape(2, 3)
        datasets.mldata.urlopen = mock_mldata_urlopen({dataname: {'x': x}})

        dset = fetch_mldata(dataname, data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "data"]:
            assert_in(n, dset)
        assert_not_in("target", dset)

        assert_equal(dset.data.shape, (2, 3))
        assert_array_equal(dset.data, x)

        # transposing the data array
        dset = fetch_mldata(dataname, transpose_data=False, data_home=tmpdir)
        assert_equal(dset.data.shape, (3, 2))
    finally:
        datasets.mldata.urlopen = _urlopen_ref


@with_setup(setup_tmpdata, teardown_tmpdata)
def test_fetch_multiple_column():
    _urlopen_ref = datasets.mldata.urlopen
    try:
        # create fake data set in cache
        x = sp.arange(6).reshape(2, 3)
        y = sp.array([1, -1])
        z = sp.arange(12).reshape(4, 3)

        # by default
        dataname = 'threecol-default'
        datasets.mldata.urlopen = mock_mldata_urlopen({
            dataname: (
                {
                    'label': y,
                    'data': x,
                    'z': z,
                },
                ['z', 'data', 'label'],
            ),
        })

        dset = fetch_mldata(dataname, data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "target", "data", "z"]:
            assert_in(n, dset)
        assert_not_in("x", dset)
        assert_not_in("y", dset)

        assert_array_equal(dset.data, x)
        assert_array_equal(dset.target, y)
        assert_array_equal(dset.z, z.T)

        # by order
        dataname = 'threecol-order'
        datasets.mldata.urlopen = mock_mldata_urlopen({
            dataname: ({'y': y, 'x': x, 'z': z},
                       ['y', 'x', 'z']), })

        dset = fetch_mldata(dataname, data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "target", "data", "z"]:
            assert_in(n, dset)
        assert_not_in("x", dset)
        assert_not_in("y", dset)

        assert_array_equal(dset.data, x)
        assert_array_equal(dset.target, y)
        assert_array_equal(dset.z, z.T)

        # by number
        dataname = 'threecol-number'
        datasets.mldata.urlopen = mock_mldata_urlopen({
            dataname: ({'y': y, 'x': x, 'z': z},
                       ['z', 'x', 'y']),
        })

        dset = fetch_mldata(dataname, target_name=2, data_name=0,
                            data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "target", "data", "x"]:
            assert_in(n, dset)
        assert_not_in("y", dset)
        assert_not_in("z", dset)

        assert_array_equal(dset.data, z)
        assert_array_equal(dset.target, y)

        # by name
        dset = fetch_mldata(dataname, target_name='y', data_name='z',
                            data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "target", "data", "x"]:
            assert_in(n, dset)
        assert_not_in("y", dset)
        assert_not_in("z", dset)

    finally:
        datasets.mldata.urlopen = _urlopen_ref
</pre></body></html>