ap.train

Модуль для поддержания работы с данными

class ap.train.data_manager.ModelDataManager(data_dir: str, experiment_config: str)

Базовые классы: object

Класс для поддержания работы с данными

generate_batches_balanced_by_rubric()

Возвращает artm.BatchVectorizer, построенный на сбалансированных батчах.

Генерирует батчи, в которых документы сбалансированны относительно рубрик ГРНТИ. Из всего тренировочного датасета сэмплируются документы так, чтобы в обучении на эпохе участвовало одинаковое количество документов каждой рубрики ГРНТИ. Количество документов каждой рубрики равно average_rubric_size - среднему размеру рубрики ГРНТИ.

Если в конфиге для обучения модели self._config присутствует путь до батчей, построенных по википедии self._path_batches_wiki, то батчи будут использованы для обучения модели. Иначе в обучении будут принимать участие только батчи, сбалансированные относительно рубрик ГРНТИ.

Возвращает artm.BatchVectorizer, построенный на этих батчах.

Результат

artm.BatchVectorizer, построенный на сбалансированных батчах.

Тип результата

batch_vectorizer (artm.BatchVectorizer)

get_modality_distribution() Dict[str, int]

Возвращает количество документов каждой модальности из self.class_ids для тренировочных данных.

Если в конфиге для обучения модели self.config передан путь до словаря, содержащего количество документов Wikipedia по модальностям, эти данные учитываются для оценки всего тренировочного датасета.

Результат

словарь, ключ это модальность, значение это количество документов с такой модальностью

Тип результата

modality_distribution_all (dict)

load_train_data()

Загружает тренировочные данные.

Создает два атрибута:
  • self.train_docs это словарь, где по doc_id содержиться документ в Vowpal Wabbit формате

  • self._docs_of_rubrics это словарь, где по рубрике хранится

    список всех doc_id с такой рубрикой из self.rubrics_train.

update_config(config: str)

Обновляет конфиг, хранящийся по пути self._config_path.

Параметры

config – конфиг обучаемой модели

update_ds_metrics()

Обновляет метрики о датасете.

write_new_docs(vw, docs: Dict[str, Dict[str, str]])

Сохраняет документы.

Параметры
  • vw (VowpalWabbitBPE) – объект класса VowpalWabbitBPE для сохранения VW файлов.

  • docs (dict) – документы

exception ap.train.data_manager.NoTranslationException

Базовые классы: Exception