vggish

VGGish neural network

    Description

    example

    net = vggish returns a pretrained VGGish model.

    This function requires both Audio Toolbox™ and Deep Learning Toolbox™.

    Examples

    collapse all

    Download and unzip the Audio Toolbox™ model for VGGish.

    Type vggish at the Command Window. If the Audio Toolbox model for VGGish is not installed, then the function provides a link to the location of the network weights. To download the model, click the link. Unzip the file to a location on the MATLAB path.

    Alternatively, execute these commands to download and unzip the VGGish model to your temporary directory.

    downloadFolder = fullfile(tempdir,'VGGishDownload');
    loc = websave(downloadFolder,'https://ssd.mathworks.com/supportfiles/audio/vggish.zip');
    VGGishLocation = tempdir;
    unzip(loc,VGGishLocation)
    addpath(fullfile(VGGishLocation,'vggish'))

    Check that the installation is successful by typing vggish at the Command Window. If the network is installed, then the function returns a SeriesNetwork (Deep Learning Toolbox) object.

    vggish
    ans = 
      SeriesNetwork with properties:
    
             Layers: [24×1 nnet.cnn.layer.Layer]
         InputNames: {'InputBatch'}
        OutputNames: {'regressionoutput'}
    
    

    Load a pretrained VGGish convolutional neural network and examine the layers and classes.

    Use vggish to load the pretrained VGGish network. The output net is a SeriesNetwork (Deep Learning Toolbox) object.

    net = vggish
    net = 
      SeriesNetwork with properties:
    
             Layers: [24×1 nnet.cnn.layer.Layer]
         InputNames: {'InputBatch'}
        OutputNames: {'regressionoutput'}
    
    

    View the network architecture using the Layers property. The network has 24 layers. There are nine layers with learnable weights, of which six are convolutional layers and three are fully connected layers.

    net.Layers
    ans = 
      24×1 Layer array with layers:
    
         1   'InputBatch'         Image Input         96×64×1 images
         2   'conv1'              Convolution         64 3×3×1 convolutions with stride [1  1] and padding 'same'
         3   'relu'               ReLU                ReLU
         4   'pool1'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
         5   'conv2'              Convolution         128 3×3×64 convolutions with stride [1  1] and padding 'same'
         6   'relu2'              ReLU                ReLU
         7   'pool2'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
         8   'conv3_1'            Convolution         256 3×3×128 convolutions with stride [1  1] and padding 'same'
         9   'relu3_1'            ReLU                ReLU
        10   'conv3_2'            Convolution         256 3×3×256 convolutions with stride [1  1] and padding 'same'
        11   'relu3_2'            ReLU                ReLU
        12   'pool3'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
        13   'conv4_1'            Convolution         512 3×3×256 convolutions with stride [1  1] and padding 'same'
        14   'relu4_1'            ReLU                ReLU
        15   'conv4_2'            Convolution         512 3×3×512 convolutions with stride [1  1] and padding 'same'
        16   'relu4_2'            ReLU                ReLU
        17   'pool4'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
        18   'fc1_1'              Fully Connected     4096 fully connected layer
        19   'relu5_1'            ReLU                ReLU
        20   'fc1_2'              Fully Connected     4096 fully connected layer
        21   'relu5_2'            ReLU                ReLU
        22   'fc2'                Fully Connected     128 fully connected layer
        23   'EmbeddingBatch'     ReLU                ReLU
        24   'regressionoutput'   Regression Output   mean-squared-error
    

    Use analyzeNetwork (Deep Learning Toolbox) to visually explore the network.

    analyzeNetwork(net)

    The VGGish network requires you to preprocess and extract features from audio signals by converting them to the sample rate the network was trained on, and then extracting log mel spectrograms. This example walks through the required preprocessing and feature extraction to match the preprocessing and feature extraction used to train VGGish. The vggishFeatures function performs these steps for you.

    Read in an audio signal to classify. Resample the audio signal to 16 kHz and then convert it to single precision.

    [audioIn,fs0] = audioread('Ambiance-16-44p1-mono-12secs.wav');
    
    fs = 16e3;
    audioIn = resample(audioIn,fs,fs0);
    
    audioIn = single(audioIn);

    Define mel spectrogram parameters and then extract features using the melSpectrogram function.

    FFTLength = 512;
    numBands = 64;
    frequencyRange = [125 7500];
    windowLength = 0.025*fs;
    overlapLength = 0.015*fs;
    
    melSpect = melSpectrogram(audioIn,fs, ...
        'Window',hann(windowLength,'periodic'), ...
        'OverlapLength',overlapLength, ...
        'FFTLength',FFTLength, ...
        'FrequencyRange',frequencyRange, ...
        'NumBands',numBands, ...
        'FilterBankNormalization','none', ...
        'WindowNormalization',false, ...
        'SpectrumType','magnitude', ...
        'FilterBankDesignDomain','warped');

    Convert the mel spectrogram to the log scale.

    melSpect = log(melSpect + single(0.001));

    Reorient the mel spectrogram so that time is along the first dimension as rows.

    melSpect = melSpect.';
    [numSTFTWindows,numBands] = size(melSpect)
    numSTFTWindows = 1222
    
    numBands = 64
    

    Partition the spectrogram into frames of length 96 with an overlap of 48. Place the frames along the fourth dimension.

    frameWindowLength = 96;
    frameOverlapLength = 48;
    
    hopLength = frameWindowLength - frameOverlapLength;
    numHops = floor((numSTFTWindows - frameWindowLength)/hopLength) + 1;
    
    frames = zeros(frameWindowLength,numBands,1,numHops,'like',melSpect);
    for hop = 1:numHops
        range = 1 + hopLength*(hop-1):hopLength*(hop - 1) + frameWindowLength;
        frames(:,:,1,hop) = melSpect(range,:);
    end

    Create a VGGish network.

    net = vggish;

    Call predict to extract feature embeddings from the spectrogram images. The feature embeddings are returned as a numFrames-by-128 matrix, where numFrames is the number of individual spectrograms, and 128 is the number of elements in each feature vector.

    features = predict(net,frames);
    
    [numFrames,numFeatures] = size(features)
    numFrames = 24
    
    numFeatures = 128
    

    Compare visualizations of the mel spectrogram and the VGGish feature embeddings.

    melSpectrogram(audioIn,fs, ...
        'Window',hann(windowLength,'periodic'), ...
        'OverlapLength',overlapLength, ...
        'FFTLength',FFTLength, ...
        'FrequencyRange',frequencyRange, ...
        'NumBands',numBands, ...
        'FilterBankNormalization','none', ...
        'WindowNormalization',false, ...
        'SpectrumType','magnitude', ...
        'FilterBankDesignDomain','warped');

    surf(features,'EdgeColor','none')
    view([90,-90])
    axis([1 numFeatures 1 numFrames])
    xlabel('Feature')
    ylabel('Frame')
    title('VGGish Feature Embeddings')

    In this example, you transfer the learning in the VGGish regression model to an audio classification task.

    Download and unzip the environmental sound classification data set. This data set consists of recordings labeled as one of 10 different audio sound classes (ESC-10).

    url = 'http://ssd.mathworks.com/supportfiles/audio/ESC-10.zip';
    downloadFolder = fullfile(tempdir,'ESC-10');
    datasetLocation = tempdir;
    
    if ~exist(fullfile(tempdir,'ESC-10'),'dir')
        loc = websave(downloadFolder,url);
        unzip(loc,fullfile(tempdir,'ESC-10'))
    end

    Create an audioDatastore object to manage the data and split it into train and validation sets. Call countEachLabel to display the distribution of sound classes and the number of unique labels.

    ads = audioDatastore(downloadFolder,'IncludeSubfolders',true,'LabelSource','foldernames');
    labelTable = countEachLabel(ads)
    labelTable=10×2 table
            Label         Count
        ______________    _____
    
        chainsaw           40  
        clock_tick         40  
        crackling_fire     40  
        crying_baby        40  
        dog                40  
        helicopter         40  
        rain               40  
        rooster            38  
        sea_waves          40  
        sneezing           40  
    
    

    Determine the total number of classes.

    numClasses = size(labelTable,1);

    Call splitEachLabel to split the data set into training and validation sets. Inspect the distribution of labels in the training and validation sets.

    [adsTrain, adsValidation] = splitEachLabel(ads,0.8);
    
    countEachLabel(adsTrain)
    ans=10×2 table
            Label         Count
        ______________    _____
    
        chainsaw           32  
        clock_tick         32  
        crackling_fire     32  
        crying_baby        32  
        dog                32  
        helicopter         32  
        rain               32  
        rooster            30  
        sea_waves          32  
        sneezing           32  
    
    
    countEachLabel(adsValidation)
    ans=10×2 table
            Label         Count
        ______________    _____
    
        chainsaw            8  
        clock_tick          8  
        crackling_fire      8  
        crying_baby         8  
        dog                 8  
        helicopter          8  
        rain                8  
        rooster             8  
        sea_waves           8  
        sneezing            8  
    
    

    The VGGish network expects audio to be preprocessed into log mel spectrograms. The supporting function vggishPreprocess takes an audioDatastore object and the overlap percentage between log mel spectrograms as input, and returns matrices of predictors and responses suitable as input to the VGGish network.

    overlapPercentage = 75;
    
    [trainFeatures,trainLabels] = vggishPreprocess(adsTrain,overlapPercentage);
    [validationFeatures,validationLabels,segmentsPerFile] = vggishPreprocess(adsValidation,overlapPercentage);

    Load the VGGish model and convert it to a layerGraph (Deep Learning Toolbox) object.

    net = vggish;
    
    lgraph = layerGraph(net.Layers);

    Use removeLayers (Deep Learning Toolbox) to remove the final regression output layer from the graph. After you remove the regression layer, the new final layer of the graph is a ReLU layer named 'EmbeddingBatch'.

    lgraph = removeLayers(lgraph,'regressionoutput');
    lgraph.Layers(end)
    ans = 
      ReLULayer with properties:
    
        Name: 'EmbeddingBatch'
    
    

    Use addLayers (Deep Learning Toolbox) to add a fullyConnectedLayer (Deep Learning Toolbox), a softmaxLayer (Deep Learning Toolbox), and a classificationLayer (Deep Learning Toolbox) to the graph.

    lgraph = addLayers(lgraph,fullyConnectedLayer(numClasses,'Name','FCFinal'));
    lgraph = addLayers(lgraph,softmaxLayer('Name','softmax'));
    lgraph = addLayers(lgraph,classificationLayer('Name','classOut'));

    Use connectLayers (Deep Learning Toolbox) to append the fully connected, softmax, and classification layers to the layer graph.

    lgraph = connectLayers(lgraph,'EmbeddingBatch','FCFinal');
    lgraph = connectLayers(lgraph,'FCFinal','softmax');
    lgraph = connectLayers(lgraph,'softmax','classOut');

    To define training options, use trainingOptions (Deep Learning Toolbox).

    miniBatchSize = 128;
    options = trainingOptions('adam', ...
        'MaxEpochs',5, ...
        'MiniBatchSize',miniBatchSize, ...
        'Shuffle','every-epoch', ...
        'ValidationData',{validationFeatures,validationLabels}, ...
        'ValidationFrequency',50, ...
        'LearnRateSchedule','piecewise', ...
        'LearnRateDropFactor',0.5, ...
        'LearnRateDropPeriod',2);

    To train the network, use trainNetwork (Deep Learning Toolbox).

    [trainedNet, netInfo] = trainNetwork(trainFeatures,trainLabels,lgraph,options);
    Training on single GPU.
    |======================================================================================================================|
    |  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
    |         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
    |======================================================================================================================|
    |       1 |           1 |       00:00:00 |       10.94% |       26.03% |       2.2253 |       2.0317 |          0.0010 |
    |       2 |          50 |       00:00:05 |       93.75% |       83.75% |       0.1884 |       0.7001 |          0.0010 |
    |       3 |         100 |       00:00:10 |       96.88% |       80.07% |       0.1150 |       0.7838 |          0.0005 |
    |       4 |         150 |       00:00:15 |       92.97% |       81.99% |       0.1656 |       0.7612 |          0.0005 |
    |       5 |         200 |       00:00:20 |       92.19% |       79.04% |       0.1738 |       0.8192 |          0.0003 |
    |       5 |         210 |       00:00:21 |       95.31% |       80.15% |       0.1389 |       0.8581 |          0.0003 |
    |======================================================================================================================|
    

    Each audio file was split into several segments to feed into the VGGish network. Combine the predictions for each file in the validation set using a majority-rule decision.

    validationPredictions = classify(trainedNet,validationFeatures);
    
    idx = 1;
    validationPredictionsPerFile = categorical;
    for ii = 1:numel(adsValidation.Files)
        validationPredictionsPerFile(ii,1) = mode(validationPredictions(idx:idx+segmentsPerFile(ii)-1));
        idx = idx + segmentsPerFile(ii);
    end

    Use confusionchart (Deep Learning Toolbox) to evaluate the performance of the network on the validation set.

    figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
    cm = confusionchart(adsValidation.Labels,validationPredictionsPerFile);
    cm.Title = sprintf('Confusion Matrix for Validation Data \nAccuracy = %0.2f %%',mean(validationPredictionsPerFile==adsValidation.Labels)*100);
    cm.ColumnSummary = 'column-normalized';
    cm.RowSummary = 'row-normalized';

    Supporting Functions

    function [predictor,response,segmentsPerFile] = vggishPreprocess(ads,overlap)
    % This function is for example purposes only and may be changed or removed
    % in a future release.
    
    % Create filter bank
    FFTLength = 512;
    numBands = 64;
    fs0 = 16e3;
    filterBank = designAuditoryFilterBank(fs0, ...
        'FrequencyScale','mel', ...
        'FFTLength',FFTLength, ...
        'FrequencyRange',[125 7500], ...
        'NumBands',numBands, ...
        'Normalization','none', ...
        'FilterBankDesignDomain','warped');
    
    % Define STFT parameters
    windowLength = 0.025 * fs0;
    hopLength = 0.01 * fs0;
    win = hann(windowLength,'periodic');
    
    % Define spectrogram segmentation parameters
    segmentDuration = 0.96; % seconds
    segmentRate = 100; % hertz
    segmentLength = segmentDuration*segmentRate; % Number of spectrums per auditory spectrograms
    segmentHopDuration = (100-overlap) * segmentDuration / 100; % Duration (s) advanced between auditory spectrograms
    segmentHopLength = round(segmentHopDuration * segmentRate); % Number of spectrums advanced between auditory spectrograms
    
    % Preallocate cell arrays for the predictors and responses
    numFiles = numel(ads.Files);
    predictor = cell(numFiles,1);
    response = predictor;
    segmentsPerFile = zeros(numFiles,1);
    
    % Extract predictors and responses for each file
    for ii = 1:numFiles
        [audioIn,info] = read(ads);
    
        x = single(resample(audioIn,fs0,info.SampleRate));
    
        Y = stft(x, ...
            'Window',win, ...
            'OverlapLength',windowLength-hopLength, ...
            'FFTLength',FFTLength, ...
            'FrequencyRange','onesided');
        Y = abs(Y);
    
        logMelSpectrogram = log(filterBank*Y + single(0.01))';
        
        % Segment log-mel spectrogram
        numHops = floor((size(Y,2)-segmentLength)/segmentHopLength) + 1;
        segmentedLogMelSpectrogram = zeros(segmentLength,numBands,1,numHops);
        for hop = 1:numHops
            segmentedLogMelSpectrogram(:,:,1,hop) = logMelSpectrogram(1+segmentHopLength*(hop-1):segmentLength+segmentHopLength*(hop-1),:);
        end
    
        predictor{ii} = segmentedLogMelSpectrogram;
        response{ii} = repelem(info.Label,numHops);
        segmentsPerFile(ii) = numHops;
    end
    
    % Concatenate predictors and responses into arrays
    predictor = cat(4,predictor{:});
    response = cat(2,response{:});
    end

    Output Arguments

    collapse all

    Pretrained VGGish neural network, returned as a SeriesNetwork (Deep Learning Toolbox) object.

    References

    [1] Gemmeke, Jort F., Daniel P. W. Ellis, Dylan Freedman, Aren Jansen, Wade Lawrence, R. Channing Moore, Manoj Plakal, and Marvin Ritter. 2017. “Audio Set: An Ontology and Human-Labeled Dataset for Audio Events.” In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 776–80. New Orleans, LA: IEEE. https://doi.org/10.1109/ICASSP.2017.7952261.

    [2] Hershey, Shawn, Sourish Chaudhuri, Daniel P. W. Ellis, Jort F. Gemmeke, Aren Jansen, R. Channing Moore, Manoj Plakal, et al. 2017. “CNN Architectures for Large-Scale Audio Classification.” In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 131–35. New Orleans, LA: IEEE. https://doi.org/10.1109/ICASSP.2017.7952132.

    Extended Capabilities

    Introduced in R2020b