Week 03 Lab: Convolutional Networks

Student Name: YOUR NAME HERE

Predicting molecular properties

In this lab we will build networks that learn how to predict molecular properties from PDB structures.

As a test system we will use simulation-generated snapshots of \(AAQAA_5\). This peptide assumes partial helical structures. An example conformation is shown below:

image_with_pixels

Downloading data

We begin by downloading a set of conformations in PDB format for \(AAQAA_5\) that we will use for this tutorial:

aq15_pdbs.tgz

Instead of downloading it via the browswer you can use wget. That may be convenient for example when using Google Colab.

!wget -nc https://raw.githubusercontent.com/ADicksonLab/ml4md-jb/main/Week-03/aq15_pdbs.tgz

Once downloaded, we unpack the data:

!tar xzf aq15_pdbs.tgz

Preparing target data

As target data we will use different molecular properties. For training and validation, they are calculated based on conventional methods from the PDB structures. The goal of machine learning is then to learn models that can predict these properties from the input conformations.

Let’s begin by calculating the radius of gyration based on \(C\alpha\) atoms from the PDB files using mdtraj.

Note: mdtraj should be installed along with openmm. Install it now if you get an error message below.

import os.path
import mdtraj as md
import numpy as np

if os.path.isfile('aq15.rgyr.dat'):
    print('output file exists already, skipping')
else:
    rgout=open('aq15.rgyr.dat','w')

    ndata=8000
    for n in range(ndata):
        fname=f'aq15_pdbs/aq15.{n+1}.pdb'
        if ((n+1)%1000==0): 
            print(f'reading: {fname}')
    
        # reading PDB
        pdb=md.load_pdb(fname)
    
        # selecting CA atoms
        calist=pdb.topology.select("name CA")
        pdbca=pdb.atom_slice(calist)
    
        # calculating radius of gyration (output is in [nm])
        rg=md.compute_rg(pdbca)
    
        # writing output to file
        rgout.write(f'{n+1} {rg[0]}\n')

    rgout.close()

If the above code does not work or runs too slow, you can download the data file from here:

aq15.rgyr.dat

To check, we can make a histogram and plot the distribution of \(R_g\) values:

import matplotlib.pyplot as plt
import numpy as np

rgdata=np.loadtxt('aq15.rgyr.dat')
rgdata=rgdata[:,1]

rgmin=np.min(rgdata)
rgmax=np.max(rgdata)

bins=20
hist=np.zeros(bins+1)
delta=(rgmax-rgmin)/bins

for d in rgdata:
    inx=int((d-rgmin)/delta+0.5)
    hist[inx]+=1

hist/=len(rgdata)
    
xhist=np.linspace(rgmin,rgmax,bins+1)
plt.plot(xhist,hist,'r')
plt.show()

Let’s also download data files with pre-calculated solvation free energies (using PB and GB models) and pre-calculated solvent-accessible surface areas. We will need these files later.

aq15_pb.dat

aq15_gb.dat

aq15.sasa.dat

As before, you can use wget to download all of the files:

!wget -nc https://raw.githubusercontent.com/ADicksonLab/ml4md-jb/main/Week-03/aq15.pb.dat
!wget -nc https://raw.githubusercontent.com/ADicksonLab/ml4md-jb/main/Week-03/aq15.gb.dat
!wget -nc https://raw.githubusercontent.com/ADicksonLab/ml4md-jb/main/Week-03/aq15.sasa.dat

Preparing input features

The input data used here are the structures from the PDB files. However, there are different ways how such conformations can be ‘featurized’ as input to the machine learning model.

We will try the following features:

  • XYZ coordinates as is from the PDB

  • distance of each atom from the center of the molecule

  • pairwise distances for all pairs of atoms

To keep things simple and manageable we will only use \(C\alpha\) atoms for now. It is easy to modify the code below to consider other atoms as well.

Let’s get started!

Input feature: XYZ coordinates

import os.path
import mdtraj as md
import numpy as np

oname='aq15.input.xyz.npy'

if os.path.isfile(oname):
    print('output file exists already, skipping')
else:
    xyzlist=[]
    ndata=8000
    for n in range(ndata):
        fname=f'aq15_pdbs/aq15.{n+1}.pdb'
        if ((n+1)%1000==0): 
            print(f'reading: {fname}')
    
        pdb=md.load_pdb(fname)
    
        x=y=z=[]
        calist=pdb.topology.select("name CA")
        for ca in calist:
            xyz=pdb.xyz[0][ca]
            x=np.append(x,xyz[0])
            y=np.append(y,xyz[1])
            z=np.append(z,xyz[2])
        xyzlist=np.append(xyzlist,[x,y,z])
         
    data=np.reshape(xyzlist,(ndata,3,-1))
    np.save(oname,data)

Input feature: Distance from center

import os.path
import mdtraj as md
import numpy as np

oname='aq15.input.distance_from_center.npy'

if os.path.isfile(oname):
    print('output file exists already, skipping')
else:
    d=[]
    ndata=8000
    for n in range(ndata):
        fname=f'aq15_pdbs/aq15.{n+1}.pdb'
        if ((n+1)%1000==0): 
            print(f'reading: {fname}')
    
        pdb=md.load_pdb(fname)
        pdb.center_coordinates()
        
        calist=pdb.topology.select("name CA")
        for ca in calist:
            d=np.append(d,np.linalg.norm(pdb.xyz[0][ca]))
                                
    data=np.reshape(d,(ndata,1,-1))
    np.save(oname,data)

Input feature: Atom-atom distances

import os.path
import mdtraj as md
import numpy as np

oname='aq15.input.atom_atom.npy'

if os.path.isfile(oname):
    print('output file exists already, skipping')
else:
    d=[]
    ndata=8000
    for n in range(ndata):
        fname=f'aq15_pdbs/aq15.{n+1}.pdb'
        if ((n+1)%1000==0): 
            print(f'reading: {fname}')
    
        pdb=md.load_pdb(fname)
        calist=pdb.topology.select("name CA")
        
        pairs=[]
        for cai in calist:
            for caj in calist:
                pairs=np.append(pairs,[cai,caj])
        pairs=np.reshape(pairs,(-1,2))
        
        distances=md.compute_distances(pdb,pairs)
        d=np.append(d,distances)
                                
    data=np.reshape(d,(ndata,1,len(calist),-1))
    np.save(oname,data)

Setting up the Data Loader

We are now ready to set up the data functions and classes for machine learning.

We will use the following function to read two files (one for the target data and one for a set of input features) and generate merged data sets for training and validation.

import random
import numpy as np
import torch

class Dataset(torch.utils.data.Dataset):
# optional parameters allow target data to be shifted and scaled
    def __init__(self, target, data, offset=0.0, scale=1.0):
        self.label = (target[:,1].astype(np.float32)+offset)/scale
# assumes that data is prepared in correct shape beforehand
        self.input = data
    def __len__(self):
        return self.label.shape[0]
    def __getitem__(self, index):
        return self.input[index].astype(np.float32), self.label[index]
    
def randomsplitdata(targetfn,inputfn,training_fraction):
    targetdata=np.loadtxt(targetfn)    
    inputdata=np.load(inputfn)
    
    flag=np.zeros(len(targetdata),dtype=int)
    while np.average(flag)<training_fraction:
        flag[random.randint(0,len(targetdata)-1)]=1
    
    target_training=targetdata[np.nonzero(flag)].copy()
    target_validation=targetdata[np.nonzero(1-flag)].copy()
    input_training=inputdata[np.nonzero(flag)].copy()
    input_validation=inputdata[np.nonzero(1-flag)].copy()
        
    return target_training,input_training,target_validation,input_validation

# use optional parameters to shift and scale data if used above for Dataset class
def get_loaders(targetfn,inputfn,training_fraction,batch_size=64,offset=0.0,scale=1.0):
    [ttarget,tinput,vtarget,vinput]=randomsplitdata(targetfn,inputfn,training_fraction) 
    train_set=Dataset(ttarget,tinput,offset=offset,scale=scale)
    validation_set=Dataset(vtarget,vinput,offset=offset,scale=scale)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=1)
    return train_loader,validation_loader

Setting up Training Functions

As before we set up training functions.

import numpy as np

# klfactor allows KL divergence to be used as additional loss
# use factors of 0.01-0.3 to obtain useful results
def train(m,loss_fn,opt,loader,klfactor=0.0):
    klloss=nn.KLDivLoss(reduction='batchmean')
    
    loss_sum = 0.0
    for input, label in loader:
        opt.zero_grad()
        
        output = m(input)              # this is where the model is evaluated
        output = torch.flatten(output)
        
        loss = loss_fn(output, label)  # this is where the model is evaluated
        loss_sum += loss.item()        # accumulate MSE loss
        
        if (klfactor>0):               # additional KL loss if requested
            loss=loss+klfactor*klloss(output,label)                         
                
        loss.backward()                # this calculates the back-propagated loss
        opt.step()                     # this carries out the gradient descent
    
    return loss_sum / len(loader)      # Note: KL loss is not included in reported loss

def validate(m,loss_fn,loader):
    loss_sum = 0.0
    for input, label in loader:
        with torch.no_grad():
            output = m(input)
        
        output=torch.flatten(output)
        loss = loss_fn(output, label)
        loss_sum += loss.item()
    return loss_sum / len(loader)

# klfactor allows KL divergence to be used as additional loss
# use factors of 0.01-0.3 to obtain useful results
def do_training(m,opt,tloader,vloader,epochs,output,klfactor=0.0):
    # use MSE loss fucntion
    loss_fn = nn.MSELoss()
    
    tloss=np.zeros(epochs)
    vloss=np.zeros(epochs)

    for i in range(epochs):
        tloss[i] = train(m,loss_fn,opt,tloader,klfactor=klfactor)
        vloss[i] = validate(m,loss_fn,vloader)
        if (output):
            print (i, tloss[i], vloss[i])
            
    return tloss,vloss

We also set up plotting functions, including a calculation of linear regression for the analysis of results later.

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

def plot_progress(epochs,tloss,vloss):
    epoch_index=np.arange(epochs)
    plt.plot(epoch_index,np.log(tloss),color='r',label='training')
    plt.plot(epoch_index,np.log(vloss),color='b',label='validation')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.show()

# specify offset/scale if used with DataSet
def plot_validation(loader,m,offset=0.0,scale=1.0):    
    target=[]
    prediction=[]

    for input, label in loader:        
        with torch.no_grad():
            output = m(input)
        output=torch.flatten(output)        
        target=np.append(target,label)
        prediction=np.append(prediction,output)
        
    target=target*scale-offset
    prediction=prediction*scale-offset
    
    minval=np.min(target)
    maxval=np.max(target)
    lin=np.linspace(minval-0.1*(maxval-minval),maxval+0.1*(maxval-minval),num=100)
    
    plt.plot(lin,lin,'k',linewidth=2)
    plt.plot(target,prediction,'ro',markersize=2)
    plt.xlabel('target')
    plt.ylabel('prediction')
    plt.show()
    
    x=target.reshape((-1,1))
    y=prediction
    
    linmodel=LinearRegression().fit(x,y)
    
    r2=linmodel.score(x,y)
    mval=linmodel.coef_[0]
    nval=linmodel.intercept_
    print(f'r2: {r2} slope: {mval} intercept: {nval}')    
    

Building Models

We now define models with convolutional layers.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ModelFC(nn.Module):
    def __init__(self):
        super(ModelFC, self).__init__()
        # define layers to be used
        self.fc_1 = nn.Linear(45,128)       
        self.fc_2 = nn.Linear(128, 64)         
        self.fc_f = nn.Linear(64, 1)           
    def forward(self, x):
        # back-propagation is done automatically
        x = x.reshape(len(x),-1)
        #print(x.size())
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x)) 
        x = self.fc_f(x)         
        return x
    def initialize_weights(self, m):
        # initialization of weights, setting them to zero is not good
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.xavier_uniform_(m.weight.data)

class Model1D(nn.Module):
    def __init__(self):
        super(Model1D, self).__init__()
        # define layers to be used
        self.conv_1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        self.conv_f = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        # dimensional flattening
        self.flatten = nn.Flatten(start_dim=1) 
        # fully connected layers
        self.fc_1 = nn.Linear(240,128)       
        self.fc_2 = nn.Linear(128, 64)         
        self.fc_f = nn.Linear(64, 1)           
    def forward(self, x):
        # back-propagation is done automatically
        x = self.conv_1(x)
        x = F.relu(self.conv_2(x))
        x = F.relu(self.conv_f(x))
        x = self.flatten(x)
        #print(x.size())
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x)) 
        x = self.fc_f(x)         
        return x
    def initialize_weights(self, m):
        # initialization of weights, setting them to zero is not good
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.xavier_uniform_(m.weight.data)
            
class Model1D3(nn.Module):
    def __init__(self):
        super(Model1D3, self).__init__()
        # define layers to be used
        self.conv_1 = nn.Conv1d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        self.conv_f = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        # dimensional flattening
        self.flatten = nn.Flatten(start_dim=1) 
        # fully connected layers
        self.fc_1 = nn.Linear(240,128)       
        self.fc_2 = nn.Linear(128, 64)         
        self.fc_f = nn.Linear(64, 1)           
    def forward(self, x):
        # back-propagation is done automatically
        x = self.conv_1(x)
        x = F.relu(self.conv_2(x))
        x = F.relu(self.conv_f(x))
        x = self.flatten(x)
        #print(x.size())
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x)) 
        x = self.fc_f(x)         
        return x
    def initialize_weights(self, m):
        # initialization of weights, setting them to zero is not good
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.xavier_uniform_(m.weight.data)
            
class Model2D(nn.Module):
    def __init__(self):
        super(Model2D, self).__init__()
        # define layers to be used
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        self.conv_f = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        # dimensional flattening
        self.flatten = nn.Flatten(start_dim=1) 
        # fully connected layers
        self.fc_1 = nn.Linear(3600,128)       
        self.fc_2 = nn.Linear(128, 64)         
        self.fc_f = nn.Linear(64, 1)           
    def forward(self, x):
        # back-propagation is done automatically
        x = self.conv_1(x)
        x = F.relu(self.conv_2(x))
        x = F.relu(self.conv_f(x))
        x = self.flatten(x)
        #print(x.size())
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x)) 
        x = self.fc_f(x)         
        return x
    def initialize_weights(self, m):
        # initialization of weights, setting them to zero is not good
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.xavier_uniform_(m.weight.data)

Training models

We will start training models to predict the radius of gyration. Let’s begin with a simple, fully-connected model and use XYZ coordinates as the input.

We will use the class ModelFC that has been set up with the correct dimensions to work with XYZ data for this data set. For other input data, it would need to be adjusted.

[tloader,vloader]=get_loaders('aq15.rgyr.dat','aq15.input.xyz.npy',0.8) 

m = ModelFC()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m)

That’s pretty good. Now let’s try with a convolutional network.

We use the Model1D3 class that uses 1D convolutions for data with a depth of 3. The data is arranged as 1D data because there is only one set of coordinates for each \(C\alpha\) atom and the x/y/z coordinate values are used as the depth layers (like RGB values for an image).

The data could be arranged differently, for example as (N,3) 2D data with an initial layer depth of 1, or as N*3 1D data also with an initial layer depth of 1.

[tloader,vloader]=get_loaders('aq15.rgyr.dat','aq15.input.xyz.npy',0.8) 

m = Model1D3()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m)

Probably you will find that this performs better.

Let’s explore using different input data features.

We will try next using the distance from the center. We will use Model1D which calculates 1D covolutions as above but for input data that has a depth of 1 (instead of 3 for the XYZ data).

[tloader,vloader]=get_loaders('aq15.rgyr.dat','aq15.input.distance_from_center.npy',0.8) 

m = Model1D()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m)

This should look even better.

Finally, let’s use the intramolecular distance matrix.

For this we use Model2D that used 2D convolutions for input data that is 2D with a depth of 1:

[tloader,vloader]=get_loaders('aq15.rgyr.dat','aq15.input.atom_atom.npy',0.8) 

m = Model2D()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m)

This should also work well, but training the model was probably a bit slower since the model is larger and more data has to be used for training.

TODO: Based on the loss curves, when should have stopped the training?

TODO: Discuss why different input features give different results:

  • Why are the results with the XYZ coordinates not as good?

  • Why can you fit a very good model with just the distances from the center?

Training models to represent more complex properties

Now that we have some initial experience with convolutational networks, we will train models to predict more interesting features where a ML model would be of greater practical value.

We will focus on electrostatic solvation free energies. An accurate calculation requires solving the Poisson equation which is numerically demanding and a ML model would offer practical advantages.

We can reuse the model classes from above since only the target data is different.

Let’s start with using the xyz coordinates again as input:

[tloader,vloader]=get_loaders('aq15.pb.dat','aq15.input.xyz.npy',0.8) 

m = Model1D3()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m)

Probably you are disappointed by the poor performance. But how bad is it?

Let’s compare with what can be achieved with a highly optimized Generalized Born model for the rapid calculation of approximate electrostatic free energies based on physics.

pbdata=np.loadtxt('aq15.pb.dat')
pbdata=pbdata[:,1]

gbdata=np.loadtxt('aq15.gb.dat')
gbdata=gbdata[:,1]          
    
minval=np.min(pbdata)
maxval=np.max(pbdata)
lin=np.linspace(minval-0.1*(maxval-minval),maxval+0.1*(maxval-minval),num=100)
    
plt.plot(lin,lin,'k',linewidth=2)
plt.plot(pbdata,gbdata,'ro',markersize=2)
plt.xlabel('PB')
plt.ylabel('GB')
plt.show()
    
x=pbdata.reshape((-1,1))
y=gbdata
    
linmodel=LinearRegression().fit(x,y)
    
r2=linmodel.score(x,y)
mval=linmodel.coef_[0]
nval=linmodel.intercept_
print(f'r2: {r2} slope: {mval} intercept: {nval}')    

As you can see, the GB model performs better, but it is not perfect.

The GB model is our baseline here. Can we reach the same performance (or do better?) with a ML model?

Tuning the model

To improve the model, we will try a number of different things.

First, for numerical reasons, ML predictions work best for (output) data values around 1.

Let’s see if we can improve the model by shifting and scaling the data. This does not make any practical difference, because we can simply reverse the scaling and shift after the final output is obtained.

[tloader,vloader]=get_loaders('aq15.pb.dat','aq15.input.xyz.npy',0.8,offset=200,scale=100) 

m = Model1D3()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m,offset=200,scale=100)

This should make a (small) difference. But the performance is still not very good.

Can we improve the prediction by using different input data?

TODO: Try training the model again using distances from the center. Remember to change the model class.

TODO: What about using the intramolecular distances as input? Again, you will need a different Model class.

The correlation probably improved when using the atom-atom distances, but there is a practical issue: The calculation of an intra-molecular distance is computationally a bit expensive because we have to evaluate all pairwise distances, which is an O(N*N) operation. If we want to have a very fast model, it would be better to use input features that can be calculated faster.

TODO: Reconsider using XYZ coordinates, but think about how we could change the coordinates to make training more successful. Think about centering and/or rotating coordinates with respect to a reference. For this you will need to generate new input data and then rerun the model training. Hint: Check the mdtraj functions for centering and superposing a molecule!

Depending on the results you get decide whether to continue with the atom-atom distances or the updated XYZ coordinates.

What is a good model? Improving the loss.

In all cases above you should notice that predictions are biases at the most negative solvation free energies. As a result, the slope in all of the predictions is less than 1.

One way to address this issue is to change the loss function. So far we have only used MSE (mean-squared error) as the loss. We can add a second loss function based on Kullback-Leibler divergence which compares the distributions of input and output values.

The KL loss function is already included in the training calculation above but it needs to be activated by giving it a non-zero weight.

The following example below shows how to use for the XYZ input data set:

[tloader,vloader]=get_loaders('aq15.pb.dat','aq15.input.xyz.npy',0.8,offset=200,scale=100) 

m = Model1D3()
m.apply(m.initialize_weights)
m.zero_grad()

opt = torch.optim.Adam(m.parameters(), lr=0.001, weight_decay=0.0000001)

epochs=30
showoutput=True

[tloss,vloss]=do_training(m,opt,tloader,vloader,epochs,showoutput,klfactor=0.4)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m,offset=200,scale=100)

You should find that the slope improves but the MSE loss (that is printed out) may be worse, especially if you use a large weight factor for the KL loss. One consequence is that will find that the output data has an offset, but that could be corrected for by applying a constant offset at the end.

TODO: Experiment to find good values that improve the slope but still give a good model with low MSE and high R2 values.

Going deeper

The next step for improving the model is to try a deeper model with more (convolutional) layers and increased depths for each layer.

TODO: Design a deep version of either Model1D3 (used with updated XYZ coordinates) or Model2D (used with atom-atom distance) and retrain using your new model.

How much better is your model? Does it approach the performance of the GB model yet?

Increasing details

Probably you are now approaching the maximal accuracy in predicting solvation free energies with a neural network based on the given input data (i.e. \(C\alpha\) coordinates).

It may be possible to do a little bit better with a different type of model (i.e. a graph neural network) but the more promising direction for further improvements is to increase the detail in the input data.

Only using \(C\alpha\) coordinates neglects the side chains and while the model may in principle be able to learn indirectly where side chains are located, there is significant uncertainty associated with that, especially for extended chains.

TODO (optional): Let’s see if we can improve the model further if we include additional atoms. To do that you will need to generate new data, for example XYZ coordinates for all heavy atoms, and then retrain the best model so far to see how much the performance can be improved.

Reusing models

You should now have a well-trained model that predicts solvation free energies reasonably well. In this model, the convolutional layers encode the structural information, while the linear layers are trained to predict the desired output (i.e. the electrostatic solvation free energies or radii of gyration at the beginning).

Let’s see if we can re-use the convolutional part and just re-train the output layers to target different properties.

Assuming that your last model is stored in m, let’s begin by ‘saving the model’ (i.e. the optimized weights):

torch.save(m.state_dict(),'greatmodel.dict')

Then we will create an instance of a new model (using the same Model class), read the saved model, turn off optimization for the convolutational layers and retrain it to target SASA (solvent-accessible surface area) values.

# this assumes that you defined a deep 1D-model for XYZ coordinates that were are reusing
# adjust if your Model class is different

msasa = Model1D3deep()
msasa.load_state_dict(torch.load('greatmodel.dict'))

# turn off optimization for all layers
for param in msasa.parameters():
    param.requires_grad = False

# turn optimization back on for linear layers
for param in msasa.fc_1.parameters():
    param.requires_grad = True
for param in msasa.fc_2.parameters():
    param.requires_grad = True
for param in msasa.fc_f.parameters():
    param.requires_grad = True

# tell the optimizer to only optimize linear layers (where requires_grad = True)
optsasa = torch.optim.Adam(filter(lambda p: p.requires_grad, msasa.parameters()), lr=0.001, weight_decay=0.0000001)

# now load SASA data, offset/scale is not needed and train
[tloader,vloader]=get_loaders('aq15.sasa.dat','aq15.input.xyz.npy',0.8) 

epochs=30
showoutput=True

# make sure to use 'msasa' and 'optsasa'
[tloss,vloss]=do_training(msasa,optsasa,tloader,vloader,epochs,showoutput)

plot_progress(epochs,tloss,vloss)
plot_validation(vloader,m)

TODO: For comparison, also retrain another model as before, but without reusing any of the pretrained weights.

Does it make a big difference?

TODO (optional): mdtraj has a function for calculating SASA (shake_rupley) that uses a different algorithm than what was used for the SASA values provided in this exercise. How does the the error in the ML predictions compare to the differences between the provided and values calculated via mdtraj?

Final thoughts

TODO: The models you have trained probably work fairly well by now and you could easily train new models for new properties, such as other kinds of energies, moments of inertia, rotational diffusion etc. However, there are serious limitations with the approach so far in terms of transferability. Discuss the following questions:

  • What practical applications could you imagine with the model(s) developed so far?

  • In what way is transferability to other applications limited?

  • How could transferability issues be overcome?