ap.train¶
Модуль для поддержания работы с данными
- class ap.train.data_manager.ModelDataManager(data_dir: str, experiment_config: str)¶
Базовые классы:
objectКласс для поддержания работы с данными
- generate_batches_balanced_by_rubric()¶
Возвращает artm.BatchVectorizer, построенный на сбалансированных батчах.
Генерирует батчи, в которых документы сбалансированны относительно рубрик ГРНТИ. Из всего тренировочного датасета сэмплируются документы так, чтобы в обучении на эпохе участвовало одинаковое количество документов каждой рубрики ГРНТИ. Количество документов каждой рубрики равно average_rubric_size - среднему размеру рубрики ГРНТИ.
Если в конфиге для обучения модели self._config присутствует путь до батчей, построенных по википедии self._path_batches_wiki, то батчи будут использованы для обучения модели. Иначе в обучении будут принимать участие только батчи, сбалансированные относительно рубрик ГРНТИ.
Возвращает artm.BatchVectorizer, построенный на этих батчах.
- Результат
artm.BatchVectorizer, построенный на сбалансированных батчах.
- Тип результата
batch_vectorizer (artm.BatchVectorizer)
- get_modality_distribution() Dict[str, int]¶
Возвращает количество документов каждой модальности из self.class_ids для тренировочных данных.
Если в конфиге для обучения модели self.config передан путь до словаря, содержащего количество документов Wikipedia по модальностям, эти данные учитываются для оценки всего тренировочного датасета.
- Результат
словарь, ключ это модальность, значение это количество документов с такой модальностью
- Тип результата
modality_distribution_all (dict)
- load_train_data()¶
Загружает тренировочные данные.
- Создает два атрибута:
self.train_docs это словарь, где по doc_id содержиться документ в Vowpal Wabbit формате
- self._docs_of_rubrics это словарь, где по рубрике хранится
список всех doc_id с такой рубрикой из self.rubrics_train.
- update_config(config: str)¶
Обновляет конфиг, хранящийся по пути self._config_path.
- Параметры
config – конфиг обучаемой модели
- update_ds_metrics()¶
Обновляет метрики о датасете.
- write_new_docs(vw, docs: Dict[str, Dict[str, str]])¶
Сохраняет документы.
- Параметры
vw (VowpalWabbitBPE) – объект класса VowpalWabbitBPE для сохранения VW файлов.
docs (dict) – документы
- exception ap.train.data_manager.NoTranslationException¶
Базовые классы:
Exception