test_size = int(0.1 * len(dataset))
valid_size = int(0.2 * len(dataset))
train_size = len(dataset) - test_size - valid_size
train_data, valid_data, test_data = random_split(dataset, [train_size, valid_size, test_size])
Now let me double-check the length of the split sets.
len(train_data), len(valid_data), len(test_data)
(51041, 14583, 7291)
Set up training
Everything looks right. Next, we will build the optimizer. Needless to say, I
will use Adam here. Note that, in the PyTorch implementation of Adam,
weight_decay is actually L2 regularization. See the post
here.
optimizer = torch.optim.Adam(model.parameters(), 2e-3, weight_decay=2e-4)
Here I will initialize the loss function, which is just the binary cross entropy loss, and the dictionary to log the training metics. Nothing special.
loss_func = torch.nn.BCELoss()
metrics = {
"train_loss": [],
"valid_loss": [],
"train_acc": [],
"valid_acc": [],
"best_valid_loss": None,
}
Next step, I will define a function that implements the batch size and learning rate regularization (in an ugly way). It basically does what was said in the famous paper, "Don't Decay the Learning Rate, Increase the Batch Size".
def regularization(
batch_size: int,
max_batch_size: int,
train_data: TensorDataset,
valid_data: TensorDataset,
optimizer: torch.optim,
model: torch.nn.Module,
l2_reg=2e-4,
):
if batch_size == max_batch_size:
lr = optimizer.param_groups[0]["lr"]
lr /= 2
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=l2_reg)
else:
tmp = 2 * batch_size
if tmp > max_batch_size:
batch_size = max_batch_size
else:
batch_size = tmp
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size)
return batch_size, train_dataloader, valid_dataloader, optimizer
If you want a little bit explanation, here it goes. We keep track of a certain metric as the criteria, for example, validation loss. If the validation loss has not been improved from a certain epochs, we double the batch size, until the defined maximum batch size is reached. After that, every time, the model is stuck for that many epochs, we will lower the learning rate by half. This will help us to train a better model. The parameters used from this regularization are defined below.
batch_size = 32 # initial batch size
max_batch_size = 1024 # maximum batch size
raise_batch_patience = 8 # call `regularization` when validation loss is stuck for this number of epochs
early_stop_patience = 60 # exit the training loop if the validation loss is stuck more than 60 epochs
Here we build a dataset loader with torch's built-in DataLoader, such that the
batching can be done easily. To minimize the impact of initial random states, we
use shuffle=True, so every epoch get different batches.
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
The training loop
Here is the main training loop. I guess it is pretty much the standard way to training the model in torch. Basically, for each epoch, we calculate the gradients with the training set to update the model, and then apply the current model onto the validation set to determine its actual performance without biases. Then we will use the validation loss to do the batch size/learning rate regularization and check if early stopping should be done. When we exit the training loop, the parameters from the "best" model will be restored to ensure the performance of the network. Here, "best" is the one that has the best validation loss. Apparently, you can use other metrics as the criterion.
raise_batch_plateau = 0
early_stop_plateau = 0
for i in range(1000):
start_time = time.time()
model.train()
print(f"Current epoch {i + 1}")
print(f" Current batch size {batch_size}")
print(f" Validation loss has plateaued for {early_stop_plateau} epochs")
train_loss = 0
valid_loss = 0
train_acc = 0
valid_acc = 0
# training the model
for train_protein1, train_protein2, targets in train_dataloader:
train_protein1 = train_protein1.to(device)
train_protein2 = train_protein2.to(device)
targets = targets.to(device)
pred = model(train_protein1, train_protein2)
batch_loss = loss_func(pred, targets)
# get training metrics
train_loss += batch_loss.detach().cpu().numpy()
train_acc += (
(torch.argmax(pred, axis=1) == torch.argmax(targets, axis=1))
.sum()
.cpu()
.numpy()
)
# get gradients and update model weights
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# validation
with torch.no_grad():
model.eval()
for valid_protein1, valid_protein2, targets in valid_dataloader:
valid_protein1 = valid_protein1.to(device)
valid_protein2 = valid_protein2.to(device)
targets = targets.to(device)
pred = model(valid_protein1, valid_protein2)
batch_loss = loss_func(pred, targets)
valid_loss += batch_loss.cpu().numpy()
valid_acc += (
(torch.argmax(pred, axis=1) == torch.argmax(targets, axis=1))
.sum()
.cpu()
.numpy()
)
train_loss /= len(train_dataloader)
valid_loss /= len(valid_dataloader)
train_acc /= len(train_data)
valid_acc /= len(valid_data)
# print and log
print(f" Training loss: {train_loss}")
print(f" Training accuracy: {train_acc}")
print(f" Validation loss: {valid_loss}")
print(f" Validation accuracy: {valid_acc}")
metrics["train_loss"].append(train_loss)
metrics["valid_loss"].append(valid_loss)
metrics["train_acc"].append(train_acc)
metrics["valid_acc"].append(valid_acc)
# regularization for the batch size and learning rate
if metrics["best_valid_loss"] is None or metrics["best_valid_loss"] > valid_loss:
raise_batch_plateau = 0
early_stop_plateau = 0
metrics["best_valid_loss"] = valid_loss
best_parameters = model.state_dict()
else:
raise_batch_plateau += 1
early_stop_plateau += 1
if raise_batch_plateau == raise_batch_patience:
batch_size, train_dataloader, valid_dataloader, optimizer = regularization(
batch_size, max_batch_size, train_data, valid_data, optimizer, model
)
raise_batch_plateau = 0
# early stop to prevent overfitting
if early_stop_plateau == early_stop_patience:
print(
f"Early stopping after no improvements in the validation loss for {early_stop_patience} epochs."
)
# restore the parameters of the "best" model
model.load_state_dict(best_parameters)
break
print(f" Best validation loss so far: {metrics['best_valid_loss']}")
print(f" Epoch time: {time.time() - start_time}")
else:
print(f"Best validation loss: {metrics['best_valid_loss']}")
print(f"Total number of epochs: {i + 1} seconds")
Current epoch 1
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.3683407397679331
Training accuracy: 0.8278246899551341
Validation loss: 0.14916993617978797
Validation accuracy: 0.9509017348967976
Best validation loss so far: 0.14916993617978797
Epoch time: 17.75820255279541
Current epoch 2
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.1534857010179407
Training accuracy: 0.9502360847162086
Validation loss: 0.10900586530394656
Validation accuracy: 0.9663992319824453
Best validation loss so far: 0.10900586530394656
Epoch time: 16.510152339935303
Current epoch 3
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.11228975994155899
Training accuracy: 0.9659489430066025
Validation loss: 0.09190676893769323
Validation accuracy: 0.9726393746142769
Best validation loss so far: 0.09190676893769323
Epoch time: 17.22013568878174
Current epoch 4
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.09009803373166442
Training accuracy: 0.973589859132854
Validation loss: 0.079203839863144
Validation accuracy: 0.9788109442501542
Best validation loss so far: 0.079203839863144
Epoch time: 17.492194652557373
Current epoch 5
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.07818448865149558
Training accuracy: 0.9778217511412394
Validation loss: 0.07646888854357888
Validation accuracy: 0.9796338202016046
Best validation loss so far: 0.07646888854357888
Epoch time: 18.225395917892456
Current epoch 6
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.06945815600738113
Training accuracy: 0.9797613683117494
Validation loss: 0.07843868468434696
Validation accuracy: 0.979085236233971
Best validation loss so far: 0.07646888854357888
Epoch time: 20.313061475753784
Current epoch 7
Current batch size 32
Validation loss has plateaued for 1 epochs
Training loss: 0.06475827005861927
Training accuracy: 0.9818577222233107
Validation loss: 0.07695504867348191
Validation accuracy: 0.9829253240074058
Best validation loss so far: 0.07646888854357888
Epoch time: 19.2711660861969
Current epoch 8
Current batch size 32
Validation loss has plateaued for 2 epochs
Training loss: 0.06451458194719646
Training accuracy: 0.9823671166317275
Validation loss: 0.0691892049217131
Validation accuracy: 0.981896729068093
Best validation loss so far: 0.0691892049217131
Epoch time: 17.760085821151733
Current epoch 9
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.0597318727137205
Training accuracy: 0.9836014184675065
Validation loss: 0.07737614678090292
Validation accuracy: 0.9805938421449634
Best validation loss so far: 0.0691892049217131
Epoch time: 18.603302240371704
Current epoch 10
Current batch size 32
Validation loss has plateaued for 1 epochs
Training loss: 0.05820741842538504
Training accuracy: 0.9847181677474971
Validation loss: 0.06344870385905858
Validation accuracy: 0.9863539738051156
Best validation loss so far: 0.06344870385905858
Epoch time: 17.374438524246216
Current epoch 11
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.052588668085047086
Training accuracy: 0.9854430751748594
Validation loss: 0.07269282722993746
Validation accuracy: 0.9845710759103066
Best validation loss so far: 0.06344870385905858
Epoch time: 16.83235502243042
Current epoch 12
Current batch size 32
Validation loss has plateaued for 1 epochs
Training loss: 0.057011060293171295
Training accuracy: 0.9840128524127663
Validation loss: 0.064653842423005
Validation accuracy: 0.9835424809709936
Best validation loss so far: 0.06344870385905858
Epoch time: 17.263275623321533
Current epoch 13
Current batch size 32
Validation loss has plateaued for 2 epochs
Training loss: 0.0546581399321622
Training accuracy: 0.9841891812464489
Validation loss: 0.06777635000323419
Validation accuracy: 0.98299389700336
Best validation loss so far: 0.06344870385905858
Epoch time: 17.358654260635376
Current epoch 14
Current batch size 32
Validation loss has plateaued for 3 epochs
Training loss: 0.05259665207112514
Training accuracy: 0.9855018514527537
Validation loss: 0.06801316980068293
Validation accuracy: 0.9832681889871768
Best validation loss so far: 0.06344870385905858
Epoch time: 17.22840452194214
Current epoch 15
Current batch size 32
Validation loss has plateaued for 4 epochs
Training loss: 0.04765466454235057
Training accuracy: 0.9867165611959013
Validation loss: 0.0723494520208445
Validation accuracy: 0.9830624699993142
Best validation loss so far: 0.06344870385905858
Epoch time: 17.35137963294983
Current epoch 16
Current batch size 32
Validation loss has plateaued for 5 epochs
Training loss: 0.050809676872624984
Training accuracy: 0.9854038909895966
Validation loss: 0.06181443450955078
Validation accuracy: 0.9857368168415278
Best validation loss so far: 0.06181443450955078
Epoch time: 17.174020051956177
Current epoch 17
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.0511245566147746
Training accuracy: 0.9857761407495934
Validation loss: 0.06468374196801481
Validation accuracy: 0.9860796818212988
Best validation loss so far: 0.06181443450955078
Epoch time: 17.998218536376953
Current epoch 18
Current batch size 32
Validation loss has plateaued for 1 epochs
Training loss: 0.04710395816011287
Training accuracy: 0.9865402323622187
Validation loss: 0.06764085641107317
Validation accuracy: 0.9868339847767948
Best validation loss so far: 0.06181443450955078
Epoch time: 16.71057677268982
Current epoch 19
Current batch size 32
Validation loss has plateaued for 2 epochs
Training loss: 0.04492825991094566
Training accuracy: 0.9874022844380008
Validation loss: 0.07456026459406355
Validation accuracy: 0.9823767400397724
Best validation loss so far: 0.06181443450955078
Epoch time: 16.96254301071167
Current epoch 20
Current batch size 32
Validation loss has plateaued for 3 epochs
Training loss: 0.04858407398926813
Training accuracy: 0.9863051272506417
Validation loss: 0.06886662286706269
Validation accuracy: 0.9827881780154974
Best validation loss so far: 0.06181443450955078
Epoch time: 16.29864525794983
Current epoch 21
Current batch size 32
Validation loss has plateaued for 4 epochs
Training loss: 0.047012624492612974
Training accuracy: 0.9880684155874689
Validation loss: 0.061586590847582556
Validation accuracy: 0.98511965987794
Best validation loss so far: 0.061586590847582556
Epoch time: 17.49224019050598
Current epoch 22
Current batch size 32
Validation loss has plateaued for 0 epochs
Training loss: 0.0441433174594702
Training accuracy: 0.9871084030485296
Validation loss: 0.07087576272808069
Validation accuracy: 0.9827881780154974
Best validation loss so far: 0.061586590847582556
Epoch time: 16.837544918060303
Current epoch 23
Current batch size 32
Validation loss has plateaued for 1 epochs
Training loss: 0.04327004051605796
Training accuracy: 0.987696165827472
Validation loss: 0.07305410647422193
Validation accuracy: 0.9820338750600014
Best validation loss so far: 0.061586590847582556
Epoch time: 17.781333684921265
Current epoch 24
Current batch size 32
Validation loss has plateaued for 2 epochs
Training loss: 0.04319740521447572
Training accuracy: 0.9873043239748438
Validation loss: 0.07156930744046501
Validation accuracy: 0.98299389700336
Best validation loss so far: 0.061586590847582556
Epoch time: 17.007068634033203
Current epoch 25
Current batch size 32
Validation loss has plateaued for 3 epochs
Training loss: 0.04459984330658831
Training accuracy: 0.9872847318822123
Validation loss: 0.07416862640974087
Validation accuracy: 0.9838167729548104
Best validation loss so far: 0.061586590847582556
Epoch time: 16.850345849990845
Current epoch 26
Current batch size 32
Validation loss has plateaued for 4 epochs
Training loss: 0.04506697157228787
Training accuracy: 0.9870888109558982
Validation loss: 0.06764495102744836
Validation accuracy: 0.9832681889871768
Best validation loss so far: 0.061586590847582556
Epoch time: 19.053792476654053
Current epoch 27
Current batch size 32
Validation loss has plateaued for 5 epochs
Training loss: 0.044546362687009365
Training accuracy: 0.9868537058443212
Validation loss: 0.06427123854135923
Validation accuracy: 0.9825138860316808
Best validation loss so far: 0.061586590847582556
Epoch time: 18.845555782318115
Current epoch 28
Current batch size 32
Validation loss has plateaued for 6 epochs
Training loss: 0.043435348496369125
Training accuracy: 0.9863051272506417
Validation loss: 0.06533311409714913
Validation accuracy: 0.9845025029143524
Best validation loss so far: 0.061586590847582556
Epoch time: 17.404787302017212
Current epoch 29
Current batch size 32
Validation loss has plateaued for 7 epochs
Training loss: 0.047297982143466206
Training accuracy: 0.9872455476969495
Validation loss: 0.07197319508505691
Validation accuracy: 0.9820338750600014
Best validation loss so far: 0.061586590847582556
Epoch time: 17.614145278930664
Current epoch 30
Current batch size 64
Validation loss has plateaued for 8 epochs
Training loss: 0.03229245651010914
Training accuracy: 0.9910464136674438
Validation loss: 0.061415670516955866
Validation accuracy: 0.9866282657889324
Best validation loss so far: 0.061415670516955866
Epoch time: 10.00475025177002
Current epoch 31
Current batch size 64
Validation loss has plateaued for 0 epochs
Training loss: 0.029930324474058433
Training accuracy: 0.9916537685390177
Validation loss: 0.06336071727957715
Validation accuracy: 0.9870397037646574
Best validation loss so far: 0.061415670516955866
Epoch time: 9.40141487121582
Current epoch 32
Current batch size 64
Validation loss has plateaued for 1 epochs
Training loss: 0.03111360014136472
Training accuracy: 0.9904782429811329
Validation loss: 0.07372458294476178
Validation accuracy: 0.9853253788658026
Best validation loss so far: 0.061415670516955866
Epoch time: 9.495984554290771
Current epoch 33
Current batch size 64
Validation loss has plateaued for 2 epochs
Training loss: 0.033617910492102054
Training accuracy: 0.990537019259027
Validation loss: 0.09375316726660107
Validation accuracy: 0.9814167180964136
Best validation loss so far: 0.061415670516955866
Epoch time: 9.735952138900757
Current epoch 34
Current batch size 64
Validation loss has plateaued for 3 epochs
Training loss: 0.03288972672520854
Training accuracy: 0.9901255853137674
Validation loss: 0.07779436399281972
Validation accuracy: 0.9820338750600014
Best validation loss so far: 0.061415670516955866
Epoch time: 9.665679931640625
Current epoch 35
Current batch size 64
Validation loss has plateaued for 4 epochs
Training loss: 0.03503895539186141
Training accuracy: 0.9907133480927098
Validation loss: 0.09059854827169227
Validation accuracy: 0.9835424809709936
Best validation loss so far: 0.061415670516955866
Epoch time: 9.891529560089111
Current epoch 36
Current batch size 64
Validation loss has plateaued for 5 epochs
Training loss: 0.033166191175948816
Training accuracy: 0.989635782997982
Validation loss: 0.07999834188810558
Validation accuracy: 0.9850510868819858
Best validation loss so far: 0.061415670516955866
Epoch time: 8.977028846740723
Current epoch 37
Current batch size 64
Validation loss has plateaued for 6 epochs
Training loss: 0.030503041897684494
Training accuracy: 0.9909092690190239
Validation loss: 0.08545267803913846
Validation accuracy: 0.9853253788658026
Best validation loss so far: 0.061415670516955866
Epoch time: 9.67720079421997
Current epoch 38
Current batch size 64
Validation loss has plateaued for 7 epochs
Training loss: 0.032573911578763264
Training accuracy: 0.9897729276464019
Validation loss: 0.06656704326629997
Validation accuracy: 0.9873139957484742
Best validation loss so far: 0.061415670516955866
Epoch time: 9.957160472869873
Current epoch 39
Current batch size 128
Validation loss has plateaued for 8 epochs
Training loss: 0.02171412575698194
Training accuracy: 0.9932603201347936
Validation loss: 0.07847999895947348
Validation accuracy: 0.986902557772749
Best validation loss so far: 0.061415670516955866
Epoch time: 5.88164496421814
Current epoch 40
Current batch size 128
Validation loss has plateaued for 9 epochs
Training loss: 0.02085860077505301
Training accuracy: 0.9936129778021591
Validation loss: 0.08312600874035668
Validation accuracy: 0.9864225468010698
Best validation loss so far: 0.061415670516955866
Epoch time: 5.754244089126587
Current epoch 41
Current batch size 128
Validation loss has plateaued for 10 epochs
Training loss: 0.022007626555755007
Training accuracy: 0.9932015438568994
Validation loss: 0.07022756190047387
Validation accuracy: 0.9869711307687032
Best validation loss so far: 0.061415670516955866
Epoch time: 5.307782888412476
Current epoch 42
Current batch size 128
Validation loss has plateaued for 11 epochs
Training loss: 0.023238363061893853
Training accuracy: 0.9934170568758449
Validation loss: 0.07184197927167509
Validation accuracy: 0.9856682438455736
Best validation loss so far: 0.061415670516955866
Epoch time: 5.3743133544921875
Current epoch 43
Current batch size 128
Validation loss has plateaued for 12 epochs
Training loss: 0.023080671148930202
Training accuracy: 0.9930643992084794
Validation loss: 0.0685470320668333
Validation accuracy: 0.9877254337241994
Best validation loss so far: 0.061415670516955866
Epoch time: 5.585885524749756
Current epoch 44
Current batch size 128
Validation loss has plateaued for 13 epochs
Training loss: 0.02344710743060583
Training accuracy: 0.9930448071158481
Validation loss: 0.08351820774923749
Validation accuracy: 0.9862854008091614
Best validation loss so far: 0.061415670516955866
Epoch time: 5.4615912437438965
Current epoch 45
Current batch size 128
Validation loss has plateaued for 14 epochs
Training loss: 0.024643406720673267
Training accuracy: 0.9925550048000626
Validation loss: 0.0716867510927841
Validation accuracy: 0.9853253788658026
Best validation loss so far: 0.061415670516955866
Epoch time: 5.499894380569458
Current epoch 46
Current batch size 128
Validation loss has plateaued for 15 epochs
Training loss: 0.02531626852629543
Training accuracy: 0.9921043866695402
Validation loss: 0.08132775970367029
Validation accuracy: 0.9859425358293904
Best validation loss so far: 0.061415670516955866
Epoch time: 5.873080015182495
Current epoch 47
Current batch size 256
Validation loss has plateaued for 16 epochs
Training loss: 0.01887055601226166
Training accuracy: 0.9945729903410984
Validation loss: 0.07769224495349224
Validation accuracy: 0.986491119797024
Best validation loss so far: 0.061415670516955866
Epoch time: 3.655604839324951
Current epoch 48
Current batch size 256
Validation loss has plateaued for 17 epochs
Training loss: 0.014967346995836123
Training accuracy: 0.9951803452126722
Validation loss: 0.08718655184074714
Validation accuracy: 0.9862168278132072
Best validation loss so far: 0.061415670516955866
Epoch time: 3.77701735496521
Current epoch 49
Current batch size 256
Validation loss has plateaued for 18 epochs
Training loss: 0.016167639949708247
Training accuracy: 0.9948472796379382
Validation loss: 0.08126736897975206
Validation accuracy: 0.9865596927929782
Best validation loss so far: 0.061415670516955866
Epoch time: 3.7280848026275635
Current epoch 50
Current batch size 256
Validation loss has plateaued for 19 epochs
Training loss: 0.01921810561674647
Training accuracy: 0.9945533982484669
Validation loss: 0.0831841843431456
Validation accuracy: 0.9862168278132072
Best validation loss so far: 0.061415670516955866
Epoch time: 3.7328131198883057
Current epoch 51
Current batch size 256
Validation loss has plateaued for 20 epochs
Training loss: 0.017456095467205158
Training accuracy: 0.9947101349895182
Validation loss: 0.08114791524253394
Validation accuracy: 0.9863539738051156
Best validation loss so far: 0.061415670516955866
Epoch time: 3.824042320251465
Current epoch 52
Current batch size 256
Validation loss has plateaued for 21 epochs
Training loss: 0.01735563467082102
Training accuracy: 0.9944750298779412
Validation loss: 0.08258781973435952
Validation accuracy: 0.9884111636837414
Best validation loss so far: 0.061415670516955866
Epoch time: 3.763444662094116
Current epoch 53
Current batch size 256
Validation loss has plateaued for 22 epochs
Training loss: 0.01743745368090458
Training accuracy: 0.9945142140632041
Validation loss: 0.07817265899492461
Validation accuracy: 0.985805389837482
Best validation loss so far: 0.061415670516955866
Epoch time: 3.5514976978302
Current epoch 54
Current batch size 256
Validation loss has plateaued for 23 epochs
Training loss: 0.016653583026491104
Training accuracy: 0.9947101349895182
Validation loss: 0.09138312354161028
Validation accuracy: 0.98724542275252
Best validation loss so far: 0.061415670516955866
Epoch time: 3.5773463249206543
Current epoch 55
Current batch size 512
Validation loss has plateaued for 24 epochs
Training loss: 0.012735030422918498
Training accuracy: 0.9956309633431947
Validation loss: 0.08574167549096305
Validation accuracy: 0.9871082767606116
Best validation loss so far: 0.061415670516955866
Epoch time: 2.704899549484253
Current epoch 56
Current batch size 512
Validation loss has plateaued for 25 epochs
Training loss: 0.012711949818767608
Training accuracy: 0.9955134107874062
Validation loss: 0.09290048094659016
Validation accuracy: 0.9878625797161078
Best validation loss so far: 0.061415670516955866
Epoch time: 2.6715188026428223
Current epoch 57
Current batch size 512
Validation loss has plateaued for 26 epochs
Training loss: 0.01198679932160303
Training accuracy: 0.9958660684547717
Validation loss: 0.09201742496726842
Validation accuracy: 0.9871082767606116
Best validation loss so far: 0.061415670516955866
Epoch time: 2.6517176628112793
Current epoch 58
Current batch size 512
Validation loss has plateaued for 27 epochs
Training loss: 0.013237406165571883
Training accuracy: 0.9953370819537235
Validation loss: 0.08688194381779638
Validation accuracy: 0.986902557772749
Best validation loss so far: 0.061415670516955866
Epoch time: 2.7578959465026855
Current epoch 59
Current batch size 512
Validation loss has plateaued for 28 epochs
Training loss: 0.01205850999103859
Training accuracy: 0.9958268842695088
Validation loss: 0.08346784551595819
Validation accuracy: 0.9871768497565658
Best validation loss so far: 0.061415670516955866
Epoch time: 2.6216750144958496
Current epoch 60
Current batch size 512
Validation loss has plateaued for 29 epochs
Training loss: 0.012675212565809488
Training accuracy: 0.9956701475284575
Validation loss: 0.08758672371763608
Validation accuracy: 0.9873825687444284
Best validation loss so far: 0.061415670516955866
Epoch time: 2.716305732727051
Current epoch 61
Current batch size 512
Validation loss has plateaued for 30 epochs
Training loss: 0.013350522943073883
Training accuracy: 0.9955721870653005
Validation loss: 0.0879034168761352
Validation accuracy: 0.9859425358293904
Best validation loss so far: 0.061415670516955866
Epoch time: 2.683640480041504
Current epoch 62
Current batch size 512
Validation loss has plateaued for 31 epochs
Training loss: 0.01285180420614779
Training accuracy: 0.995552594972669
Validation loss: 0.08913909052980357
Validation accuracy: 0.9876568607282452
Best validation loss so far: 0.061415670516955866
Epoch time: 2.682720899581909
Current epoch 63
Current batch size 1024
Validation loss has plateaued for 32 epochs
Training loss: 0.010929818265140057
Training accuracy: 0.99612076565898
Validation loss: 0.08601320832967758
Validation accuracy: 0.9866282657889324
Best validation loss so far: 0.061415670516955866
Epoch time: 2.3699469566345215
Current epoch 64
Current batch size 1024
Validation loss has plateaued for 33 epochs
Training loss: 0.010039424323476851
Training accuracy: 0.9966301600673968
Validation loss: 0.1001866136988004
Validation accuracy: 0.9875197147363368
Best validation loss so far: 0.061415670516955866
Epoch time: 2.305299758911133
Current epoch 65
Current batch size 1024
Validation loss has plateaued for 34 epochs
Training loss: 0.009725354425609112
Training accuracy: 0.9963950549558198
Validation loss: 0.10690353165070215
Validation accuracy: 0.9874511417403826
Best validation loss so far: 0.061415670516955866
Epoch time: 2.294971466064453
Current epoch 66
Current batch size 1024
Validation loss has plateaued for 35 epochs
Training loss: 0.009057683399878442
Training accuracy: 0.9967868968084481
Validation loss: 0.10738842810193698
Validation accuracy: 0.9877940067201536
Best validation loss so far: 0.061415670516955866
Epoch time: 2.21946120262146
Current epoch 67
Current batch size 1024
Validation loss has plateaued for 36 epochs
Training loss: 0.00999259621836245
Training accuracy: 0.9965517916968711
Validation loss: 0.09545341928799947
Validation accuracy: 0.988274017691833
Best validation loss so far: 0.061415670516955866
Epoch time: 2.223292589187622
Current epoch 68
Current batch size 1024
Validation loss has plateaued for 37 epochs
Training loss: 0.008173715975135564
Training accuracy: 0.9967868968084481
Validation loss: 0.10537191952268282
Validation accuracy: 0.98724542275252
Best validation loss so far: 0.061415670516955866
Epoch time: 2.313368797302246
Current epoch 69
Current batch size 1024
Validation loss has plateaued for 38 epochs
Training loss: 0.008103228085674345
Training accuracy: 0.9970807781979193
Validation loss: 0.11642497678597769
Validation accuracy: 0.987931152712062
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2715446949005127
Current epoch 70
Current batch size 1024
Validation loss has plateaued for 39 epochs
Training loss: 0.00919597113970667
Training accuracy: 0.9967281205305538
Validation loss: 0.10784027427434921
Validation accuracy: 0.986491119797024
Best validation loss so far: 0.061415670516955866
Epoch time: 2.3160805702209473
Current epoch 71
Current batch size 1024
Validation loss has plateaued for 40 epochs
Training loss: 0.010282265054993332
Training accuracy: 0.9963166865852942
Validation loss: 0.0963414303958416
Validation accuracy: 0.9877254337241994
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2900140285491943
Current epoch 72
Current batch size 1024
Validation loss has plateaued for 41 epochs
Training loss: 0.008897184254601597
Training accuracy: 0.9968456730863423
Validation loss: 0.10360249256094296
Validation accuracy: 0.9866282657889324
Best validation loss so far: 0.061415670516955866
Epoch time: 2.238320827484131
Current epoch 73
Current batch size 1024
Validation loss has plateaued for 42 epochs
Training loss: 0.008531770464032888
Training accuracy: 0.9968456730863423
Validation loss: 0.10750259707371394
Validation accuracy: 0.9873825687444284
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2965457439422607
Current epoch 74
Current batch size 1024
Validation loss has plateaued for 43 epochs
Training loss: 0.008573842472396792
Training accuracy: 0.9970415940126565
Validation loss: 0.10380735645691554
Validation accuracy: 0.9866968387848866
Best validation loss so far: 0.061415670516955866
Epoch time: 2.3397650718688965
Current epoch 75
Current batch size 1024
Validation loss has plateaued for 44 epochs
Training loss: 0.008687639902345836
Training accuracy: 0.996688936345291
Validation loss: 0.10118161862095197
Validation accuracy: 0.9871768497565658
Best validation loss so far: 0.061415670516955866
Epoch time: 2.3468687534332275
Current epoch 76
Current batch size 1024
Validation loss has plateaued for 45 epochs
Training loss: 0.008084502387791872
Training accuracy: 0.9970024098273936
Validation loss: 0.11261810660362244
Validation accuracy: 0.9881368716999246
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2893991470336914
Current epoch 77
Current batch size 1024
Validation loss has plateaued for 46 epochs
Training loss: 0.007966156965121627
Training accuracy: 0.9971199623831821
Validation loss: 0.10353251074751219
Validation accuracy: 0.9871768497565658
Best validation loss so far: 0.061415670516955866
Epoch time: 2.331054925918579
Current epoch 78
Current batch size 1024
Validation loss has plateaued for 47 epochs
Training loss: 0.00788309826515615
Training accuracy: 0.9969828177347623
Validation loss: 0.11331362550457319
Validation accuracy: 0.9863539738051156
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2312557697296143
Current epoch 79
Current batch size 1024
Validation loss has plateaued for 48 epochs
Training loss: 0.007866435796022414
Training accuracy: 0.9968848572716051
Validation loss: 0.12261365205049515
Validation accuracy: 0.9876568607282452
Best validation loss so far: 0.061415670516955866
Epoch time: 2.304025888442993
Current epoch 80
Current batch size 1024
Validation loss has plateaued for 49 epochs
Training loss: 0.007455730964429677
Training accuracy: 0.9974726200505476
Validation loss: 0.11368072579304377
Validation accuracy: 0.988274017691833
Best validation loss so far: 0.061415670516955866
Epoch time: 2.1812942028045654
Current epoch 81
Current batch size 1024
Validation loss has plateaued for 50 epochs
Training loss: 0.007067323885858059
Training accuracy: 0.9972179228463393
Validation loss: 0.11195060685276985
Validation accuracy: 0.9877940067201536
Best validation loss so far: 0.061415670516955866
Epoch time: 2.182356357574463
Current epoch 82
Current batch size 1024
Validation loss has plateaued for 51 epochs
Training loss: 0.005886196801438928
Training accuracy: 0.9976685409768618
Validation loss: 0.11387138118346532
Validation accuracy: 0.9873139957484742
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2456588745117188
Current epoch 83
Current batch size 1024
Validation loss has plateaued for 52 epochs
Training loss: 0.006019606455229223
Training accuracy: 0.9976097646989674
Validation loss: 0.12921293278535206
Validation accuracy: 0.987931152712062
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2601568698883057
Current epoch 84
Current batch size 1024
Validation loss has plateaued for 53 epochs
Training loss: 0.006778459954075515
Training accuracy: 0.9974138437726534
Validation loss: 0.1152724509437879
Validation accuracy: 0.9873139957484742
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2022740840911865
Current epoch 85
Current batch size 1024
Validation loss has plateaued for 54 epochs
Training loss: 0.006473128912039101
Training accuracy: 0.9973550674947591
Validation loss: 0.11730540543794632
Validation accuracy: 0.9884797366796956
Best validation loss so far: 0.061415670516955866
Epoch time: 2.2285830974578857
Current epoch 86
Current batch size 1024
Validation loss has plateaued for 55 epochs
Training loss: 0.006308788727037609
Training accuracy: 0.9975901726063361
Validation loss: 0.12370573580265046
Validation accuracy: 0.9877940067201536
Best validation loss so far: 0.061415670516955866
Epoch time: 2.30120849609375
Current epoch 87
Current batch size 1024
Validation loss has plateaued for 56 epochs
Training loss: 0.005978679561521858
Training accuracy: 0.9978644619031759
Validation loss: 0.12931236649552982
Validation accuracy: 0.9881368716999246
Best validation loss so far: 0.061415670516955866
Epoch time: 2.259540319442749
Current epoch 88
Current batch size 1024
Validation loss has plateaued for 57 epochs
Training loss: 0.006088510861154646
Training accuracy: 0.9976685409768618
Validation loss: 0.1283781218032042
Validation accuracy: 0.9884111636837414
Best validation loss so far: 0.061415670516955866
Epoch time: 2.1688272953033447
Current epoch 89
Current batch size 1024
Validation loss has plateaued for 58 epochs
Training loss: 0.006190169749315828
Training accuracy: 0.9976097646989674
Validation loss: 0.12261400173107782
Validation accuracy: 0.9884111636837414
Best validation loss so far: 0.061415670516955866
Epoch time: 2.283600091934204
Current epoch 90
Current batch size 1024
Validation loss has plateaued for 59 epochs
Training loss: 0.005575546114705503
Training accuracy: 0.9977077251621246
Validation loss: 0.12886554822325708
Validation accuracy: 0.9882054446958788
Early stopping after no improvements in the validation loss for 60 epochs.
n_epoch = len(metrics["train_loss"])
epochs = list(range(1, n_epoch + 1))
fig, ax = plt.subplots()
fig.suptitle("Metics over time")
ax.set_xlim([0, n_epoch])
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ax.plot(epochs, metrics["train_loss"], c="#2980b9", label="Training loss")
ax.plot(epochs, metrics["valid_loss"], c="#2980b9", ls="--", label="Validation loss")
ax.axhline(metrics["best_valid_loss"], c="#3498db", ls=":", label="best validation loss")
ax2 = ax.twinx()
ax2.set_ylabel("Accuracy")
ax2.plot(epochs, metrics["train_acc"], c="#27ae60", label="Training accuracy")
ax2.plot(epochs, metrics["valid_acc"], c="#27ae60", ls="--", label="Validation accuracy")
fig.legend(loc=10)
<matplotlib.legend.Legend at 0x7f8db2180750>
Even with the regularization, overfitting is still a huge issue. Basically, the
model stop improving on the validation set after about 30 epochs. In fact, we
can tighten the early stopping criterion early_stop_plateau by quite a lot. Of
course, these are all RNG bounded. It is possible that, next time, even 60 is
too short to make the early stop. However, always remember that even if the
model does improve after a prolonged plateau, the improvement is likely abysmal.
with torch.no_grad():
model.eval()
test_loss = 0
test_acc =0
predicted = np.empty(0)
true = np.empty(0)
for (test_protein1, test_protein2, targets) in test_dataloader:
test_protein1 = test_protein1.to(device)
test_protein2 = test_protein2.to(device)
targets = targets.to(device)
pred = model(test_protein1, test_protein2)
batch_loss = loss_func(pred, targets)
test_loss += batch_loss.cpu().numpy()
predicted_label = torch.argmax(pred, axis=1).cpu().numpy()
true_label = torch.argmax(targets, axis=1).cpu().numpy()
predicted = np.append(predicted, predicted_label)
true = np.append(true, true_label)
test_acc += (predicted_label == true_label).sum()
print(test_loss / len(test_dataloader), test_acc / len(test_data))
0.1437098534805827 0.9849129063228638
The two values printed out are the test loss and test accuracy. The loss is not so intuitive, but we have got 98.5% accuracy on the test set! This is how simple these sequence-based binary classification can actually be.
With the predicted labels, we can draw a confusion matrix.
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(predicted, true)
disp = ConfusionMatrixDisplay(cm)
disp.plot(colorbar=False)
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7f8db0c54850>
I did not bother to fine tune the heatmap, but we have a very balanced dataset, and very balanced predictions as well. This just makes sure we are not in the case where the dataset is very unbalanced and we have a high accuracy, but possibly very poor other metrics, like recall and precision.
Again the notebook itself is on GitHub and can be downloaded from here.
Machine Learning Neural Network Coding Python Torch