Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Add way of skipping pretrained weights download#5172

Merged
epwalsh merged 3 commits intomainfrom
transformer-no-load-weights
May 2, 2021
Merged

Add way of skipping pretrained weights download#5172
epwalsh merged 3 commits intomainfrom
transformer-no-load-weights

Conversation

@epwalsh
Copy link
Copy Markdown
Member

@epwalsh epwalsh commented Apr 30, 2021

Fixes #4599.

Changes proposed in this pull request:

  • Adds a load_weights: bool (default = True) parameter to cached_transformers.get() and all higher-level modules that call this function, such as PretrainedTransformerEmbedder and PretrainedTransformerMismatchedEmbedder. Setting this parameter to False will avoid downloading and loading pretrained transformer weights, so only the architecture is instantiated. So you can set the parameter to False via the overrides parameter when loading an AllenNLP model/predictor from an archive to avoid an unnecessary download.

For example, suppose your training config looks something like this:

{
  "model": {
    "type": "basic_classifier",
    "text_field_embedder": {
      "tokens": {
        "type": "pretrained_transformer",
        "model_name": "bert-base-cased",
        // ... other stuff ...
      }
    },
  },
  // ... other stuff ...
}

And now you have an archive from training this model: model.tar.gz. Then you can load the trained model into a predictor like so:

from allennlp.predictors import Predictor

overrides = {"model.text_field_embedder.tokens.load_weights": False}
predictor = Predictor.from_path("model.tar.gz", overrides=overrides)

@epwalsh
Copy link
Copy Markdown
Member Author

epwalsh commented Apr 30, 2021

Unfortunately this actually doesn't address #5170, because the SrlBert model uses the transformers library directly. But that's not hard to fix. I'll follow up with a separate PR for that in allennlp-models.

Copy link
Copy Markdown
Contributor

@ArjunSubramonian ArjunSubramonian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! looks great to me. thanks for meticulously going down the stack :)

@epwalsh epwalsh merged commit a463e0e into main May 2, 2021
@epwalsh epwalsh deleted the transformer-no-load-weights branch May 2, 2021 21:51
dirkgr pushed a commit that referenced this pull request May 10, 2021
* add way of skipping pretrained weights download

* clarify docstring

* add link to PR in CHANGELOG
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

When loading archived fine-tuned models for prediction, prevent non-fine-tuned pretrained transformer models from being downloaded

2 participants