test_size = int(0.1 * len(dataset))
valid_size = int(0.2 * len(dataset))
train_size = len(dataset) - test_size - valid_size
train_data, valid_data, test_data = random_split(dataset, [train_size, valid_size, test_size])
Now let me double-check the length of the split sets.
len(train_data), len(valid_data), len(test_data)
(51041, 14583, 7291)
Set up training
Everything looks right. Next, we will build the optimizer. Needless to say, I
will use Adam here. Note that, in the PyTorch implementation of Adam,
weight_decay
is actually L2 regularization. See the post
here.
optimizer = torch.optim.Adam(model.parameters(), 2e-3, weight_decay=2e-4)
Here I will initialize the loss function, which is just the binary cross entropy loss, and the dictionary to log the training metics. Nothing special.
loss_func = torch.nn.BCELoss()
metrics = {
"train_loss": [],
"valid_loss": [],
"train_acc": [],
"valid_acc": [],
"best_valid_loss": None,
}
Next step, I will define a function that implements the batch size and learning rate regularization (in an ugly way). It basically does what was said in the famous paper, "Don't Decay the Learning Rate, Increase the Batch Size".
def regularization(
batch_size: int,
max_batch_size: int,
train_data: TensorDataset,
valid_data: TensorDataset,
optimizer: torch.optim,
model: torch.nn.Module,
l2_reg=2e-4,
):
if batch_size == max_batch_size:
lr = optimizer.param_groups[0]["lr"]
lr /= 2
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=l2_reg)
else:
tmp = 2 * batch_size
if tmp > max_batch_size:
batch_size = max_batch_size
else:
batch_size = tmp
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size)
return batch_size, train_dataloader, valid_dataloader, optimizer
If you want a little bit explanation, here it goes. We keep track of a certain metric as the criteria, for example, validation loss. If the validation loss has not been improved from a certain epochs, we double the batch size, until the defined maximum batch size is reached. After that, every time, the model is stuck for that many epochs, we will lower the learning rate by half. This will help us to train a better model. The parameters used from this regularization are defined below.
batch_size = 32 # initial batch size
max_batch_size = 1024 # maximum batch size
raise_batch_patience = 8 # call `regularization` when validation loss is stuck for this number of epochs
early_stop_patience = 60 # exit the training loop if the validation loss is stuck more than 60 epochs
Here we build a dataset loader with torch's built-in DataLoader
, such that the
batching can be done easily. To minimize the impact of initial random states, we
use shuffle=True
, so every epoch get different batches.
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
The training loop
Here is the main training loop. I guess it is pretty much the standard way to training the model in torch. Basically, for each epoch, we calculate the gradients with the training set to update the model, and then apply the current model onto the validation set to determine its actual performance without biases. Then we will use the validation loss to do the batch size/learning rate regularization and check if early stopping should be done. When we exit the training loop, the parameters from the "best" model will be restored to ensure the performance of the network. Here, "best" is the one that has the best validation loss. Apparently, you can use other metrics as the criterion.
raise_batch_plateau = 0
early_stop_plateau = 0
for i in range(1000):
start_time = time.time()
model.train()
print(f"Current epoch {i + 1}")
print(f" Current batch size {batch_size}")
print(f" Validation loss has plateaued for {early_stop_plateau} epochs")
train_loss = 0
valid_loss = 0
train_acc = 0
valid_acc = 0
# training the model
for train_protein1, train_protein2, targets in train_dataloader:
train_protein1 = train_protein1.to(device)
train_protein2 = train_protein2.to(device)
targets = targets.to(device)
pred = model(train_protein1, train_protein2)
batch_loss = loss_func(pred, targets)
# get training metrics
train_loss += batch_loss.detach().cpu().numpy()
train_acc += (
(torch.argmax(pred, axis=1) == torch.argmax(targets, axis=1))
.sum()
.cpu()
.numpy()
)
# get gradients and update model weights
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# validation
with torch.no_grad():
model.eval()
for valid_protein1, valid_protein2, targets in valid_dataloader:
valid_protein1 = valid_protein1.to(device)
valid_protein2 = valid_protein2.to(device)
targets = targets.to(device)
pred = model(valid_protein1, valid_protein2)
batch_loss = loss_func(pred, targets)
valid_loss += batch_loss.cpu().numpy()
valid_acc += (
(torch.argmax(pred, axis=1) == torch.argmax(targets, axis=1))
.sum()
.cpu()
.numpy()
)
train_loss /= len(train_dataloader)
valid_loss /= len(valid_dataloader)
train_acc /= len(train_data)
valid_acc /= len(valid_data)
# print and log
print(f" Training loss: {train_loss}")
print(f" Training accuracy: {train_acc}")
print(f" Validation loss: {valid_loss}")
print(f" Validation accuracy: {valid_acc}")
metrics["train_loss"].append(train_loss)
metrics["valid_loss"].append(valid_loss)
metrics["train_acc"].append(train_acc)
metrics["valid_acc"].append(valid_acc)
# regularization for the batch size and learning rate
if metrics["best_valid_loss"] is None or metrics["best_valid_loss"] > valid_loss:
raise_batch_plateau = 0
early_stop_plateau = 0
metrics["best_valid_loss"] = valid_loss
best_parameters = model.state_dict()
else:
raise_batch_plateau += 1
early_stop_plateau += 1
if raise_batch_plateau == raise_batch_patience:
batch_size, train_dataloader, valid_dataloader, optimizer = regularization(
batch_size, max_batch_size, train_data, valid_data, optimizer, model
)
raise_batch_plateau = 0
# early stop to prevent overfitting
if early_stop_plateau == early_stop_patience:
print(
f"Early stopping after no improvements in the validation loss for {early_stop_patience} epochs."
)
# restore the parameters of the "best" model
model.load_state_dict(best_parameters)
break
print(f" Best validation loss so far: {metrics['best_valid_loss']}")
print(f" Epoch time: {time.time() - start_time}")
else:
print(f"Best validation loss: {metrics['best_valid_loss']}")
print(f"Total number of epochs: {i + 1} seconds")
Current epoch 1 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.3683407397679331 Training accuracy: 0.8278246899551341 Validation loss: 0.14916993617978797 Validation accuracy: 0.9509017348967976 Best validation loss so far: 0.14916993617978797 Epoch time: 17.75820255279541 Current epoch 2 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.1534857010179407 Training accuracy: 0.9502360847162086 Validation loss: 0.10900586530394656 Validation accuracy: 0.9663992319824453 Best validation loss so far: 0.10900586530394656 Epoch time: 16.510152339935303 Current epoch 3 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.11228975994155899 Training accuracy: 0.9659489430066025 Validation loss: 0.09190676893769323 Validation accuracy: 0.9726393746142769 Best validation loss so far: 0.09190676893769323 Epoch time: 17.22013568878174 Current epoch 4 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.09009803373166442 Training accuracy: 0.973589859132854 Validation loss: 0.079203839863144 Validation accuracy: 0.9788109442501542 Best validation loss so far: 0.079203839863144 Epoch time: 17.492194652557373 Current epoch 5 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.07818448865149558 Training accuracy: 0.9778217511412394 Validation loss: 0.07646888854357888 Validation accuracy: 0.9796338202016046 Best validation loss so far: 0.07646888854357888 Epoch time: 18.225395917892456 Current epoch 6 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.06945815600738113 Training accuracy: 0.9797613683117494 Validation loss: 0.07843868468434696 Validation accuracy: 0.979085236233971 Best validation loss so far: 0.07646888854357888 Epoch time: 20.313061475753784 Current epoch 7 Current batch size 32 Validation loss has plateaued for 1 epochs Training loss: 0.06475827005861927 Training accuracy: 0.9818577222233107 Validation loss: 0.07695504867348191 Validation accuracy: 0.9829253240074058 Best validation loss so far: 0.07646888854357888 Epoch time: 19.2711660861969 Current epoch 8 Current batch size 32 Validation loss has plateaued for 2 epochs Training loss: 0.06451458194719646 Training accuracy: 0.9823671166317275 Validation loss: 0.0691892049217131 Validation accuracy: 0.981896729068093 Best validation loss so far: 0.0691892049217131 Epoch time: 17.760085821151733 Current epoch 9 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.0597318727137205 Training accuracy: 0.9836014184675065 Validation loss: 0.07737614678090292 Validation accuracy: 0.9805938421449634 Best validation loss so far: 0.0691892049217131 Epoch time: 18.603302240371704 Current epoch 10 Current batch size 32 Validation loss has plateaued for 1 epochs Training loss: 0.05820741842538504 Training accuracy: 0.9847181677474971 Validation loss: 0.06344870385905858 Validation accuracy: 0.9863539738051156 Best validation loss so far: 0.06344870385905858 Epoch time: 17.374438524246216 Current epoch 11 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.052588668085047086 Training accuracy: 0.9854430751748594 Validation loss: 0.07269282722993746 Validation accuracy: 0.9845710759103066 Best validation loss so far: 0.06344870385905858 Epoch time: 16.83235502243042 Current epoch 12 Current batch size 32 Validation loss has plateaued for 1 epochs Training loss: 0.057011060293171295 Training accuracy: 0.9840128524127663 Validation loss: 0.064653842423005 Validation accuracy: 0.9835424809709936 Best validation loss so far: 0.06344870385905858 Epoch time: 17.263275623321533 Current epoch 13 Current batch size 32 Validation loss has plateaued for 2 epochs Training loss: 0.0546581399321622 Training accuracy: 0.9841891812464489 Validation loss: 0.06777635000323419 Validation accuracy: 0.98299389700336 Best validation loss so far: 0.06344870385905858 Epoch time: 17.358654260635376 Current epoch 14 Current batch size 32 Validation loss has plateaued for 3 epochs Training loss: 0.05259665207112514 Training accuracy: 0.9855018514527537 Validation loss: 0.06801316980068293 Validation accuracy: 0.9832681889871768 Best validation loss so far: 0.06344870385905858 Epoch time: 17.22840452194214 Current epoch 15 Current batch size 32 Validation loss has plateaued for 4 epochs Training loss: 0.04765466454235057 Training accuracy: 0.9867165611959013 Validation loss: 0.0723494520208445 Validation accuracy: 0.9830624699993142 Best validation loss so far: 0.06344870385905858 Epoch time: 17.35137963294983 Current epoch 16 Current batch size 32 Validation loss has plateaued for 5 epochs Training loss: 0.050809676872624984 Training accuracy: 0.9854038909895966 Validation loss: 0.06181443450955078 Validation accuracy: 0.9857368168415278 Best validation loss so far: 0.06181443450955078 Epoch time: 17.174020051956177 Current epoch 17 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.0511245566147746 Training accuracy: 0.9857761407495934 Validation loss: 0.06468374196801481 Validation accuracy: 0.9860796818212988 Best validation loss so far: 0.06181443450955078 Epoch time: 17.998218536376953 Current epoch 18 Current batch size 32 Validation loss has plateaued for 1 epochs Training loss: 0.04710395816011287 Training accuracy: 0.9865402323622187 Validation loss: 0.06764085641107317 Validation accuracy: 0.9868339847767948 Best validation loss so far: 0.06181443450955078 Epoch time: 16.71057677268982 Current epoch 19 Current batch size 32 Validation loss has plateaued for 2 epochs Training loss: 0.04492825991094566 Training accuracy: 0.9874022844380008 Validation loss: 0.07456026459406355 Validation accuracy: 0.9823767400397724 Best validation loss so far: 0.06181443450955078 Epoch time: 16.96254301071167 Current epoch 20 Current batch size 32 Validation loss has plateaued for 3 epochs Training loss: 0.04858407398926813 Training accuracy: 0.9863051272506417 Validation loss: 0.06886662286706269 Validation accuracy: 0.9827881780154974 Best validation loss so far: 0.06181443450955078 Epoch time: 16.29864525794983 Current epoch 21 Current batch size 32 Validation loss has plateaued for 4 epochs Training loss: 0.047012624492612974 Training accuracy: 0.9880684155874689 Validation loss: 0.061586590847582556 Validation accuracy: 0.98511965987794 Best validation loss so far: 0.061586590847582556 Epoch time: 17.49224019050598 Current epoch 22 Current batch size 32 Validation loss has plateaued for 0 epochs Training loss: 0.0441433174594702 Training accuracy: 0.9871084030485296 Validation loss: 0.07087576272808069 Validation accuracy: 0.9827881780154974 Best validation loss so far: 0.061586590847582556 Epoch time: 16.837544918060303 Current epoch 23 Current batch size 32 Validation loss has plateaued for 1 epochs Training loss: 0.04327004051605796 Training accuracy: 0.987696165827472 Validation loss: 0.07305410647422193 Validation accuracy: 0.9820338750600014 Best validation loss so far: 0.061586590847582556 Epoch time: 17.781333684921265 Current epoch 24 Current batch size 32 Validation loss has plateaued for 2 epochs Training loss: 0.04319740521447572 Training accuracy: 0.9873043239748438 Validation loss: 0.07156930744046501 Validation accuracy: 0.98299389700336 Best validation loss so far: 0.061586590847582556 Epoch time: 17.007068634033203 Current epoch 25 Current batch size 32 Validation loss has plateaued for 3 epochs Training loss: 0.04459984330658831 Training accuracy: 0.9872847318822123 Validation loss: 0.07416862640974087 Validation accuracy: 0.9838167729548104 Best validation loss so far: 0.061586590847582556 Epoch time: 16.850345849990845 Current epoch 26 Current batch size 32 Validation loss has plateaued for 4 epochs Training loss: 0.04506697157228787 Training accuracy: 0.9870888109558982 Validation loss: 0.06764495102744836 Validation accuracy: 0.9832681889871768 Best validation loss so far: 0.061586590847582556 Epoch time: 19.053792476654053 Current epoch 27 Current batch size 32 Validation loss has plateaued for 5 epochs Training loss: 0.044546362687009365 Training accuracy: 0.9868537058443212 Validation loss: 0.06427123854135923 Validation accuracy: 0.9825138860316808 Best validation loss so far: 0.061586590847582556 Epoch time: 18.845555782318115 Current epoch 28 Current batch size 32 Validation loss has plateaued for 6 epochs Training loss: 0.043435348496369125 Training accuracy: 0.9863051272506417 Validation loss: 0.06533311409714913 Validation accuracy: 0.9845025029143524 Best validation loss so far: 0.061586590847582556 Epoch time: 17.404787302017212 Current epoch 29 Current batch size 32 Validation loss has plateaued for 7 epochs Training loss: 0.047297982143466206 Training accuracy: 0.9872455476969495 Validation loss: 0.07197319508505691 Validation accuracy: 0.9820338750600014 Best validation loss so far: 0.061586590847582556 Epoch time: 17.614145278930664 Current epoch 30 Current batch size 64 Validation loss has plateaued for 8 epochs Training loss: 0.03229245651010914 Training accuracy: 0.9910464136674438 Validation loss: 0.061415670516955866 Validation accuracy: 0.9866282657889324 Best validation loss so far: 0.061415670516955866 Epoch time: 10.00475025177002 Current epoch 31 Current batch size 64 Validation loss has plateaued for 0 epochs Training loss: 0.029930324474058433 Training accuracy: 0.9916537685390177 Validation loss: 0.06336071727957715 Validation accuracy: 0.9870397037646574 Best validation loss so far: 0.061415670516955866 Epoch time: 9.40141487121582 Current epoch 32 Current batch size 64 Validation loss has plateaued for 1 epochs Training loss: 0.03111360014136472 Training accuracy: 0.9904782429811329 Validation loss: 0.07372458294476178 Validation accuracy: 0.9853253788658026 Best validation loss so far: 0.061415670516955866 Epoch time: 9.495984554290771 Current epoch 33 Current batch size 64 Validation loss has plateaued for 2 epochs Training loss: 0.033617910492102054 Training accuracy: 0.990537019259027 Validation loss: 0.09375316726660107 Validation accuracy: 0.9814167180964136 Best validation loss so far: 0.061415670516955866 Epoch time: 9.735952138900757 Current epoch 34 Current batch size 64 Validation loss has plateaued for 3 epochs Training loss: 0.03288972672520854 Training accuracy: 0.9901255853137674 Validation loss: 0.07779436399281972 Validation accuracy: 0.9820338750600014 Best validation loss so far: 0.061415670516955866 Epoch time: 9.665679931640625 Current epoch 35 Current batch size 64 Validation loss has plateaued for 4 epochs Training loss: 0.03503895539186141 Training accuracy: 0.9907133480927098 Validation loss: 0.09059854827169227 Validation accuracy: 0.9835424809709936 Best validation loss so far: 0.061415670516955866 Epoch time: 9.891529560089111 Current epoch 36 Current batch size 64 Validation loss has plateaued for 5 epochs Training loss: 0.033166191175948816 Training accuracy: 0.989635782997982 Validation loss: 0.07999834188810558 Validation accuracy: 0.9850510868819858 Best validation loss so far: 0.061415670516955866 Epoch time: 8.977028846740723 Current epoch 37 Current batch size 64 Validation loss has plateaued for 6 epochs Training loss: 0.030503041897684494 Training accuracy: 0.9909092690190239 Validation loss: 0.08545267803913846 Validation accuracy: 0.9853253788658026 Best validation loss so far: 0.061415670516955866 Epoch time: 9.67720079421997 Current epoch 38 Current batch size 64 Validation loss has plateaued for 7 epochs Training loss: 0.032573911578763264 Training accuracy: 0.9897729276464019 Validation loss: 0.06656704326629997 Validation accuracy: 0.9873139957484742 Best validation loss so far: 0.061415670516955866 Epoch time: 9.957160472869873 Current epoch 39 Current batch size 128 Validation loss has plateaued for 8 epochs Training loss: 0.02171412575698194 Training accuracy: 0.9932603201347936 Validation loss: 0.07847999895947348 Validation accuracy: 0.986902557772749 Best validation loss so far: 0.061415670516955866 Epoch time: 5.88164496421814 Current epoch 40 Current batch size 128 Validation loss has plateaued for 9 epochs Training loss: 0.02085860077505301 Training accuracy: 0.9936129778021591 Validation loss: 0.08312600874035668 Validation accuracy: 0.9864225468010698 Best validation loss so far: 0.061415670516955866 Epoch time: 5.754244089126587 Current epoch 41 Current batch size 128 Validation loss has plateaued for 10 epochs Training loss: 0.022007626555755007 Training accuracy: 0.9932015438568994 Validation loss: 0.07022756190047387 Validation accuracy: 0.9869711307687032 Best validation loss so far: 0.061415670516955866 Epoch time: 5.307782888412476 Current epoch 42 Current batch size 128 Validation loss has plateaued for 11 epochs Training loss: 0.023238363061893853 Training accuracy: 0.9934170568758449 Validation loss: 0.07184197927167509 Validation accuracy: 0.9856682438455736 Best validation loss so far: 0.061415670516955866 Epoch time: 5.3743133544921875 Current epoch 43 Current batch size 128 Validation loss has plateaued for 12 epochs Training loss: 0.023080671148930202 Training accuracy: 0.9930643992084794 Validation loss: 0.0685470320668333 Validation accuracy: 0.9877254337241994 Best validation loss so far: 0.061415670516955866 Epoch time: 5.585885524749756 Current epoch 44 Current batch size 128 Validation loss has plateaued for 13 epochs Training loss: 0.02344710743060583 Training accuracy: 0.9930448071158481 Validation loss: 0.08351820774923749 Validation accuracy: 0.9862854008091614 Best validation loss so far: 0.061415670516955866 Epoch time: 5.4615912437438965 Current epoch 45 Current batch size 128 Validation loss has plateaued for 14 epochs Training loss: 0.024643406720673267 Training accuracy: 0.9925550048000626 Validation loss: 0.0716867510927841 Validation accuracy: 0.9853253788658026 Best validation loss so far: 0.061415670516955866 Epoch time: 5.499894380569458 Current epoch 46 Current batch size 128 Validation loss has plateaued for 15 epochs Training loss: 0.02531626852629543 Training accuracy: 0.9921043866695402 Validation loss: 0.08132775970367029 Validation accuracy: 0.9859425358293904 Best validation loss so far: 0.061415670516955866 Epoch time: 5.873080015182495 Current epoch 47 Current batch size 256 Validation loss has plateaued for 16 epochs Training loss: 0.01887055601226166 Training accuracy: 0.9945729903410984 Validation loss: 0.07769224495349224 Validation accuracy: 0.986491119797024 Best validation loss so far: 0.061415670516955866 Epoch time: 3.655604839324951 Current epoch 48 Current batch size 256 Validation loss has plateaued for 17 epochs Training loss: 0.014967346995836123 Training accuracy: 0.9951803452126722 Validation loss: 0.08718655184074714 Validation accuracy: 0.9862168278132072 Best validation loss so far: 0.061415670516955866 Epoch time: 3.77701735496521 Current epoch 49 Current batch size 256 Validation loss has plateaued for 18 epochs Training loss: 0.016167639949708247 Training accuracy: 0.9948472796379382 Validation loss: 0.08126736897975206 Validation accuracy: 0.9865596927929782 Best validation loss so far: 0.061415670516955866 Epoch time: 3.7280848026275635 Current epoch 50 Current batch size 256 Validation loss has plateaued for 19 epochs Training loss: 0.01921810561674647 Training accuracy: 0.9945533982484669 Validation loss: 0.0831841843431456 Validation accuracy: 0.9862168278132072 Best validation loss so far: 0.061415670516955866 Epoch time: 3.7328131198883057 Current epoch 51 Current batch size 256 Validation loss has plateaued for 20 epochs Training loss: 0.017456095467205158 Training accuracy: 0.9947101349895182 Validation loss: 0.08114791524253394 Validation accuracy: 0.9863539738051156 Best validation loss so far: 0.061415670516955866 Epoch time: 3.824042320251465 Current epoch 52 Current batch size 256 Validation loss has plateaued for 21 epochs Training loss: 0.01735563467082102 Training accuracy: 0.9944750298779412 Validation loss: 0.08258781973435952 Validation accuracy: 0.9884111636837414 Best validation loss so far: 0.061415670516955866 Epoch time: 3.763444662094116 Current epoch 53 Current batch size 256 Validation loss has plateaued for 22 epochs Training loss: 0.01743745368090458 Training accuracy: 0.9945142140632041 Validation loss: 0.07817265899492461 Validation accuracy: 0.985805389837482 Best validation loss so far: 0.061415670516955866 Epoch time: 3.5514976978302 Current epoch 54 Current batch size 256 Validation loss has plateaued for 23 epochs Training loss: 0.016653583026491104 Training accuracy: 0.9947101349895182 Validation loss: 0.09138312354161028 Validation accuracy: 0.98724542275252 Best validation loss so far: 0.061415670516955866 Epoch time: 3.5773463249206543 Current epoch 55 Current batch size 512 Validation loss has plateaued for 24 epochs Training loss: 0.012735030422918498 Training accuracy: 0.9956309633431947 Validation loss: 0.08574167549096305 Validation accuracy: 0.9871082767606116 Best validation loss so far: 0.061415670516955866 Epoch time: 2.704899549484253 Current epoch 56 Current batch size 512 Validation loss has plateaued for 25 epochs Training loss: 0.012711949818767608 Training accuracy: 0.9955134107874062 Validation loss: 0.09290048094659016 Validation accuracy: 0.9878625797161078 Best validation loss so far: 0.061415670516955866 Epoch time: 2.6715188026428223 Current epoch 57 Current batch size 512 Validation loss has plateaued for 26 epochs Training loss: 0.01198679932160303 Training accuracy: 0.9958660684547717 Validation loss: 0.09201742496726842 Validation accuracy: 0.9871082767606116 Best validation loss so far: 0.061415670516955866 Epoch time: 2.6517176628112793 Current epoch 58 Current batch size 512 Validation loss has plateaued for 27 epochs Training loss: 0.013237406165571883 Training accuracy: 0.9953370819537235 Validation loss: 0.08688194381779638 Validation accuracy: 0.986902557772749 Best validation loss so far: 0.061415670516955866 Epoch time: 2.7578959465026855 Current epoch 59 Current batch size 512 Validation loss has plateaued for 28 epochs Training loss: 0.01205850999103859 Training accuracy: 0.9958268842695088 Validation loss: 0.08346784551595819 Validation accuracy: 0.9871768497565658 Best validation loss so far: 0.061415670516955866 Epoch time: 2.6216750144958496 Current epoch 60 Current batch size 512 Validation loss has plateaued for 29 epochs Training loss: 0.012675212565809488 Training accuracy: 0.9956701475284575 Validation loss: 0.08758672371763608 Validation accuracy: 0.9873825687444284 Best validation loss so far: 0.061415670516955866 Epoch time: 2.716305732727051 Current epoch 61 Current batch size 512 Validation loss has plateaued for 30 epochs Training loss: 0.013350522943073883 Training accuracy: 0.9955721870653005 Validation loss: 0.0879034168761352 Validation accuracy: 0.9859425358293904 Best validation loss so far: 0.061415670516955866 Epoch time: 2.683640480041504 Current epoch 62 Current batch size 512 Validation loss has plateaued for 31 epochs Training loss: 0.01285180420614779 Training accuracy: 0.995552594972669 Validation loss: 0.08913909052980357 Validation accuracy: 0.9876568607282452 Best validation loss so far: 0.061415670516955866 Epoch time: 2.682720899581909 Current epoch 63 Current batch size 1024 Validation loss has plateaued for 32 epochs Training loss: 0.010929818265140057 Training accuracy: 0.99612076565898 Validation loss: 0.08601320832967758 Validation accuracy: 0.9866282657889324 Best validation loss so far: 0.061415670516955866 Epoch time: 2.3699469566345215 Current epoch 64 Current batch size 1024 Validation loss has plateaued for 33 epochs Training loss: 0.010039424323476851 Training accuracy: 0.9966301600673968 Validation loss: 0.1001866136988004 Validation accuracy: 0.9875197147363368 Best validation loss so far: 0.061415670516955866 Epoch time: 2.305299758911133 Current epoch 65 Current batch size 1024 Validation loss has plateaued for 34 epochs Training loss: 0.009725354425609112 Training accuracy: 0.9963950549558198 Validation loss: 0.10690353165070215 Validation accuracy: 0.9874511417403826 Best validation loss so far: 0.061415670516955866 Epoch time: 2.294971466064453 Current epoch 66 Current batch size 1024 Validation loss has plateaued for 35 epochs Training loss: 0.009057683399878442 Training accuracy: 0.9967868968084481 Validation loss: 0.10738842810193698 Validation accuracy: 0.9877940067201536 Best validation loss so far: 0.061415670516955866 Epoch time: 2.21946120262146 Current epoch 67 Current batch size 1024 Validation loss has plateaued for 36 epochs Training loss: 0.00999259621836245 Training accuracy: 0.9965517916968711 Validation loss: 0.09545341928799947 Validation accuracy: 0.988274017691833 Best validation loss so far: 0.061415670516955866 Epoch time: 2.223292589187622 Current epoch 68 Current batch size 1024 Validation loss has plateaued for 37 epochs Training loss: 0.008173715975135564 Training accuracy: 0.9967868968084481 Validation loss: 0.10537191952268282 Validation accuracy: 0.98724542275252 Best validation loss so far: 0.061415670516955866 Epoch time: 2.313368797302246 Current epoch 69 Current batch size 1024 Validation loss has plateaued for 38 epochs Training loss: 0.008103228085674345 Training accuracy: 0.9970807781979193 Validation loss: 0.11642497678597769 Validation accuracy: 0.987931152712062 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2715446949005127 Current epoch 70 Current batch size 1024 Validation loss has plateaued for 39 epochs Training loss: 0.00919597113970667 Training accuracy: 0.9967281205305538 Validation loss: 0.10784027427434921 Validation accuracy: 0.986491119797024 Best validation loss so far: 0.061415670516955866 Epoch time: 2.3160805702209473 Current epoch 71 Current batch size 1024 Validation loss has plateaued for 40 epochs Training loss: 0.010282265054993332 Training accuracy: 0.9963166865852942 Validation loss: 0.0963414303958416 Validation accuracy: 0.9877254337241994 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2900140285491943 Current epoch 72 Current batch size 1024 Validation loss has plateaued for 41 epochs Training loss: 0.008897184254601597 Training accuracy: 0.9968456730863423 Validation loss: 0.10360249256094296 Validation accuracy: 0.9866282657889324 Best validation loss so far: 0.061415670516955866 Epoch time: 2.238320827484131 Current epoch 73 Current batch size 1024 Validation loss has plateaued for 42 epochs Training loss: 0.008531770464032888 Training accuracy: 0.9968456730863423 Validation loss: 0.10750259707371394 Validation accuracy: 0.9873825687444284 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2965457439422607 Current epoch 74 Current batch size 1024 Validation loss has plateaued for 43 epochs Training loss: 0.008573842472396792 Training accuracy: 0.9970415940126565 Validation loss: 0.10380735645691554 Validation accuracy: 0.9866968387848866 Best validation loss so far: 0.061415670516955866 Epoch time: 2.3397650718688965 Current epoch 75 Current batch size 1024 Validation loss has plateaued for 44 epochs Training loss: 0.008687639902345836 Training accuracy: 0.996688936345291 Validation loss: 0.10118161862095197 Validation accuracy: 0.9871768497565658 Best validation loss so far: 0.061415670516955866 Epoch time: 2.3468687534332275 Current epoch 76 Current batch size 1024 Validation loss has plateaued for 45 epochs Training loss: 0.008084502387791872 Training accuracy: 0.9970024098273936 Validation loss: 0.11261810660362244 Validation accuracy: 0.9881368716999246 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2893991470336914 Current epoch 77 Current batch size 1024 Validation loss has plateaued for 46 epochs Training loss: 0.007966156965121627 Training accuracy: 0.9971199623831821 Validation loss: 0.10353251074751219 Validation accuracy: 0.9871768497565658 Best validation loss so far: 0.061415670516955866 Epoch time: 2.331054925918579 Current epoch 78 Current batch size 1024 Validation loss has plateaued for 47 epochs Training loss: 0.00788309826515615 Training accuracy: 0.9969828177347623 Validation loss: 0.11331362550457319 Validation accuracy: 0.9863539738051156 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2312557697296143 Current epoch 79 Current batch size 1024 Validation loss has plateaued for 48 epochs Training loss: 0.007866435796022414 Training accuracy: 0.9968848572716051 Validation loss: 0.12261365205049515 Validation accuracy: 0.9876568607282452 Best validation loss so far: 0.061415670516955866 Epoch time: 2.304025888442993 Current epoch 80 Current batch size 1024 Validation loss has plateaued for 49 epochs Training loss: 0.007455730964429677 Training accuracy: 0.9974726200505476 Validation loss: 0.11368072579304377 Validation accuracy: 0.988274017691833 Best validation loss so far: 0.061415670516955866 Epoch time: 2.1812942028045654 Current epoch 81 Current batch size 1024 Validation loss has plateaued for 50 epochs Training loss: 0.007067323885858059 Training accuracy: 0.9972179228463393 Validation loss: 0.11195060685276985 Validation accuracy: 0.9877940067201536 Best validation loss so far: 0.061415670516955866 Epoch time: 2.182356357574463 Current epoch 82 Current batch size 1024 Validation loss has plateaued for 51 epochs Training loss: 0.005886196801438928 Training accuracy: 0.9976685409768618 Validation loss: 0.11387138118346532 Validation accuracy: 0.9873139957484742 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2456588745117188 Current epoch 83 Current batch size 1024 Validation loss has plateaued for 52 epochs Training loss: 0.006019606455229223 Training accuracy: 0.9976097646989674 Validation loss: 0.12921293278535206 Validation accuracy: 0.987931152712062 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2601568698883057 Current epoch 84 Current batch size 1024 Validation loss has plateaued for 53 epochs Training loss: 0.006778459954075515 Training accuracy: 0.9974138437726534 Validation loss: 0.1152724509437879 Validation accuracy: 0.9873139957484742 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2022740840911865 Current epoch 85 Current batch size 1024 Validation loss has plateaued for 54 epochs Training loss: 0.006473128912039101 Training accuracy: 0.9973550674947591 Validation loss: 0.11730540543794632 Validation accuracy: 0.9884797366796956 Best validation loss so far: 0.061415670516955866 Epoch time: 2.2285830974578857 Current epoch 86 Current batch size 1024 Validation loss has plateaued for 55 epochs Training loss: 0.006308788727037609 Training accuracy: 0.9975901726063361 Validation loss: 0.12370573580265046 Validation accuracy: 0.9877940067201536 Best validation loss so far: 0.061415670516955866 Epoch time: 2.30120849609375 Current epoch 87 Current batch size 1024 Validation loss has plateaued for 56 epochs Training loss: 0.005978679561521858 Training accuracy: 0.9978644619031759 Validation loss: 0.12931236649552982 Validation accuracy: 0.9881368716999246 Best validation loss so far: 0.061415670516955866 Epoch time: 2.259540319442749 Current epoch 88 Current batch size 1024 Validation loss has plateaued for 57 epochs Training loss: 0.006088510861154646 Training accuracy: 0.9976685409768618 Validation loss: 0.1283781218032042 Validation accuracy: 0.9884111636837414 Best validation loss so far: 0.061415670516955866 Epoch time: 2.1688272953033447 Current epoch 89 Current batch size 1024 Validation loss has plateaued for 58 epochs Training loss: 0.006190169749315828 Training accuracy: 0.9976097646989674 Validation loss: 0.12261400173107782 Validation accuracy: 0.9884111636837414 Best validation loss so far: 0.061415670516955866 Epoch time: 2.283600091934204 Current epoch 90 Current batch size 1024 Validation loss has plateaued for 59 epochs Training loss: 0.005575546114705503 Training accuracy: 0.9977077251621246 Validation loss: 0.12886554822325708 Validation accuracy: 0.9882054446958788 Early stopping after no improvements in the validation loss for 60 epochs.
n_epoch = len(metrics["train_loss"])
epochs = list(range(1, n_epoch + 1))
fig, ax = plt.subplots()
fig.suptitle("Metics over time")
ax.set_xlim([0, n_epoch])
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ax.plot(epochs, metrics["train_loss"], c="#2980b9", label="Training loss")
ax.plot(epochs, metrics["valid_loss"], c="#2980b9", ls="--", label="Validation loss")
ax.axhline(metrics["best_valid_loss"], c="#3498db", ls=":", label="best validation loss")
ax2 = ax.twinx()
ax2.set_ylabel("Accuracy")
ax2.plot(epochs, metrics["train_acc"], c="#27ae60", label="Training accuracy")
ax2.plot(epochs, metrics["valid_acc"], c="#27ae60", ls="--", label="Validation accuracy")
fig.legend(loc=10)
<matplotlib.legend.Legend at 0x7f8db2180750>
Even with the regularization, overfitting is still a huge issue. Basically, the
model stop improving on the validation set after about 30 epochs. In fact, we
can tighten the early stopping criterion early_stop_plateau
by quite a lot. Of
course, these are all RNG bounded. It is possible that, next time, even 60 is
too short to make the early stop. However, always remember that even if the
model does improve after a prolonged plateau, the improvement is likely abysmal.
with torch.no_grad():
model.eval()
test_loss = 0
test_acc =0
predicted = np.empty(0)
true = np.empty(0)
for (test_protein1, test_protein2, targets) in test_dataloader:
test_protein1 = test_protein1.to(device)
test_protein2 = test_protein2.to(device)
targets = targets.to(device)
pred = model(test_protein1, test_protein2)
batch_loss = loss_func(pred, targets)
test_loss += batch_loss.cpu().numpy()
predicted_label = torch.argmax(pred, axis=1).cpu().numpy()
true_label = torch.argmax(targets, axis=1).cpu().numpy()
predicted = np.append(predicted, predicted_label)
true = np.append(true, true_label)
test_acc += (predicted_label == true_label).sum()
print(test_loss / len(test_dataloader), test_acc / len(test_data))
0.1437098534805827 0.9849129063228638
The two values printed out are the test loss and test accuracy. The loss is not so intuitive, but we have got 98.5% accuracy on the test set! This is how simple these sequence-based binary classification can actually be.
With the predicted labels, we can draw a confusion matrix.
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(predicted, true)
disp = ConfusionMatrixDisplay(cm)
disp.plot(colorbar=False)
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7f8db0c54850>
I did not bother to fine tune the heatmap, but we have a very balanced dataset, and very balanced predictions as well. This just makes sure we are not in the case where the dataset is very unbalanced and we have a high accuracy, but possibly very poor other metrics, like recall and precision.
Again the notebook itself is on GitHub and can be downloaded from here.
Machine Learning Neural Network Coding Python Torch