""" Module for real datasets with available ground truth word importance explanations. Also contains
methods and classes for word importance data manipulation. """
import json
from teex._utils._misc import _download_extract_file
from teex._utils._paths import _check_pathlib_dir
from teex._datasets.info.newsgroup import _newsgroupRoot, _newsgroupLabels, _newsgroupNEntries, _newsgroupAll, \
_newsgroup_url
from teex._baseClasses._baseDatasets import _ClassificationDataset
[docs]class Newsgroup(_ClassificationDataset):
""" 20 Newsgroup dataset. Contains 188 human
annotaded newsgroup texts belonging to two categories. From
Sina Mohseni, Jeremy E Block, and Eric Ragan. 2021.
Quantitative Evaluation of Machine Learning
Explanations: A Human-Grounded Benchmark.
https://doi.org/10.1145/3397481.3450689
:Example:
>>> nDataset = Newsgroup()
>>> obs, label, exp = nDataset[1]
where :code:`obs` is a str, :code:`label` is an int and :code:`exp` is a dict. containing a score for each important
word in :code:`obs`. When a slice is performed, obs, label and exp are lists of the objects described above. """
def __init__(self):
super(Newsgroup, self).__init__(path=_newsgroupRoot)
if self._check_integrity() is False:
print('Files do not exist or are corrupted:')
self._download()
self.classMap = self._get_class_map()
def __getitem__(self, item):
if isinstance(item, slice):
obs, label, exp = [], [], []
fileNames = _newsgroupAll[item]
labels = _newsgroupLabels[item]
for name, classLabel in zip(fileNames, labels):
with open(str(self._path / ('data/' + name)), 'rb') as t:
obs.append(t.read())
label.append(classLabel)
with open(str(self._path / ('expl/' + name + '.json')), 'rb') as t:
_e = json.load(t)['words']
exp.append({word[0]: word[1] for word in _e})
elif isinstance(item, int):
with open(str(self._path / ('data/' + _newsgroupAll[item])), 'rb') as t:
obs = t.read()
label = _newsgroupLabels[item]
with open(str(self._path / ('expl/' + _newsgroupAll[item] + '.json')), 'rb') as t:
_e = json.load(t)['words']
exp = {word[0]: word[1] for word in _e}
else:
raise TypeError('Invalid argument type.')
return obs, label, exp
def __len__(self) -> int:
return _newsgroupNEntries
def _check_integrity(self) -> bool:
return (_check_pathlib_dir(self._path / 'expl') and
_check_pathlib_dir(self._path / 'data'))
def _download(self) -> None:
_download_extract_file(self._path, _newsgroup_url, 'rawNewsgroup.zip')
self._isDownloaded = True
def _get_class_map(self) -> dict:
return {0: 'electronics', 1: 'medicine'}