Quantifying Classification Uncertainty in Deep Neural Networks

The purpose of this page is to provide an easy-to-run demo with low computational requirements for the ideas proposed in the paper Evidential Deep Learning to Quantify Classification Uncertainty. Using MNIST dataset, I demonstrate how to create neural networks that are able to quantify classification uncertainty. The paper can be accesed over http://arxiv.org/abs/1806.01768

You can run this notebook in Colab using the colab icon below:

Open In Colab

The notebook can also be downloaded using https://muratsensoy.github.io/uncertainty.ipynb

Neural Networks Trained with Softmax Cross Entropy Loss

The following lines of codes demonstrate how softmax based Deep Neural Networks fail when they encounter out-of-sample queries.

In [ ]:
# use this while running this notebook in Colab
%tensorflow_version 1.x
In [5]:
#import necessary libraries
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import scipy.ndimage as nd

%matplotlib inline
import pylab as pl
from IPython import display

from tensorflow.examples.tutorials.mnist import input_data
In [5]:
# Download MNIST dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

K= 10 # number of classes
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [6]:
digit_one = mnist.train.images[4].copy()
plt.imshow(digit_one.reshape(28,28)) 
plt.show()
In [7]:
# define some utility functions
def var(name, shape, init=None):
    if init is None:
        init = tf.truncated_normal_initializer(stddev=(2/shape[0])**0.5)
    return tf.get_variable(name=name, shape=shape, dtype=tf.float32,
                          initializer=init)

def conv(Xin, f, strides=[1, 1, 1, 1], padding='SAME'):
    return tf.nn.conv2d(Xin, f, strides, padding)

def max_pool(Xin, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'):
    return tf.nn.max_pool(Xin, ksize, strides, padding)

def rotate_img(x, deg):
    import scipy.ndimage as nd
    return nd.rotate(x.reshape(28,28),deg,reshape=False).ravel()
In [92]:
# Create a LeNet network with softmax cross entropy loss function
def LeNet_softmax(lmb=0.005): 
    g = tf.Graph()
    with g.as_default():
        X = tf.placeholder(shape=[None,28*28], dtype=tf.float32)
        Y = tf.placeholder(shape=[None,10], dtype=tf.float32)
        keep_prob = tf.placeholder(dtype=tf.float32)
        
        # first hidden layer - conv
        W1 = var('W1', [5,5,1,20])
        b1 = var('b1', [20])
        out1 = max_pool(tf.nn.relu(conv(tf.reshape(X, [-1, 28,28, 1]), 
                                        W1, strides=[1, 1, 1, 1]) + b1))
        # second hidden layer - conv
        W2 = var('W2', [5,5,20,50])
        b2 = var('b2', [50])
        out2 = max_pool(tf.nn.relu(conv(out1, W2, strides=[1, 1, 1, 1]) + b2))
        # flatten the output
        Xflat = tf.contrib.layers.flatten(out2)
        # third hidden layer - fully connected
        W3 = var('W3', [Xflat.get_shape()[1].value, 500])
        b3 = var('b3', [500]) 
        out3 = tf.nn.relu(tf.matmul(Xflat, W3) + b3)
        out3 = tf.nn.dropout(out3, keep_prob=keep_prob)
        #output layer
        W4 = var('W4', [500,10])
        b4 = var('b4',[10])
        logits = tf.matmul(out3, W4) + b4
        
        prob = tf.nn.softmax(logits=logits) 
        
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
        l2_loss = (tf.nn.l2_loss(W3)+tf.nn.l2_loss(W4)) * lmb
        
        step = tf.train.AdamOptimizer().minimize(loss + l2_loss)
        
        # Calculate accuracy
        pred = tf.argmax(logits, 1)
        truth = tf.argmax(Y, 1)
        acc = tf.reduce_mean(tf.cast(tf.equal(pred, truth), tf.float32))
        
        return g, step, X, Y, keep_prob, prob, acc, loss
In [93]:
# get the LeNet network
g1, step1, X1, Y1, keep_prob1, prob1, acc1, loss1 = LeNet_softmax()
In [169]:
sess1 = tf.Session(graph=g1)
with g1.as_default(): 
    sess1.run(tf.global_variables_initializer())
In [170]:
bsize = 1000 #batch size
n_batches = mnist.train.num_examples // bsize
for epoch in range(50):   
    for i in range(n_batches):
        data, label = mnist.train.next_batch(bsize)
        feed_dict={X1:data, Y1:label, keep_prob1:.5}
        sess1.run(step1,feed_dict)
        print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\r' if i<n_batches-1 else '')
        
    train_acc = sess1.run(acc1, feed_dict={X1:mnist.train.images,Y1:mnist.train.labels,keep_prob1:1.})
    test_acc = sess1.run(acc1, feed_dict={X1:mnist.test.images,Y1:mnist.test.labels,keep_prob1:1.})
    
    print('training accuracy: %2.4f \t testing accuracy: %2.4f' % (train_acc, test_acc))
epoch 1 - 100%) training accuracy: 0.9235 	 testing accuracy: 0.9255
epoch 2 - 100%) training accuracy: 0.9514 	 testing accuracy: 0.9510
epoch 3 - 100%) training accuracy: 0.9648 	 testing accuracy: 0.9644
epoch 4 - 100%) training accuracy: 0.9699 	 testing accuracy: 0.9685
epoch 5 - 100%) training accuracy: 0.9765 	 testing accuracy: 0.9732
epoch 6 - 100%) training accuracy: 0.9796 	 testing accuracy: 0.9744
epoch 7 - 100%) training accuracy: 0.9800 	 testing accuracy: 0.9766
epoch 8 - 100%) training accuracy: 0.9816 	 testing accuracy: 0.9767
epoch 9 - 100%) training accuracy: 0.9824 	 testing accuracy: 0.9787
epoch 10 - 100%) training accuracy: 0.9863 	 testing accuracy: 0.9812
epoch 11 - 100%) training accuracy: 0.9862 	 testing accuracy: 0.9822
epoch 12 - 100%) training accuracy: 0.9873 	 testing accuracy: 0.9831
epoch 13 - 100%) training accuracy: 0.9866 	 testing accuracy: 0.9816
epoch 14 - 100%) training accuracy: 0.9875 	 testing accuracy: 0.9826
epoch 15 - 100%) training accuracy: 0.9883 	 testing accuracy: 0.9837
epoch 16 - 100%) training accuracy: 0.9895 	 testing accuracy: 0.9850
epoch 17 - 100%) training accuracy: 0.9892 	 testing accuracy: 0.9838
epoch 18 - 100%) training accuracy: 0.9877 	 testing accuracy: 0.9838
epoch 19 - 100%) training accuracy: 0.9899 	 testing accuracy: 0.9851
epoch 20 - 100%) training accuracy: 0.9903 	 testing accuracy: 0.9866
epoch 21 - 100%) training accuracy: 0.9901 	 testing accuracy: 0.9868
epoch 22 - 100%) training accuracy: 0.9909 	 testing accuracy: 0.9864
epoch 23 - 100%) training accuracy: 0.9915 	 testing accuracy: 0.9855
epoch 24 - 100%) training accuracy: 0.9914 	 testing accuracy: 0.9856
epoch 25 - 100%) training accuracy: 0.9904 	 testing accuracy: 0.9837
epoch 26 - 100%) training accuracy: 0.9913 	 testing accuracy: 0.9873
epoch 27 - 100%) training accuracy: 0.9931 	 testing accuracy: 0.9892
epoch 28 - 100%) training accuracy: 0.9910 	 testing accuracy: 0.9874
epoch 29 - 100%) training accuracy: 0.9920 	 testing accuracy: 0.9877
epoch 30 - 100%) training accuracy: 0.9923 	 testing accuracy: 0.9881
epoch 31 - 100%) training accuracy: 0.9922 	 testing accuracy: 0.9870
epoch 32 - 100%) training accuracy: 0.9934 	 testing accuracy: 0.9886
epoch 33 - 100%) training accuracy: 0.9937 	 testing accuracy: 0.9893
epoch 34 - 100%) training accuracy: 0.9946 	 testing accuracy: 0.9904
epoch 35 - 100%) training accuracy: 0.9937 	 testing accuracy: 0.9890
epoch 36 - 100%) training accuracy: 0.9933 	 testing accuracy: 0.9874
epoch 37 - 100%) training accuracy: 0.9944 	 testing accuracy: 0.9883
epoch 38 - 100%) training accuracy: 0.9944 	 testing accuracy: 0.9890
epoch 39 - 100%) training accuracy: 0.9948 	 testing accuracy: 0.9895
epoch 40 - 100%) training accuracy: 0.9947 	 testing accuracy: 0.9896
epoch 41 - 100%) training accuracy: 0.9935 	 testing accuracy: 0.9893
epoch 42 - 100%) training accuracy: 0.9943 	 testing accuracy: 0.9898
epoch 43 - 100%) training accuracy: 0.9936 	 testing accuracy: 0.9884
epoch 44 - 100%) training accuracy: 0.9945 	 testing accuracy: 0.9901
epoch 45 - 100%) training accuracy: 0.9945 	 testing accuracy: 0.9887
epoch 46 - 100%) training accuracy: 0.9949 	 testing accuracy: 0.9908
epoch 47 - 100%) training accuracy: 0.9950 	 testing accuracy: 0.9893
epoch 48 - 100%) training accuracy: 0.9949 	 testing accuracy: 0.9896
epoch 49 - 100%) training accuracy: 0.9940 	 testing accuracy: 0.9891
epoch 50 - 100%) training accuracy: 0.9949 	 testing accuracy: 0.9886

The test accuracy after 50 epochs is around 98.9%. Now, we want to classify a rotating digit from MNIST dataset to see how this network does for the samples that are not from the training set distribution. The following lines of codes helps us to see it.

In [322]:
# This method rotates an image counter-clockwise and classify it for different degress of rotation. 
# It plots the highest classification probability along with the class label for each rotation degree.
def rotating_image_classification(img, sess, prob, X, keep_prob, uncertainty=None, threshold=0.5):
    Mdeg = 180 
    Ndeg = int(Mdeg/10)+1
    ldeg = []
    lp = []
    lu=[]
    scores = np.zeros((1,K))
    rimgs = np.zeros((28,28*Ndeg))
    for i,deg in enumerate(np.linspace(0,Mdeg, Ndeg)):
        nimg = rotate_img(img,deg).reshape(28,28)
        nimg = np.clip(a=nimg,a_min=0,a_max=1)
        rimgs[:,i*28:(i+1)*28] = nimg
        feed_dict={X:nimg.reshape(1,-1), keep_prob:1.0}
        if uncertainty is None:
            p_pred_t = sess.run(prob, feed_dict=feed_dict)
        else:
            p_pred_t,u = sess.run([prob,uncertainty], feed_dict=feed_dict)
            lu.append(u.mean())
        scores += p_pred_t >= threshold
        ldeg.append(deg) 
        lp.append(p_pred_t[0])
    
    labels = np.arange(10)[scores[0].astype(bool)]
    lp = np.array(lp)[:,labels]
    c = ['black','blue','red','brown','purple','cyan']
    marker = ['s','^','o']*2
    labels = labels.tolist()
    for i in range(len(labels)):
        plt.plot(ldeg,lp[:,i],marker=marker[i],c=c[i])
    
    if uncertainty is not None:
        labels += ['uncertainty']
        plt.plot(ldeg,lu,marker='<',c='red')
        
    plt.legend(labels)
 
    plt.xlim([0,Mdeg])  
    plt.xlabel('Rotation Degree')
    plt.ylabel('Classification Probability')
    plt.show()

    plt.figure(figsize=[6.2,100])
    plt.imshow(1-rimgs,cmap='gray')
    plt.axis('off')
    plt.show()
In [323]:
rotating_image_classification(digit_one, sess1, prob1, X1, keep_prob1)

As shown above, a neural network trained to generate softmax probabilities fails significantly when it encounters a sample that is different from the training examples. The softmax forces neural network to pick one class, even though the object belongs to an unknown category. This is demonstrated when we rotate the digit one between 60 and 130 degrees.

Classification with Evidential Deep Learning

In the following sections, we train the same neural network using the loss functions introduced in the paper.

Using the Expected Mean Square Error (Eq. 5)

As described in the paper, a neural network can be trained to learn parameters of a Dirichlet distribution, instead of softmax probabilities. Dirichlet distributions with parameters $\alpha \geq 1$ behaves like a generative model for softmax probabilities (categorical distributions). It associates a likelihood value with each categorical distribution.

Some functions to convert logits to evidence

In [2]:
# This function to generate evidence is used for the first example
def relu_evidence(logits):
    return tf.nn.relu(logits)

# This one usually works better and used for the second and third examples
# For general settings and different datasets, you may try this one first
def exp_evidence(logits): 
    return tf.exp(tf.clip_by_value(logits,-10,10))

# This one is another alternative and 
# usually behaves better than the relu_evidence 
def softplus_evidence(logits):
    return tf.nn.softplus(logits)

Define the loss function

In [201]:
def KL(alpha):
    beta=tf.constant(np.ones((1,K)),dtype=tf.float32)
    S_alpha = tf.reduce_sum(alpha,axis=1,keep_dims=True)
    S_beta = tf.reduce_sum(beta,axis=1,keep_dims=True)
    lnB = tf.lgamma(S_alpha) - tf.reduce_sum(tf.lgamma(alpha),axis=1,keep_dims=True)
    lnB_uni = tf.reduce_sum(tf.lgamma(beta),axis=1,keep_dims=True) - tf.lgamma(S_beta)
    
    dg0 = tf.digamma(S_alpha)
    dg1 = tf.digamma(alpha)
    
    kl = tf.reduce_sum((alpha - beta)*(dg1-dg0),axis=1,keep_dims=True) + lnB + lnB_uni
    return kl

def mse_loss(p, alpha, global_step, annealing_step): 
    S = tf.reduce_sum(alpha, axis=1, keep_dims=True) 
    E = alpha - 1
    m = alpha / S
    
    A = tf.reduce_sum((p-m)**2, axis=1, keep_dims=True) 
    B = tf.reduce_sum(alpha*(S-alpha)/(S*S*(S+1)), axis=1, keep_dims=True) 
    
    annealing_coef = tf.minimum(1.0,tf.cast(global_step/annealing_step,tf.float32))
    
    alp = E*(1-p) + 1 
    C =  annealing_coef * KL(alp)
    return (A + B) + C
In [199]:
# train LeNet network with expected mean square error loss
def LeNet_EDL(logits2evidence=relu_evidence,loss_function=mse_loss, lmb=0.005):
    g = tf.Graph()
    with g.as_default():
        X = tf.placeholder(shape=[None,28*28], dtype=tf.float32)
        Y = tf.placeholder(shape=[None,10], dtype=tf.float32)
        keep_prob = tf.placeholder(dtype=tf.float32)
        global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)
        annealing_step = tf.placeholder(dtype=tf.int32) 
    
        # first hidden layer - conv
        W1 = var('W1', [5,5,1,20])
        b1 = var('b1', [20])
        out1 = max_pool(tf.nn.relu(conv(tf.reshape(X, [-1, 28,28, 1]), 
                                        W1, strides=[1, 1, 1, 1]) + b1))
        # second hidden layer - conv
        W2 = var('W2', [5,5,20,50])
        b2 = var('b2', [50])
        out2 = max_pool(tf.nn.relu(conv(out1, W2, strides=[1, 1, 1, 1]) + b2))
        # flatten the output
        Xflat = tf.contrib.layers.flatten(out2)
        # third hidden layer - fully connected
        W3 = var('W3', [Xflat.get_shape()[1].value, 500])
        b3 = var('b3', [500]) 
        out3 = tf.nn.relu(tf.matmul(Xflat, W3) + b3)
        out3 = tf.nn.dropout(out3, keep_prob=keep_prob)
        #output layer
        W4 = var('W4', [500,10])
        b4 = var('b4',[10])
        logits = tf.matmul(out3, W4) + b4
        
        evidence = logits2evidence(logits)
        alpha = evidence + 1
        
        u = K / tf.reduce_sum(alpha, axis=1, keep_dims=True) #uncertainty
        
        prob = alpha/tf.reduce_sum(alpha, 1, keepdims=True) 
        
        loss = tf.reduce_mean(loss_function(Y, alpha, global_step, annealing_step))
        l2_loss = (tf.nn.l2_loss(W3)+tf.nn.l2_loss(W4)) * lmb
        
        step = tf.train.AdamOptimizer().minimize(loss + l2_loss, global_step=global_step)
        
        # Calculate accuracy
        pred = tf.argmax(logits, 1)
        truth = tf.argmax(Y, 1)
        match = tf.reshape(tf.cast(tf.equal(pred, truth), tf.float32),(-1,1))
        acc = tf.reduce_mean(match)
        
        total_evidence = tf.reduce_sum(evidence,1, keepdims=True) 
        mean_ev = tf.reduce_mean(total_evidence)
        mean_ev_succ = tf.reduce_sum(tf.reduce_sum(evidence,1, keepdims=True)*match) / tf.reduce_sum(match+1e-20)
        mean_ev_fail = tf.reduce_sum(tf.reduce_sum(evidence,1, keepdims=True)*(1-match)) / (tf.reduce_sum(tf.abs(1-match))+1e-20) 
        
        return g, step, X, Y, annealing_step, keep_prob, prob, acc, loss, u, evidence, mean_ev, mean_ev_succ, mean_ev_fail
In [156]:
g2, step2, X2, Y2, annealing_step, keep_prob2, prob2, acc2, loss2, u, evidence, \
    mean_ev, mean_ev_succ, mean_ev_fail= LeNet_EDL()
In [172]:
sess2 = tf.Session(graph=g2)
with g2.as_default():
    sess2.run(tf.global_variables_initializer())
In [173]:
bsize = 1000 #batch size
n_batches = mnist.train.num_examples // bsize
L_train_acc1=[]
L_train_ev_s=[]
L_train_ev_f=[]
L_test_acc1=[]
L_test_ev_s=[]
L_test_ev_f=[]
for epoch in range(50):   
    for i in range(n_batches):
        data, label = mnist.train.next_batch(bsize)
        feed_dict={X2:data, Y2:label, keep_prob2:.5, annealing_step:10*n_batches}
        sess2.run(step2,feed_dict)
        print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\r' if i<n_batches-1 else '')
        
    train_acc, train_succ, train_fail = sess2.run([acc2,mean_ev_succ,mean_ev_fail], feed_dict={X2:mnist.train.images,Y2:mnist.train.labels,keep_prob2:1.})
    test_acc, test_succ, test_fail = sess2.run([acc2,mean_ev_succ,mean_ev_fail], feed_dict={X2:mnist.test.images,Y2:mnist.test.labels,keep_prob2:1.})
    
    L_train_acc1.append(train_acc)
    L_train_ev_s.append(train_succ)
    L_train_ev_f.append(train_fail)
    
    L_test_acc1.append(test_acc)
    L_test_ev_s.append(test_succ)
    L_test_ev_f.append(test_fail)
    
    print('training: %2.4f (%2.4f - %2.4f) \t testing: %2.4f (%2.4f - %2.4f)' % 
          (train_acc, train_succ, train_fail, test_acc, test_succ, test_fail))
epoch 1 - 100%) training: 0.9470 (29.9496 - 6.0470) 	 testing: 0.9525 (30.3331 - 6.2809)
epoch 2 - 100%) training: 0.9674 (33.9131 - 6.2303) 	 testing: 0.9695 (34.4653 - 7.0255)
epoch 3 - 100%) training: 0.9756 (33.1459 - 4.3576) 	 testing: 0.9762 (33.7443 - 4.1752)
epoch 4 - 100%) training: 0.9745 (33.9431 - 3.5602) 	 testing: 0.9749 (34.4402 - 3.7496)
epoch 5 - 100%) training: 0.9807 (36.8166 - 4.0170) 	 testing: 0.9789 (37.4320 - 4.1132)
epoch 6 - 100%) training: 0.9791 (36.6413 - 3.0833) 	 testing: 0.9803 (37.2622 - 3.1119)
epoch 7 - 100%) training: 0.9782 (39.2723 - 3.2930) 	 testing: 0.9778 (40.0590 - 3.1863)
epoch 8 - 100%) training: 0.9808 (37.4109 - 1.9068) 	 testing: 0.9800 (38.1145 - 2.2643)
epoch 9 - 100%) training: 0.9815 (39.3951 - 2.7377) 	 testing: 0.9805 (40.1664 - 2.9485)
epoch 10 - 100%) training: 0.9831 (40.2205 - 2.0345) 	 testing: 0.9830 (40.8734 - 2.4930)
epoch 11 - 100%) training: 0.9839 (40.3787 - 1.5182) 	 testing: 0.9840 (41.1746 - 1.7158)
epoch 12 - 100%) training: 0.9833 (39.9269 - 1.6233) 	 testing: 0.9836 (40.7200 - 2.0022)
epoch 13 - 100%) training: 0.9844 (42.0865 - 1.8232) 	 testing: 0.9835 (42.9469 - 2.2165)
epoch 14 - 100%) training: 0.9816 (40.6318 - 1.4549) 	 testing: 0.9811 (41.4050 - 2.1860)
epoch 15 - 100%) training: 0.9851 (44.5367 - 2.1988) 	 testing: 0.9842 (45.3868 - 3.0271)
epoch 16 - 100%) training: 0.9837 (44.3451 - 1.9175) 	 testing: 0.9829 (45.0711 - 3.2015)
epoch 17 - 100%) training: 0.9839 (47.7866 - 2.7627) 	 testing: 0.9824 (48.6517 - 3.8829)
epoch 18 - 100%) training: 0.9856 (44.8029 - 1.7993) 	 testing: 0.9855 (45.7763 - 3.0277)
epoch 19 - 100%) training: 0.9841 (44.6586 - 1.8779) 	 testing: 0.9834 (45.5720 - 3.4869)
epoch 20 - 100%) training: 0.9875 (45.7646 - 1.8881) 	 testing: 0.9877 (46.7228 - 3.3131)
epoch 21 - 100%) training: 0.9866 (46.3577 - 1.9462) 	 testing: 0.9861 (47.1481 - 3.0685)
epoch 22 - 100%) training: 0.9861 (46.5912 - 1.9597) 	 testing: 0.9863 (47.3737 - 2.7471)
epoch 23 - 100%) training: 0.9869 (48.9247 - 2.2133) 	 testing: 0.9867 (49.9383 - 2.7514)
epoch 24 - 100%) training: 0.9870 (46.9884 - 2.0665) 	 testing: 0.9855 (48.0439 - 2.1623)
epoch 25 - 100%) training: 0.9873 (50.8303 - 2.4647) 	 testing: 0.9854 (51.8676 - 3.1680)
epoch 26 - 100%) training: 0.9880 (49.6770 - 2.3419) 	 testing: 0.9879 (50.5636 - 3.7460)
epoch 27 - 100%) training: 0.9871 (49.9567 - 2.2482) 	 testing: 0.9862 (51.1154 - 3.4443)
epoch 28 - 100%) training: 0.9877 (50.9868 - 2.4529) 	 testing: 0.9869 (52.1459 - 3.8382)
epoch 29 - 100%) training: 0.9883 (46.9654 - 1.3623) 	 testing: 0.9873 (47.9654 - 2.1033)
epoch 30 - 100%) training: 0.9882 (51.7587 - 2.4527) 	 testing: 0.9871 (52.7797 - 4.4101)
epoch 31 - 100%) training: 0.9883 (52.9645 - 3.1079) 	 testing: 0.9873 (54.1205 - 3.6993)
epoch 32 - 100%) training: 0.9881 (51.1556 - 2.3109) 	 testing: 0.9875 (52.0979 - 4.0251)
epoch 33 - 100%) training: 0.9880 (48.3095 - 1.5548) 	 testing: 0.9874 (49.4096 - 2.0053)
epoch 34 - 100%) training: 0.9885 (51.3298 - 2.3731) 	 testing: 0.9864 (52.2786 - 3.8313)
epoch 35 - 100%) training: 0.9892 (53.0268 - 2.6064) 	 testing: 0.9875 (54.1511 - 3.8138)
epoch 36 - 100%) training: 0.9902 (50.2909 - 1.8811) 	 testing: 0.9884 (51.5063 - 2.5316)
epoch 37 - 100%) training: 0.9885 (51.8024 - 2.0432) 	 testing: 0.9879 (52.7857 - 2.8535)
epoch 38 - 100%) training: 0.9888 (51.6903 - 1.5005) 	 testing: 0.9881 (52.7587 - 2.3378)
epoch 39 - 100%) training: 0.9896 (54.4732 - 1.9507) 	 testing: 0.9876 (55.7110 - 2.9229)
epoch 40 - 100%) training: 0.9885 (49.8664 - 1.7354) 	 testing: 0.9866 (50.7666 - 2.3627)
epoch 41 - 100%) training: 0.9893 (54.7296 - 2.3521) 	 testing: 0.9878 (55.9593 - 3.4784)
epoch 42 - 100%) training: 0.9888 (55.0189 - 2.9684) 	 testing: 0.9884 (55.9647 - 3.5898)
epoch 43 - 100%) training: 0.9907 (54.7551 - 2.2131) 	 testing: 0.9887 (55.9601 - 4.9516)
epoch 44 - 100%) training: 0.9889 (55.8486 - 2.6489) 	 testing: 0.9880 (57.0468 - 4.3089)
epoch 45 - 100%) training: 0.9895 (56.3373 - 2.6319) 	 testing: 0.9888 (57.3626 - 4.7427)
epoch 46 - 100%) training: 0.9886 (52.5418 - 1.4541) 	 testing: 0.9882 (53.6143 - 2.9547)
epoch 47 - 100%) training: 0.9908 (53.2042 - 1.7521) 	 testing: 0.9876 (54.3501 - 2.6650)
epoch 48 - 100%) training: 0.9893 (57.3922 - 2.6419) 	 testing: 0.9877 (58.5497 - 4.1028)
epoch 49 - 100%) training: 0.9889 (55.2722 - 2.0515) 	 testing: 0.9876 (56.4727 - 2.7934)
epoch 50 - 100%) training: 0.9887 (55.1349 - 2.0851) 	 testing: 0.9880 (56.2424 - 2.6495)

The following function plots average total evidence and prediction uncertainty in addition to accuracy for the training and test sets. Let us note that uncertainty approaches to 1.0 as the total evidence approaches to 0.

In [195]:
def draw_EDL_results(train_acc1, train_ev_s, train_ev_f, test_acc1, test_ev_s, test_ev_f): 
    # calculate uncertainty for training and testing data for correctly and misclassified samples
    train_u_succ = K / (K+np.array(train_ev_s))
    train_u_fail = K / (K+np.array(train_ev_f))
    test_u_succ  = K / (K+np.array(test_ev_s))
    test_u_fail  = K / (K+np.array(test_ev_f))
    
    f, axs = pl.subplots(2, 2)
    f.set_size_inches([10,10])
    
    axs[0,0].plot(train_ev_s,c='r',marker='+')
    axs[0,0].plot(train_ev_f,c='k',marker='x')
    axs[0,0].set_title('Train Data')
    axs[0,0].set_xlabel('Epoch')
    axs[0,0].set_ylabel('Estimated total evidence for classification') 
    axs[0,0].legend(['Correct Clasifications','Misclasifications'])
    
    
    axs[0,1].plot(train_u_succ,c='r',marker='+')
    axs[0,1].plot(train_u_fail,c='k',marker='x')
    axs[0,1].plot(train_acc1,c='blue',marker='*')
    axs[0,1].set_title('Train Data')
    axs[0,1].set_xlabel('Epoch')
    axs[0,1].set_ylabel('Estimated uncertainty for classification')
    axs[0,1].legend(['Correct clasifications','Misclasifications', 'Accuracy'])
    
    axs[1,0].plot(test_ev_s,c='r',marker='+')
    axs[1,0].plot(test_ev_f,c='k',marker='x')
    axs[1,0].set_title('Test Data')
    axs[1,0].set_xlabel('Epoch')
    axs[1,0].set_ylabel('Estimated total evidence for classification') 
    axs[1,0].legend(['Correct Clasifications','Misclasifications'])
    
    
    axs[1,1].plot(test_u_succ,c='r',marker='+')
    axs[1,1].plot(test_u_fail,c='k',marker='x')
    axs[1,1].plot(test_acc1,c='blue',marker='*')
    axs[1,1].set_title('Test Data')
    axs[1,1].set_xlabel('Epoch')
    axs[1,1].set_ylabel('Estimated uncertainty for classification')
    axs[1,1].legend(['Correct clasifications','Misclasifications', 'Accuracy'])
    
    plt.show()
In [196]:
draw_EDL_results(L_train_acc1, L_train_ev_s, L_train_ev_f, L_test_acc1, L_test_ev_s, L_test_ev_f)

The figure above indicates that the proposed approach generates much smaller amount of evidence for the misclassified samples than the correctly classified ones. The uncertainty of the misclassified samples are around 0.8, while it is around 0.1 for the correctly classified ones, both for training and testing sets. This means that the neural network is very uncertain for the misclassified samples and provides certain predictions only for the correctly classified ones. In other words, the neural network also predicts when it fails by assigning high uncertainty to its wrong predictions.

In [174]:
rotating_image_classification(digit_one, sess2, prob2, X2, keep_prob2, u)

Using the Expected Cross Entropy (Eq. 4)

In this section, we train neural network using the loss function described in Eq. 4 in the paper. This loss function is derived using the expected value of the cross entropy loss over the predicted Dirichlet distribution.

In [219]:
def loss_EDL(func=tf.digamma):
    def loss_func(p, alpha, global_step, annealing_step): 
        S = tf.reduce_sum(alpha, axis=1, keep_dims=True) 
        E = alpha - 1
    
        A = tf.reduce_sum(p * (func(S) - func(alpha)), axis=1, keepdims=True)
    
        annealing_coef = tf.minimum(1.0, tf.cast(global_step/annealing_step,tf.float32))
    
        alp = E*(1-p) + 1 
        B =  annealing_coef * KL(alp)
    
        return (A + B)
    return loss_func
In [210]:
g3, step3, X3, Y3, annealing_step3, keep_prob3, prob3, acc3, loss3, u3, evidence3, \
    mean_ev3, mean_ev_succ3, mean_ev_fail3 = LeNet_EDL(exp_evidence, loss_EDL(tf.digamma), lmb=0.001)
In [211]:
sess3 = tf.Session(graph=g3)
with g3.as_default():
    sess3.run(tf.global_variables_initializer())
In [212]:
bsize = 1000 #batch size
n_batches = mnist.train.num_examples // bsize
L3_train_acc1=[]
L3_train_ev_s=[]
L3_train_ev_f=[]
L3_test_acc1=[]
L3_test_ev_s=[]
L3_test_ev_f=[]
for epoch in range(50):   
    for i in range(n_batches):
        data, label = mnist.train.next_batch(bsize)
        feed_dict={X3:data, Y3:label, keep_prob3:.5, annealing_step3:10*n_batches}
        sess3.run(step3,feed_dict)
        print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\r' if i<n_batches-1 else '')
        
    train_acc, train_succ, train_fail = sess3.run([acc3,mean_ev_succ3,mean_ev_fail3], feed_dict={X3:mnist.train.images,Y3:mnist.train.labels,keep_prob3:1.})
    test_acc, test_succ, test_fail = sess3.run([acc3,mean_ev_succ3,mean_ev_fail3], feed_dict={X3:mnist.test.images,Y3:mnist.test.labels,keep_prob3:1.})
    
    L3_train_acc1.append(train_acc)
    L3_train_ev_s.append(train_succ)
    L3_train_ev_f.append(train_fail)
    
    L3_test_acc1.append(test_acc)
    L3_test_ev_s.append(test_succ)
    L3_test_ev_f.append(test_fail)
    
    print('training: %2.4f (%2.4f - %2.4f) \t testing: %2.4f (%2.4f - %2.4f)' % 
          (train_acc, train_succ, train_fail, test_acc, test_succ, test_fail))
epoch 1 - 100%) training: 0.8303 (42.6223 - 16.4772) 	 testing: 0.8412 (42.8285 - 16.6934)
epoch 2 - 100%) training: 0.8992 (123.3957 - 11.1870) 	 testing: 0.9062 (122.5605 - 11.3092)
epoch 3 - 100%) training: 0.9206 (208.4110 - 9.0025) 	 testing: 0.9272 (203.0646 - 9.2725)
epoch 4 - 100%) training: 0.9413 (264.3557 - 8.2034) 	 testing: 0.9446 (259.2374 - 8.0735)
epoch 5 - 100%) training: 0.9485 (298.0337 - 6.7932) 	 testing: 0.9540 (297.0534 - 6.4097)
epoch 6 - 100%) training: 0.9526 (300.2455 - 5.8476) 	 testing: 0.9573 (301.2290 - 5.1179)
epoch 7 - 100%) training: 0.9595 (536.3224 - 6.1392) 	 testing: 0.9631 (548.7296 - 5.8329)
epoch 8 - 100%) training: 0.9632 (616.2153 - 5.7967) 	 testing: 0.9669 (642.6508 - 5.4167)
epoch 9 - 100%) training: 0.9664 (691.5225 - 5.5155) 	 testing: 0.9705 (711.5176 - 4.6776)
epoch 10 - 100%) training: 0.9671 (743.4854 - 4.3620) 	 testing: 0.9693 (765.5163 - 4.1878)
epoch 11 - 100%) training: 0.9695 (1386.9512 - 6.1451) 	 testing: 0.9735 (1369.8014 - 5.8578)
epoch 12 - 100%) training: 0.9723 (1498.5973 - 6.5162) 	 testing: 0.9747 (1543.9032 - 6.2582)
epoch 13 - 100%) training: 0.9717 (1548.7797 - 4.6337) 	 testing: 0.9743 (1684.6056 - 4.7657)
epoch 14 - 100%) training: 0.9735 (1498.0391 - 5.4553) 	 testing: 0.9757 (1613.8787 - 5.6297)
epoch 15 - 100%) training: 0.9759 (1493.4908 - 5.2220) 	 testing: 0.9779 (1563.8809 - 4.5270)
epoch 16 - 100%) training: 0.9764 (2248.6760 - 5.7876) 	 testing: 0.9777 (2307.0652 - 5.2200)
epoch 17 - 100%) training: 0.9769 (1995.6049 - 4.3072) 	 testing: 0.9804 (2252.7351 - 4.1795)
epoch 18 - 100%) training: 0.9765 (2430.6807 - 4.4716) 	 testing: 0.9795 (2640.8728 - 4.7834)
epoch 19 - 100%) training: 0.9776 (3172.2893 - 5.9542) 	 testing: 0.9792 (3272.4224 - 5.0041)
epoch 20 - 100%) training: 0.9788 (3570.1475 - 4.9005) 	 testing: 0.9812 (4264.9502 - 5.8081)
epoch 21 - 100%) training: 0.9787 (4876.1646 - 6.1165) 	 testing: 0.9808 (5475.2358 - 5.8976)
epoch 22 - 100%) training: 0.9799 (3180.2634 - 3.9057) 	 testing: 0.9813 (3317.7156 - 3.9012)
epoch 23 - 100%) training: 0.9806 (2328.6206 - 3.7536) 	 testing: 0.9826 (2397.2944 - 3.6987)
epoch 24 - 100%) training: 0.9798 (4677.1987 - 4.4783) 	 testing: 0.9823 (5023.0850 - 4.6507)
epoch 25 - 100%) training: 0.9811 (3927.5339 - 4.8530) 	 testing: 0.9830 (4024.1538 - 5.2015)
epoch 26 - 100%) training: 0.9798 (4298.3862 - 5.0999) 	 testing: 0.9831 (4730.0220 - 4.9904)
epoch 27 - 100%) training: 0.9811 (4639.9775 - 5.3606) 	 testing: 0.9831 (4931.9014 - 5.6236)
epoch 28 - 100%) training: 0.9825 (4483.1514 - 4.3366) 	 testing: 0.9844 (4708.0806 - 4.3383)
epoch 29 - 100%) training: 0.9829 (4930.3579 - 4.2995) 	 testing: 0.9836 (5377.6284 - 4.9327)
epoch 30 - 100%) training: 0.9829 (6808.6553 - 3.7501) 	 testing: 0.9822 (7229.0522 - 3.9000)
epoch 31 - 100%) training: 0.9829 (6299.5278 - 5.7005) 	 testing: 0.9845 (7308.3398 - 5.1458)
epoch 32 - 100%) training: 0.9829 (6325.0591 - 5.7004) 	 testing: 0.9854 (6570.1626 - 6.2992)
epoch 33 - 100%) training: 0.9834 (5909.4917 - 4.6970) 	 testing: 0.9843 (6029.9980 - 5.9591)
epoch 34 - 100%) training: 0.9841 (7869.9731 - 5.8773) 	 testing: 0.9849 (8628.2305 - 5.6427)
epoch 35 - 100%) training: 0.9836 (8436.3779 - 4.7038) 	 testing: 0.9850 (9239.3271 - 6.6275)
epoch 36 - 100%) training: 0.9841 (5976.3433 - 3.9183) 	 testing: 0.9854 (6832.0786 - 4.6664)
epoch 37 - 100%) training: 0.9835 (7943.0410 - 4.1992) 	 testing: 0.9848 (9035.4375 - 5.9112)
epoch 38 - 100%) training: 0.9856 (7578.3315 - 4.6529) 	 testing: 0.9869 (8479.9932 - 5.4112)
epoch 39 - 100%) training: 0.9851 (6393.1548 - 4.1753) 	 testing: 0.9859 (7391.0679 - 4.8497)
epoch 40 - 100%) training: 0.9861 (7625.8906 - 4.5776) 	 testing: 0.9862 (8205.8027 - 5.7368)
epoch 41 - 100%) training: 0.9859 (8331.8799 - 3.7653) 	 testing: 0.9870 (9274.6816 - 4.7081)
epoch 42 - 100%) training: 0.9859 (13754.2227 - 4.9862) 	 testing: 0.9875 (14675.8350 - 5.6302)
epoch 43 - 100%) training: 0.9855 (15424.7334 - 6.8568) 	 testing: 0.9864 (16225.4473 - 7.8045)
epoch 44 - 100%) training: 0.9862 (12365.5068 - 4.6459) 	 testing: 0.9866 (13331.6621 - 5.5387)
epoch 45 - 100%) training: 0.9861 (12153.3447 - 4.4586) 	 testing: 0.9876 (14052.3027 - 6.7540)
epoch 46 - 100%) training: 0.9865 (17505.6484 - 4.8752) 	 testing: 0.9870 (18024.4199 - 6.1390)
epoch 47 - 100%) training: 0.9859 (15554.4268 - 4.5205) 	 testing: 0.9860 (17220.3281 - 5.6510)
epoch 48 - 100%) training: 0.9870 (12240.6279 - 3.7974) 	 testing: 0.9873 (14007.5693 - 5.2733)
epoch 49 - 100%) training: 0.9867 (13552.8066 - 3.8983) 	 testing: 0.9875 (15117.2031 - 4.9112)
epoch 50 - 100%) training: 0.9858 (13994.8203 - 3.8285) 	 testing: 0.9854 (14965.4561 - 4.7617)
In [213]:
draw_EDL_results(L3_train_acc1, L3_train_ev_s, L3_train_ev_f, L3_test_acc1, L3_test_ev_s, L3_test_ev_f)

The figure above indicates that the neural network generates much more evidence for the correctly classified samples. As a result, it has a very low uncertainty (around zero) for the correctly classified samples, while the uncertainty is very high (around 0.7) for the misclassified samples.

In [214]:
rotating_image_classification(digit_one, sess3, prob3, X3, keep_prob3, u3)

Using Negative Log of the Expected Likelihood (Eq. 3)

In this section, we repeat our experiments using the loss function based on Eq. 3 in the paper.

In [221]:
g4, step4, X4, Y4, annealing_step4, keep_prob4, prob4, acc4, loss4, u4, evidence4, \
    mean_ev4, mean_ev_succ4, mean_ev_fail4 = LeNet_EDL(exp_evidence, loss_EDL(tf.log), lmb=0.001)
In [225]:
sess4 = tf.Session(graph=g4)
with g4.as_default():
    sess4.run(tf.global_variables_initializer())
In [226]:
bsize = 1000 #batch size
n_batches = mnist.train.num_examples // bsize
L4_train_acc1=[]
L4_train_ev_s=[]
L4_train_ev_f=[]
L4_test_acc1=[]
L4_test_ev_s=[]
L4_test_ev_f=[]
for epoch in range(50):   
    for i in range(n_batches):
        data, label = mnist.train.next_batch(bsize)
        feed_dict={X4:data, Y4:label, keep_prob4:.5, annealing_step4:10*n_batches}
        sess4.run(step4,feed_dict)
        print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\r' if i<n_batches-1 else '')
        
    train_acc, train_succ, train_fail = sess4.run([acc4,mean_ev_succ4,mean_ev_fail4], feed_dict={X4:mnist.train.images,Y4:mnist.train.labels,keep_prob4:1.})
    test_acc, test_succ, test_fail = sess4.run([acc4,mean_ev_succ4,mean_ev_fail4], feed_dict={X4:mnist.test.images,Y4:mnist.test.labels,keep_prob4:1.})
    
    L4_train_acc1.append(train_acc)
    L4_train_ev_s.append(train_succ)
    L4_train_ev_f.append(train_fail)
    
    L4_test_acc1.append(test_acc)
    L4_test_ev_s.append(test_succ)
    L4_test_ev_f.append(test_fail)
    
    print('training: %2.4f (%2.4f - %2.4f) \t testing: %2.4f (%2.4f - %2.4f)' % 
          (train_acc, train_succ, train_fail, test_acc, test_succ, test_fail))
epoch 1 - 100%) training: 0.8389 (47.7661 - 14.2396) 	 testing: 0.8481 (48.1106 - 14.5359)
epoch 2 - 100%) training: 0.9120 (98.9960 - 8.6988) 	 testing: 0.9194 (99.6993 - 8.8808)
epoch 3 - 100%) training: 0.9304 (180.1434 - 7.9925) 	 testing: 0.9390 (182.9537 - 7.6880)
epoch 4 - 100%) training: 0.9454 (357.4596 - 8.9753) 	 testing: 0.9515 (353.7061 - 7.7063)
epoch 5 - 100%) training: 0.9546 (374.0661 - 6.4429) 	 testing: 0.9585 (369.2362 - 6.1300)
epoch 6 - 100%) training: 0.9579 (717.8140 - 7.7555) 	 testing: 0.9627 (698.8102 - 7.6518)
epoch 7 - 100%) training: 0.9617 (820.1666 - 5.8483) 	 testing: 0.9646 (819.3505 - 5.0063)
epoch 8 - 100%) training: 0.9646 (685.4586 - 4.7092) 	 testing: 0.9682 (700.0871 - 4.2215)
epoch 9 - 100%) training: 0.9676 (864.0211 - 4.7806) 	 testing: 0.9700 (882.0850 - 4.2597)
epoch 10 - 100%) training: 0.9699 (1145.9790 - 4.9029) 	 testing: 0.9728 (1166.5842 - 4.3104)
epoch 11 - 100%) training: 0.9698 (1265.3774 - 3.9921) 	 testing: 0.9730 (1349.9486 - 3.4476)
epoch 12 - 100%) training: 0.9730 (1586.4016 - 5.1049) 	 testing: 0.9747 (1696.7997 - 4.9364)
epoch 13 - 100%) training: 0.9735 (2014.4788 - 5.1080) 	 testing: 0.9763 (2115.8186 - 4.7999)
epoch 14 - 100%) training: 0.9735 (2741.4673 - 4.3296) 	 testing: 0.9752 (2957.9802 - 3.9254)
epoch 15 - 100%) training: 0.9752 (2673.9426 - 4.4678) 	 testing: 0.9772 (2692.9707 - 4.9921)
epoch 16 - 100%) training: 0.9768 (2388.1035 - 4.0882) 	 testing: 0.9781 (2634.2166 - 3.7764)
epoch 17 - 100%) training: 0.9764 (2701.2002 - 4.6162) 	 testing: 0.9791 (3023.1316 - 5.1292)
epoch 18 - 100%) training: 0.9773 (2878.8640 - 4.0546) 	 testing: 0.9792 (3105.1746 - 3.6160)
epoch 19 - 100%) training: 0.9768 (3326.2048 - 4.1775) 	 testing: 0.9787 (3591.4688 - 4.3881)
epoch 20 - 100%) training: 0.9788 (3257.1694 - 4.1882) 	 testing: 0.9791 (3451.4497 - 3.3292)
epoch 21 - 100%) training: 0.9788 (4058.1035 - 4.2671) 	 testing: 0.9801 (4339.0176 - 4.3674)
epoch 22 - 100%) training: 0.9795 (4905.1646 - 5.0746) 	 testing: 0.9806 (5299.4648 - 4.9685)
epoch 23 - 100%) training: 0.9789 (4146.0679 - 4.8996) 	 testing: 0.9794 (4280.2456 - 4.3978)
epoch 24 - 100%) training: 0.9789 (5102.7090 - 5.0765) 	 testing: 0.9811 (5573.3687 - 4.9651)
epoch 25 - 100%) training: 0.9797 (4721.9268 - 3.4639) 	 testing: 0.9824 (4908.5527 - 3.7816)
epoch 26 - 100%) training: 0.9794 (4624.8179 - 3.5023) 	 testing: 0.9810 (4622.0835 - 4.1510)
epoch 27 - 100%) training: 0.9803 (7247.1953 - 4.8597) 	 testing: 0.9823 (7369.5410 - 5.0344)
epoch 28 - 100%) training: 0.9820 (6480.2974 - 4.0996) 	 testing: 0.9825 (7157.7944 - 4.7370)
epoch 29 - 100%) training: 0.9813 (7673.8716 - 4.6725) 	 testing: 0.9817 (7892.9888 - 4.8215)
epoch 30 - 100%) training: 0.9819 (7318.6362 - 4.3854) 	 testing: 0.9841 (7933.3677 - 4.6528)
epoch 31 - 100%) training: 0.9825 (8063.1187 - 4.8635) 	 testing: 0.9827 (8506.7451 - 4.2365)
epoch 32 - 100%) training: 0.9816 (6621.0513 - 3.1235) 	 testing: 0.9819 (7588.6938 - 3.7367)
epoch 33 - 100%) training: 0.9826 (10533.4658 - 4.9056) 	 testing: 0.9843 (11875.2090 - 5.4827)
epoch 34 - 100%) training: 0.9833 (10538.7793 - 4.8260) 	 testing: 0.9830 (11016.7715 - 4.3061)
epoch 35 - 100%) training: 0.9830 (9445.5898 - 4.3326) 	 testing: 0.9830 (10311.8994 - 4.2545)
epoch 36 - 100%) training: 0.9824 (9012.2568 - 3.6567) 	 testing: 0.9831 (10201.7939 - 3.7273)
epoch 37 - 100%) training: 0.9838 (8231.3916 - 3.4212) 	 testing: 0.9848 (9213.5693 - 4.3222)
epoch 38 - 100%) training: 0.9835 (9698.7676 - 3.6955) 	 testing: 0.9844 (12237.3604 - 4.4046)
epoch 39 - 100%) training: 0.9813 (12806.3682 - 3.5367) 	 testing: 0.9833 (13110.8662 - 4.3908)
epoch 40 - 100%) training: 0.9834 (12758.0078 - 3.7434) 	 testing: 0.9852 (14355.3516 - 5.7558)
epoch 41 - 100%) training: 0.9841 (18760.6660 - 4.6124) 	 testing: 0.9847 (18847.6367 - 4.9546)
epoch 42 - 100%) training: 0.9844 (14055.1133 - 4.2950) 	 testing: 0.9850 (15240.1768 - 5.1640)
epoch 43 - 100%) training: 0.9839 (17531.1875 - 4.8606) 	 testing: 0.9856 (20936.6172 - 6.0384)
epoch 44 - 100%) training: 0.9842 (11528.1709 - 4.0646) 	 testing: 0.9857 (12673.5449 - 4.8653)
epoch 45 - 100%) training: 0.9838 (13236.9697 - 4.4013) 	 testing: 0.9848 (14658.7588 - 6.3498)
epoch 46 - 100%) training: 0.9841 (16241.7793 - 3.5962) 	 testing: 0.9871 (18014.9746 - 4.5276)
epoch 47 - 100%) training: 0.9851 (15670.2637 - 5.8087) 	 testing: 0.9855 (17604.8535 - 5.6529)
epoch 48 - 100%) training: 0.9855 (14058.3115 - 4.6745) 	 testing: 0.9865 (15463.1846 - 5.6835)
epoch 49 - 100%) training: 0.9854 (16627.9102 - 4.3638) 	 testing: 0.9851 (18772.2539 - 4.5671)
epoch 50 - 100%) training: 0.9860 (18039.0078 - 3.9769) 	 testing: 0.9875 (20124.0410 - 4.8621)
In [227]:
draw_EDL_results(L4_train_acc1, L4_train_ev_s, L4_train_ev_f, L4_test_acc1, L4_test_ev_s, L4_test_ev_f)
In [228]:
rotating_image_classification(digit_one, sess4, prob4, X4, keep_prob4, u4)

Some Other Data Uncertainty Experiments

Consider the case that we mix two digits from the MNIST dataset and query a classifier trained on MNIST dataset to classify it. For example, the following image is created by overlaying digit 0 with digit 6. The resulting image have similarities to both digits but neither 0 nor 6.

In [317]:
im0 =  mnist.test.images[10]
im6 =  mnist.test.images[21]
img = im0 + im6
img /= img.max()
plt.subplot(1,3,1)
plt.imshow(im0.reshape(28,28))
plt.subplot(1,3,2)
plt.imshow(im6.reshape(28,28))
plt.subplot(1,3,3)
plt.imshow(img.reshape(28,28))
plt.show()

The neural network trained with softmax cross entropy loss has the following prediction for the classification of this image, where the image is classifed as 0 with probability 0.9.

In [318]:
p1 = sess1.run(prob1, feed_dict={X1:img[None,:], keep_prob1:1.0})
print('softmax prob: ', np.round(p1[0], decimals=3))
softmax prob:  [0.901 0.    0.    0.    0.    0.017 0.078 0.    0.003 0.   ]

When we do the same experiments on the neural net trained using the loss function in Eq. 7, we have a much different results. The neural network could not generate any evidence to classify the image into one of 10 digits. Hence, it provides uniform distribution as its prediction. It implies I do not know by providing maximum uncertainty.

In [324]:
uncertainty2, p2 = sess2.run([u, prob2], feed_dict={X2:img[None,:], keep_prob2:1.0})
print('uncertainty:', np.round(uncertainty2[0,0], decimals=2))
print('Dirichlet mean: ', np.round(p2[0], decimals=3))
uncertainty: 1.0
Dirichlet mean:  [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]

When we use the loss function in Eq. 5, the exepcted probability is highest for digit 0. It is around 0.32, however, the associated uncertainty is quite high around 0.73 as shown below.

In [325]:
uncertainty3, p3 = sess3.run([u3, prob3], feed_dict={X3:img[None,:], keep_prob3:1.0})
print('uncertainty:', np.round(uncertainty3[0,0], decimals=2))
print('Dirichlet mean: ', np.round(p3[0], decimals=3))
uncertainty: 0.73
Dirichlet mean:  [0.325 0.073 0.075 0.074 0.073 0.081 0.078 0.073 0.074 0.074]

The uncertainty increase to 0.85 while the expected probability for the digit 0 decreases to 0.184 when the loss function in Eq. 6 is used.

In [326]:
uncertainty4, p4 = sess4.run([u4, prob4], feed_dict={X4:img[None,:], keep_prob4:1.0})
print('uncertainty:', np.round(uncertainty4[0,0], decimals=2))
print('Dirichlet mean: ', np.round(p4[0], decimals=3))
uncertainty: 0.85
Dirichlet mean:  [0.184 0.085 0.085 0.085 0.085 0.097 0.123 0.085 0.087 0.085]

Lets try another settings where each of these two digits can be recognizable easily. You can see below an image which is created by combining images for digit 0 and digit 6 without any overlap.

In [330]:
img = np.zeros((28,28))
img[:,:-6] += mnist.test.images[10].reshape(28,28)[:,6:]
img[:,14:] += mnist.test.images[21].reshape(28,28)[:,5:19]
img /= img.max()
plt.imshow(img)
plt.show()

Below, you can see the prediction of the neural network trained with softmax cross entropy for this example. The prediction of the network is digit 2 with probability 0.775. Hence, the network associates quite high probability with the wrong label.

In [331]:
p1 = sess1.run(prob1, feed_dict={X1:img.reshape(1,-1), keep_prob1:1.0})
print('softmax prob: ', np.round(p1[0], decimals=3))
softmax prob:  [0.    0.199 0.775 0.007 0.003 0.    0.    0.015 0.    0.   ]

On the otherhand, when we do the same using the network trained based on the loss in Eq. 7, the output of the neural network is uniform distribution with uncertainty 1.0, as shown below.

In [332]:
uncertainty2, p2 = sess2.run([u, prob2], feed_dict={X2:img.reshape(1,-1), keep_prob2:1.0})
print('uncertainty:', np.round(uncertainty2[0,0], decimals=2))
print('Dirichlet mean: ', np.round(p2[0], decimals=3))
uncertainty: 1.0
Dirichlet mean:  [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]

The neural networks, trained using the loss functions defined in Eq. 5 and Eq. 6 in the paper, also have very high uncertainty for their predictions. These networks assing small amount of evidence for the classification of the image as digit 2. However, they associate very high uncertainty with their misclassifications of the image.

In [333]:
uncertainty3, p3 = sess3.run([u3, prob3], feed_dict={X3:img.reshape(1,-1), keep_prob3:1.0})
print('uncertainty:', np.round(uncertainty3[0,0], decimals=2))
print('Dirichlet mean: ', np.round(p3[0], decimals=3))
uncertainty: 0.92
Dirichlet mean:  [0.092 0.094 0.143 0.12  0.092 0.092 0.092 0.093 0.092 0.092]
In [334]:
uncertainty4, p4 = sess4.run([u4, prob4], feed_dict={X4:img.reshape(1,-1), keep_prob4:1.0})
print('uncertainty:', np.round(uncertainty4[0,0], decimals=2))
print('Dirichlet mean: ', np.round(p4[0], decimals=3))
uncertainty: 0.93
Dirichlet mean:  [0.093 0.093 0.16  0.098 0.093 0.093 0.093 0.094 0.093 0.093]
In [ ]: