First, we need to install blurr module for Transformers
integration.
Grab data for binary classification:
Define task:
HF_TASKS_AUTO = HF_TASKS_AUTO()
task = HF_TASKS_AUTO$SequenceClassification
pretrained_model_name = "roberta-base" # "distilbert-base-uncased" "bert-base-uncased"
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task)Downloading: 100%|██████████| 481/481 [00:00<00:00, 277kB/s]
Downloading: 100%|██████████| 899k/899k [00:01<00:00, 580kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 471kB/s]
Downloading: 100%|██████████| 501M/501M [03:11<00:00, 2.62MB/s]Create Learner with Hugging Face data
blocks:
imdb_df = data.table::fread('imdb_sample/texts.csv')
blocks = list(HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock())
dblock = DataBlock(blocks=blocks,
                   get_x=ColReader('text'),
                   get_y=ColReader('label'),
                   splitter=ColSplitter(col='is_valid'))
dls = dblock %>% dataloaders(imdb_df, bs=4)
dls %>% one_batch()[[1]]
[[1]]$input_ids
tensor([[    0,  4833,  3009,  ...,  1916,     6,     2],
        [    0,  1876, 13856,  ...,     7,    47,     2],
        [    0,  2647,     6,  ...,     6,    61,     2],
        [    0,    20,  2091,  ...,  5779,    30,     2]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')
[[2]]
TensorCategory([0, 1, 0, 0], device='cuda:0')Wrap model:
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
                model,
                opt_func=partial(Adam, decouple_wd=TRUE),
                loss_func=CrossEntropyLossFlat(),
                metrics=accuracy,
                cbs=HF_BaseModelCallback(),
                splitter=hf_splitter())
learn$create_opt()
learn$freeze()
learn %>% summary()epoch   train_loss   valid_loss   accuracy   time  
------  -----------  -----------  ---------  ------
HF_BaseModelWrapper (Input shape: 4 x 512)
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Embedding            4 x 512 x 768        38,603,520 False     
________________________________________________________________
Embedding            4 x 512 x 768        394,752    False     
________________________________________________________________
Embedding            4 x 512 x 768        768        False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
Dropout              4 x 12 x 512 x 512   0          False     
________________________________________________________________
Linear               4 x 512 x 768        590,592    False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 512 x 3072       2,362,368  False     
________________________________________________________________
Linear               4 x 512 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            4 x 512 x 768        1,536      True      
________________________________________________________________
Dropout              4 x 512 x 768        0          False     
________________________________________________________________
Linear               4 x 768              590,592    True      
________________________________________________________________
Dropout              4 x 768              0          False     
________________________________________________________________
Linear               4 x 2                1,538      True      
________________________________________________________________
Total params: 124,647,170
Total trainable params: 630,530
Total non-trainable params: 124,016,640
Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fd850db18c8>, decouple_wd=True)
Loss function: FlattenedLoss of CrossEntropyLoss()
Model frozen up to parameter group #2
Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - HF_BaseModelCallback