import os
from podium.datasets.dataset import Dataset
from podium.datasets.example_factory import ExampleFactory
from podium.field import Field, LabelField
from podium.storage.resources.large_resource import LargeResource
from podium.vocab import Vocab
[docs]class SST(Dataset):
"""
The Stanford sentiment treebank dataset.
Attributes
----------
NAME : str
dataset name
URL : str
url to the SST dataset
DATASET_DIR : str
name of the folder in the dataset containing train and test directories
ARCHIVE_TYPE : str
string that defines archive type, used for unpacking dataset
TEXT_FIELD_NAME : str
name of the field containing comment text
LABEL_FIELD_NAME : str
name of the field containing label value
POSITIVE_LABEL : int
positive sentiment label
NEGATIVE_LABEL : int
negative sentiment label
"""
NAME = "sst"
URL = "https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip"
DATASET_DIR = os.path.join("sst", "trees")
ARCHIVE_TYPE = "zip"
TRAIN_FILE = "train.txt"
VALID_FILE = "dev.txt"
TEST_FILE = "test.txt"
TEXT_FIELD_NAME = "text"
LABEL_FIELD_NAME = "label"
def __init__(self, file_path, fields, fine_grained=False, subtrees=False):
"""
Dataset constructor. User should use static method get_dataset_splits
rather than using the constructor directly.
Parameters
----------
dir_path : str
path to the directory containing datasets
fields : dict(str, Field)
dictionary that maps field name to the field
fine_grained: bool
if false, returns the binary (positive/negative) SST dataset
and filters out neutral examples. If this is `False`, please
set your Fields *not* to be eager.
subtrees: bool
also return the subtrees of each input instance as separate
instances. This causes the dataset to become much larger.
"""
LargeResource(
**{
LargeResource.RESOURCE_NAME: SST.NAME,
LargeResource.ARCHIVE: SST.ARCHIVE_TYPE,
LargeResource.URI: SST.URL,
}
)
# Assign these to enable filtering
examples = self._create_examples(
file_path=file_path,
fields=fields,
fine_grained=fine_grained,
subtrees=subtrees,
)
# If not fine-grained, return binary task: filter out neutral instances
if not fine_grained:
# TODO @mttk: Perhaps issue warning if any of fields is eager
def filter_neutral(example):
return example["label"][1] != "neutral"
examples = [ex for ex in examples if filter_neutral(ex)]
super(SST, self).__init__(**{"examples": examples, "fields": fields})
@staticmethod
def _create_examples(file_path, fields, fine_grained, subtrees):
"""
Method creates examples for the sst dataset. Examples are arranged in
two folders, one for examples with positive sentiment and other with
negative sentiment. One file in each folder represents one example.
Parameters
----------
file_path : str
file where examples for this split are stored
fields : dict(str, Field)
dictionary mapping field names to fields
fine_grained: bool
if false, returns the binary (positive/negative) SST dataset
and filters out neutral examples. If this is `False`, please
set your Fields *not* to be eager.
subtrees: bool
also return the subtrees of each input instance as separate
instances. This causes the dataset to become much larger.
Returns
-------
examples : list(Example)
list of examples from given dir_path
"""
# Convert fields to list as the output is going to be a list
fields_as_list = [fields[SST.TEXT_FIELD_NAME], fields[SST.LABEL_FIELD_NAME]]
example_factory = ExampleFactory(fields_as_list)
label_to_string_map = _label_to_string_map(fine_grained=fine_grained)
def label_trf(label):
return label_to_string_map[label]
examples = []
with open(file=file_path, encoding="utf8") as fpr:
for line in fpr:
example = example_factory.from_fields_tree(
line, subtrees=subtrees, label_transform=label_trf
)
if subtrees:
# Example is actually a list
examples.extend(example)
else:
examples.append(example)
return examples
[docs] @staticmethod
def get_dataset_splits(fields=None, fine_grained=False, subtrees=False):
"""
Method loads and creates dataset splits for the SST dataset.
Parameters
----------
fields : dict(str, Field), optional
dictionary mapping field name to field, if not given method will
use ```get_default_fields```. User should use default field names
defined in class attributes.
fine_grained: bool
if false, returns the binary (positive/negative) SST dataset
and filters out neutral examples. If this is `False`, please
set your Fields *not* to be eager.
subtrees: bool
also return the subtrees of each input instance as separate
instances. This causes the dataset to become much larger.
Returns
-------
(train_dataset, valid_dataset, test_dataset) : (Dataset, Dataset, Dataset)
tuple containing train, valid and test dataset
"""
data_location = os.path.join(LargeResource.BASE_RESOURCE_DIR, SST.DATASET_DIR)
if fields is None:
fields = SST.get_default_fields()
train_dataset = SST(
file_path=os.path.join(data_location, SST.TRAIN_FILE),
fields=fields,
fine_grained=fine_grained,
subtrees=subtrees,
)
valid_dataset = SST(
file_path=os.path.join(data_location, SST.VALID_FILE),
fields=fields,
fine_grained=fine_grained,
subtrees=subtrees,
)
test_dataset = SST(
file_path=os.path.join(data_location, SST.TEST_FILE),
fields=fields,
fine_grained=fine_grained,
subtrees=subtrees,
)
return (train_dataset, valid_dataset, test_dataset)
[docs] @staticmethod
def get_default_fields():
"""
Method returns default Imdb fields: text and label.
Returns
-------
fields : dict(str, Field)
Dictionary mapping field name to field.
"""
text = Field(
name=SST.TEXT_FIELD_NAME,
numericalizer=Vocab(eager=False),
tokenizer="split",
keep_raw=False,
)
label = LabelField(
name=SST.LABEL_FIELD_NAME, numericalizer=Vocab(specials=(), eager=False)
)
return {SST.TEXT_FIELD_NAME: text, SST.LABEL_FIELD_NAME: label}
def _label_to_string_map(fine_grained):
pre = "very " if fine_grained else ""
return {
"0": pre + "negative",
"1": "negative",
"2": "neutral",
"3": "positive",
"4": pre + "positive",
}
def _get_label_str(label, fine_grained):
return _label_to_string_map(fine_grained)[label]