{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quantifying Classification Uncertainty in Deep Neural Networks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The purpose of this page is to provide an easy-to-run demo with low computational requirements for the ideas proposed in the paper _Evidential Deep Learning to Quantify Classification Uncertainty_. Using MNIST dataset, I demonstrate how to create neural networks that are able to quantify classification uncertainty. The paper can be accesed over http://arxiv.org/abs/1806.01768\n",
"\n",
"You can run this notebook in Colab using the colab icon below: \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The notebook can also be downloaded using https://muratsensoy.github.io/uncertainty.ipynb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neural Networks Trained with Softmax Cross Entropy Loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following lines of codes demonstrate how softmax based Deep Neural Networks fail when they encounter out-of-sample queries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use this while running this notebook in Colab\n",
"%tensorflow_version 1.x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#import necessary libraries\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"import scipy.ndimage as nd\n",
"\n",
"%matplotlib inline\n",
"import pylab as pl\n",
"from IPython import display\n",
"\n",
"from tensorflow.examples.tutorials.mnist import input_data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 431
},
"colab_type": "code",
"executionInfo": {
"elapsed": 14869,
"status": "ok",
"timestamp": 1527923826590,
"user": {
"displayName": "Murat Sensoy",
"photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
"userId": "102692943223630372304"
},
"user_tz": -180
},
"id": "MROFqBJQ1naS",
"outputId": "56a3f37d-9b80-400b-c6da-ec68bcaa5cd8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"# Download MNIST dataset\n",
"mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
"\n",
"K= 10 # number of classes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADDxJREFUeJzt3V2sHPV5x/Hvg3NspwZSSKhjHFKniCBRq3XQkVsJ0lJRIuJGNdzQuFLqSqhOpVA1EhdF5KJcoqpJFFURlQlWnIoCkQjCkWgaatEi1BZxsBxeQhIomMausYMcFRMV45enF2ccncDZOce7sy/Hz/cjrXZ3ntmZR2P/zszO7O4/MhNJ9Zwz7gYkjYfhl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9U1HtGubLlsSJXsmqUq5RKeYuf8XYei8XMO1D4I+J64CvAMuBrmXln2/wrWcVvxbWDrFJSiydz96Ln7fuwPyKWAV8FPglcAWyJiCv6XZ6k0RrkPf9G4KXMfDkz3wbuBzZ305akYRsk/GuBH895vr+Z9gsiYltEzETEzHGODbA6SV0a+tn+zNyemdOZOT3FimGvTtIiDRL+A8Alc55/qJkmaQkYJPxPAZdFxEciYjnwaWBXN21JGra+L/Vl5omIuAX4Z2Yv9e3IzOc760zSUA10nT8zHwEe6agXSSPkx3ulogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKmqkQ3Rr6Xnl/t9orT9x1V2t9T/+k7/oWVv22J6+elI33PNLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlEDXeePiH3AUeAkcCIzp7toSpMj/3tVa/39H39va/3I5St61i56rK+W1JEuPuTze5n5egfLkTRCHvZLRQ0a/gS+GxFPR8S2LhqSNBqDHvZfnZkHIuJXgEcj4geZ+fjcGZo/CtsAVvJLA65OUlcG2vNn5oHm/jDwELBxnnm2Z+Z0Zk5P0fvkj6TR6jv8EbEqIs47/Rj4BPBcV41JGq5BDvtXAw9FxOnl/GNmfqeTriQNXd/hz8yXgd/ssBdNoFX7Y6DXf/CPXu1ZO/n3Ay1aA/JSn1SU4ZeKMvxSUYZfKsrwS0UZfqkof7pbQ/V/J6Z61paPsA+9m3t+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/xqdf4fHBzo9f/74MU9axfR++u+Gj73/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlNf5izt5zZWt9W//+ldb63vfXtZaX31v73FcTrW+UsPmnl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXilrwOn9E7AA+BRzOzPXNtAuBB4B1wD7gpsz86fDa1LCcXNH+9//cWNFaP57ZWj919OgZ96TRWMye/+vA9e+YdhuwOzMvA3Y3zyUtIQuGPzMfB468Y/JmYGfzeCdwQ8d9SRqyft/zr87M07/v9BqwuqN+JI3IwCf8MjOBnm/8ImJbRMxExMxxjg26Okkd6Tf8hyJiDUBzf7jXjJm5PTOnM3N6ivaTR5JGp9/w7wK2No+3Ag93046kUVkw/BFxH/AfwOURsT8ibgbuBK6LiBeB32+eS1pCFrzOn5lbepSu7bgXjcG+G/2cV1X+y0tFGX6pKMMvFWX4paIMv1SU4ZeK8qe7izvvg37ltir3/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU3+c/y52zcmVr/eq1rwy0/LsP/+4Cc7w50PI1PO75paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqmoBa/zR8QO4FPA4cxc30y7A/gz4CfNbLdn5iPDalL9O+eX39da/7uL/2mg5f/bE+tb65fynwMtX8OzmD3/14Hr55n+5czc0NwMvrTELBj+zHwcODKCXiSN0CDv+W+JiGciYkdEXNBZR5JGot/w3wVcCmwADgJf7DVjRGyLiJmImDnOsT5XJ6lrfYU/Mw9l5snMPAXcDWxsmXd7Zk5n5vQUK/rtU1LH+gp/RKyZ8/RG4Llu2pE0Kou51HcfcA3wgYjYD/w1cE1EbAAS2Ad8dog9ShqCBcOfmVvmmXzPEHrREJxYt3qoy//wd44PdfkaHj/hJxVl+KWiDL9UlOGXijL8UlGGXyrKn+4+y73+hbcGev2mH/xha335v36vtZ4DrV3D5J5fKsrwS0UZfqkowy8VZfilogy/VJThl4ryOv9Z7q719y4wx7LW6v+8cX5r/eIT+8+wI00K9/xSUYZfKsrwS0UZfqkowy8VZfilogy/VJTX+c8C71n34Z618+LfW1+7LKa6bkdLhHt+qSjDLxVl+KWiDL9UlOGXijL8UlGGXypqwev8EXEJ8A1gNbM/w749M78SERcCDwDrgH3ATZn50+G1ql7e+lrv2kenVra+9mSeaq2f+8327/Nr6VrMnv8EcGtmXgH8NvC5iLgCuA3YnZmXAbub55KWiAXDn5kHM3NP8/go8AKwFtgM7Gxm2wncMKwmJXXvjN7zR8Q64GPAk8DqzDzYlF5j9m2BpCVi0eGPiHOBB4HPZ+Ybc2uZmfQYli0itkXETETMHOfYQM1K6s6iwh8RU8wG/97M/FYz+VBErGnqa4DD8702M7dn5nRmTk+xooueJXVgwfBHRAD3AC9k5pfmlHYBW5vHW4GHu29P0rAs5iu9VwGfAZ6NiL3NtNuBO4FvRsTNwKvATcNpUcs+emlr/dZ1u/pe9pZXrmutn3//k30vW5NtwfBn5hNA9Chf2207kkbFT/hJRRl+qSjDLxVl+KWiDL9UlOGXivKnu5eAt9e+r7V+7Xv7/9j0jx64vLW+Ott/+ltLl3t+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/xnuT/f//HW+sX3/bC1frLLZjRR3PNLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe518Clj22p7W+ae2VLdWfLbD0heo6W7nnl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiFgx/RFwSEY9FxPcj4vmI+Mtm+h0RcSAi9ja3TcNvV1JXFvMhnxPArZm5JyLOA56OiEeb2pcz82+H156kYVkw/Jl5EDjYPD4aES8Aa4fdmKThOqP3/BGxDvgY8GQz6ZaIeCYidkTEBT1esy0iZiJi5jj9DyslqVuLDn9EnAs8CHw+M98A7gIuBTYwe2Twxflel5nbM3M6M6enWNFBy5K6sKjwR8QUs8G/NzO/BZCZhzLzZGaeAu4GNg6vTUldW8zZ/gDuAV7IzC/Nmb5mzmw3As91356kYVnM2f6rgM8Az0bE3mba7cCWiNgAJLAP+OxQOpQ0FIs52/8EEPOUHum+HUmj4if8pKIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRUVmjm5lET8BXp0z6QPA6yNr4MxMam+T2hfYW7+67O1XM/Oixcw40vC/a+URM5k5PbYGWkxqb5PaF9hbv8bVm4f9UlGGXypq3OHfPub1t5nU3ia1L7C3fo2lt7G+55c0PuPe80sak7GEPyKuj4gfRsRLEXHbOHroJSL2RcSzzcjDM2PuZUdEHI6I5+ZMuzAiHo2IF5v7eYdJG1NvEzFyc8vI0mPddpM24vXID/sjYhnwI+A6YD/wFLAlM78/0kZ6iIh9wHRmjv2acET8DvAm8I3MXN9M+xvgSGbe2fzhvCAz/2pCersDeHPcIzc3A8qsmTuyNHAD8KeMcdu19HUTY9hu49jzbwReysyXM/Nt4H5g8xj6mHiZ+Thw5B2TNwM7m8c7mf3PM3I9epsImXkwM/c0j48Cp0eWHuu2a+lrLMYR/rXAj+c8389kDfmdwHcj4umI2DbuZuaxuhk2HeA1YPU4m5nHgiM3j9I7RpaemG3Xz4jXXfOE37tdnZlXAp8EPtcc3k6knH3PNkmXaxY1cvOozDOy9M+Nc9v1O+J118YR/gPAJXOef6iZNhEy80Bzfxh4iMkbffjQ6UFSm/vDY+7n5yZp5Ob5RpZmArbdJI14PY7wPwVcFhEfiYjlwKeBXWPo410iYlVzIoaIWAV8gskbfXgXsLV5vBV4eIy9/IJJGbm518jSjHnbTdyI15k58huwidkz/v8FfGEcPfTo69eA7zW358fdG3Afs4eBx5k9N3Iz8H5gN/Ai8C/AhRPU2z8AzwLPMBu0NWPq7WpmD+mfAfY2t03j3nYtfY1lu/kJP6koT/hJRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrq/wFWuLZ5z/d+bwAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"digit_one = mnist.train.images[4].copy()\n",
"plt.imshow(digit_one.reshape(28,28)) \n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# define some utility functions\n",
"def var(name, shape, init=None):\n",
" if init is None:\n",
" init = tf.truncated_normal_initializer(stddev=(2/shape[0])**0.5)\n",
" return tf.get_variable(name=name, shape=shape, dtype=tf.float32,\n",
" initializer=init)\n",
"\n",
"def conv(Xin, f, strides=[1, 1, 1, 1], padding='SAME'):\n",
" return tf.nn.conv2d(Xin, f, strides, padding)\n",
"\n",
"def max_pool(Xin, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'):\n",
" return tf.nn.max_pool(Xin, ksize, strides, padding)\n",
"\n",
"def rotate_img(x, deg):\n",
" import scipy.ndimage as nd\n",
" return nd.rotate(x.reshape(28,28),deg,reshape=False).ravel()"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Create a LeNet network with softmax cross entropy loss function\n",
"def LeNet_softmax(lmb=0.005): \n",
" g = tf.Graph()\n",
" with g.as_default():\n",
" X = tf.placeholder(shape=[None,28*28], dtype=tf.float32)\n",
" Y = tf.placeholder(shape=[None,10], dtype=tf.float32)\n",
" keep_prob = tf.placeholder(dtype=tf.float32)\n",
" \n",
" # first hidden layer - conv\n",
" W1 = var('W1', [5,5,1,20])\n",
" b1 = var('b1', [20])\n",
" out1 = max_pool(tf.nn.relu(conv(tf.reshape(X, [-1, 28,28, 1]), \n",
" W1, strides=[1, 1, 1, 1]) + b1))\n",
" # second hidden layer - conv\n",
" W2 = var('W2', [5,5,20,50])\n",
" b2 = var('b2', [50])\n",
" out2 = max_pool(tf.nn.relu(conv(out1, W2, strides=[1, 1, 1, 1]) + b2))\n",
" # flatten the output\n",
" Xflat = tf.contrib.layers.flatten(out2)\n",
" # third hidden layer - fully connected\n",
" W3 = var('W3', [Xflat.get_shape()[1].value, 500])\n",
" b3 = var('b3', [500]) \n",
" out3 = tf.nn.relu(tf.matmul(Xflat, W3) + b3)\n",
" out3 = tf.nn.dropout(out3, keep_prob=keep_prob)\n",
" #output layer\n",
" W4 = var('W4', [500,10])\n",
" b4 = var('b4',[10])\n",
" logits = tf.matmul(out3, W4) + b4\n",
" \n",
" prob = tf.nn.softmax(logits=logits) \n",
" \n",
" loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))\n",
" l2_loss = (tf.nn.l2_loss(W3)+tf.nn.l2_loss(W4)) * lmb\n",
" \n",
" step = tf.train.AdamOptimizer().minimize(loss + l2_loss)\n",
" \n",
" # Calculate accuracy\n",
" pred = tf.argmax(logits, 1)\n",
" truth = tf.argmax(Y, 1)\n",
" acc = tf.reduce_mean(tf.cast(tf.equal(pred, truth), tf.float32))\n",
" \n",
" return g, step, X, Y, keep_prob, prob, acc, loss"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# get the LeNet network\n",
"g1, step1, X1, Y1, keep_prob1, prob1, acc1, loss1 = LeNet_softmax()"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"id": "xTVDRunk1Wsq",
"scrolled": true
},
"outputs": [],
"source": [
"sess1 = tf.Session(graph=g1)\n",
"with g1.as_default(): \n",
" sess1.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "QwYkUP6L2q-n",
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 - 100%) training accuracy: 0.9235 \t testing accuracy: 0.9255\n",
"epoch 2 - 100%) training accuracy: 0.9514 \t testing accuracy: 0.9510\n",
"epoch 3 - 100%) training accuracy: 0.9648 \t testing accuracy: 0.9644\n",
"epoch 4 - 100%) training accuracy: 0.9699 \t testing accuracy: 0.9685\n",
"epoch 5 - 100%) training accuracy: 0.9765 \t testing accuracy: 0.9732\n",
"epoch 6 - 100%) training accuracy: 0.9796 \t testing accuracy: 0.9744\n",
"epoch 7 - 100%) training accuracy: 0.9800 \t testing accuracy: 0.9766\n",
"epoch 8 - 100%) training accuracy: 0.9816 \t testing accuracy: 0.9767\n",
"epoch 9 - 100%) training accuracy: 0.9824 \t testing accuracy: 0.9787\n",
"epoch 10 - 100%) training accuracy: 0.9863 \t testing accuracy: 0.9812\n",
"epoch 11 - 100%) training accuracy: 0.9862 \t testing accuracy: 0.9822\n",
"epoch 12 - 100%) training accuracy: 0.9873 \t testing accuracy: 0.9831\n",
"epoch 13 - 100%) training accuracy: 0.9866 \t testing accuracy: 0.9816\n",
"epoch 14 - 100%) training accuracy: 0.9875 \t testing accuracy: 0.9826\n",
"epoch 15 - 100%) training accuracy: 0.9883 \t testing accuracy: 0.9837\n",
"epoch 16 - 100%) training accuracy: 0.9895 \t testing accuracy: 0.9850\n",
"epoch 17 - 100%) training accuracy: 0.9892 \t testing accuracy: 0.9838\n",
"epoch 18 - 100%) training accuracy: 0.9877 \t testing accuracy: 0.9838\n",
"epoch 19 - 100%) training accuracy: 0.9899 \t testing accuracy: 0.9851\n",
"epoch 20 - 100%) training accuracy: 0.9903 \t testing accuracy: 0.9866\n",
"epoch 21 - 100%) training accuracy: 0.9901 \t testing accuracy: 0.9868\n",
"epoch 22 - 100%) training accuracy: 0.9909 \t testing accuracy: 0.9864\n",
"epoch 23 - 100%) training accuracy: 0.9915 \t testing accuracy: 0.9855\n",
"epoch 24 - 100%) training accuracy: 0.9914 \t testing accuracy: 0.9856\n",
"epoch 25 - 100%) training accuracy: 0.9904 \t testing accuracy: 0.9837\n",
"epoch 26 - 100%) training accuracy: 0.9913 \t testing accuracy: 0.9873\n",
"epoch 27 - 100%) training accuracy: 0.9931 \t testing accuracy: 0.9892\n",
"epoch 28 - 100%) training accuracy: 0.9910 \t testing accuracy: 0.9874\n",
"epoch 29 - 100%) training accuracy: 0.9920 \t testing accuracy: 0.9877\n",
"epoch 30 - 100%) training accuracy: 0.9923 \t testing accuracy: 0.9881\n",
"epoch 31 - 100%) training accuracy: 0.9922 \t testing accuracy: 0.9870\n",
"epoch 32 - 100%) training accuracy: 0.9934 \t testing accuracy: 0.9886\n",
"epoch 33 - 100%) training accuracy: 0.9937 \t testing accuracy: 0.9893\n",
"epoch 34 - 100%) training accuracy: 0.9946 \t testing accuracy: 0.9904\n",
"epoch 35 - 100%) training accuracy: 0.9937 \t testing accuracy: 0.9890\n",
"epoch 36 - 100%) training accuracy: 0.9933 \t testing accuracy: 0.9874\n",
"epoch 37 - 100%) training accuracy: 0.9944 \t testing accuracy: 0.9883\n",
"epoch 38 - 100%) training accuracy: 0.9944 \t testing accuracy: 0.9890\n",
"epoch 39 - 100%) training accuracy: 0.9948 \t testing accuracy: 0.9895\n",
"epoch 40 - 100%) training accuracy: 0.9947 \t testing accuracy: 0.9896\n",
"epoch 41 - 100%) training accuracy: 0.9935 \t testing accuracy: 0.9893\n",
"epoch 42 - 100%) training accuracy: 0.9943 \t testing accuracy: 0.9898\n",
"epoch 43 - 100%) training accuracy: 0.9936 \t testing accuracy: 0.9884\n",
"epoch 44 - 100%) training accuracy: 0.9945 \t testing accuracy: 0.9901\n",
"epoch 45 - 100%) training accuracy: 0.9945 \t testing accuracy: 0.9887\n",
"epoch 46 - 100%) training accuracy: 0.9949 \t testing accuracy: 0.9908\n",
"epoch 47 - 100%) training accuracy: 0.9950 \t testing accuracy: 0.9893\n",
"epoch 48 - 100%) training accuracy: 0.9949 \t testing accuracy: 0.9896\n",
"epoch 49 - 100%) training accuracy: 0.9940 \t testing accuracy: 0.9891\n",
"epoch 50 - 100%) training accuracy: 0.9949 \t testing accuracy: 0.9886\n"
]
}
],
"source": [
"bsize = 1000 #batch size\n",
"n_batches = mnist.train.num_examples // bsize\n",
"for epoch in range(50): \n",
" for i in range(n_batches):\n",
" data, label = mnist.train.next_batch(bsize)\n",
" feed_dict={X1:data, Y1:label, keep_prob1:.5}\n",
" sess1.run(step1,feed_dict)\n",
" print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\\r' if i= threshold\n",
" ldeg.append(deg) \n",
" lp.append(p_pred_t[0])\n",
" \n",
" labels = np.arange(10)[scores[0].astype(bool)]\n",
" lp = np.array(lp)[:,labels]\n",
" c = ['black','blue','red','brown','purple','cyan']\n",
" marker = ['s','^','o']*2\n",
" labels = labels.tolist()\n",
" for i in range(len(labels)):\n",
" plt.plot(ldeg,lp[:,i],marker=marker[i],c=c[i])\n",
" \n",
" if uncertainty is not None:\n",
" labels += ['uncertainty']\n",
" plt.plot(ldeg,lu,marker='<',c='red')\n",
" \n",
" plt.legend(labels)\n",
" \n",
" plt.xlim([0,Mdeg]) \n",
" plt.xlabel('Rotation Degree')\n",
" plt.ylabel('Classification Probability')\n",
" plt.show()\n",
"\n",
" plt.figure(figsize=[6.2,100])\n",
" plt.imshow(1-rimgs,cmap='gray')\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 323,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAA1CAYAAACp8OvZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADTxJREFUeJzt3V1MHGUXB/Czy7J8G1pWQitWUk3TIIloGkIoNpC0EmKkkqgtCQWJRiBsmxKCXJDQxA80VWsN0WqRi4oGmoJKiL1zbdwUTWmxSmi1dRvshu62tWLBrmVndv7vBZl5l9qPmWcQoXN+SS+62Tl75tln/vOx06kNADHGGLMW+3/dAGOMsYXH4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbkWODPW5BnSSiKQna7uf2aLMvkcJgbnkgkQjExMaZqMMbuHpIkUWxs7L/9MTY9b7orj/ztdjtNTk6aquFwOCgUCpmqMR/BH4lETNdgjJknSZLpGgsQ/LotuvB/7bXX6ODBg6ZqnDx5kjZt2kTHjx8XDs9QKERJSUm0a9cu4T52795NNpuNjh8/LlwjJiaG9u7dS+Pj40QktjMIh8NzlpVlWbgfFT8QkC0ERVH+8zpXrlyhF198kS5evCj82cFgkGpra2nTpk00MzMj3Munn35q+qqGatGF/wMPPEAHDhwwVSMrK4tOnDhBfX19QiEFgBITEyk7O5va29uFelAUhZ599lmqq6ujhoYGCgaDQr1cvHiRDh8+TBUVFfThhx8KnU04nU4iInrllVdoeHiYHA6H0MYgSZK2A7HZbLwDYP8KRVG0uWW3202Fpd/vp8cff1woMNW53traSl1dXfT7778L9aAoCvX09NCXX35JpaWl2vZoFAA6cOAArVy5Umj5mxZcwD93VF5ejt7eXj1vvS0igsfjMVWjrq4Os0MkLhQKgYjw1VdfIRwOC9eprq6Gw+EAAMzMzAjVaGtrg8vlMjUuRIS4uDiUlpYCgOF1UnsfGRnBt99+K9wHm1+SJCESiUBRFMPLRiIRRCIR7e+nTp3C0aNH57ymtw4AjI+Po7i4GE6nE42NjYbrALPrk5WVBSJCRUWF4eUVRUEoFMILL7wAIsJDDz2kvW5UOBwGESE5OdnwstG2bdsGIkIwGLzTW3Xl8aIK/5GRETidToRCoTuPxC0oioLLly+DiDA6OgpZloXqyLKMiooKU+GvThQiQk1NjdDEUZfp7OwU7kUdg8uXL+O9997DM888A0VRhHZGg4ODyM3NndOLJEmGaoTDYSQmJuKee+5BX1+f9pqRfhRFgSRJCIVC8Pv92vpIkiQ0zkuFum6yLCMSiRge++g6N4b29PS0qd6++OILNDU1weVyGZ6r6nr19/ejqKgINHtzCM6cOSPUy6FDh0BEcLlcOHv2rOHlJUmC1+tFTk4OiouL4fP5hOfVxMQEiAh5eXnC3xcwG/42m03PW5de+Dc3N+tdudsaGxsDEaGlpUX4KBkAurq6kJiYiNHRUVNf2tNPP236DEKWZfT09ICI8NZbbwGA8JGVekQ0PDwsVAcAent7QUSIj4/H1atXARg7Krp69SrOnDkDm82m7Ry//vpr3ctLkoSamhosX74caWlpICK43W54vV5cunRJe5+edTtx4gR6e3uxefPmeTnrvJmpqSn8+eef/3j9dv2pO+3oOXzu3DmcP38eo6OjePvtt/HRRx+hs7MTIyMjuvpQa4VCIRw5cgT79u3DihUrtLCNj4/HH3/8ccc66vbg9/vhdru1A4LMzExs375dd+Cq6//DDz+gvb0ddrsdRARJkhAOh4XmZn5+PogI9fX1AIwfrUciEUxNTYGIkJCQgOvXrxvuQf3ciYkJlJWVoba2FoFAQGh9VDab7e4N/8rKynkJ/7Nnz4KIUFJSInypJRKJoLu7G6tWrTK9E6mrq0NSUhI6OzuFawDA8PCwtpGqPYrYu3cv0tPTkZqaKtxLOBzWjtBWr16NgYEB3ctGn411dHRg3bp12np1d3djampK187W5/NhZGQEHo8HVVVVWo2YmBh0dHRgYmJCVz8JCQnahmWz2ZCenq792bBhA44dO6b7DHJ4eBgtLS3weDwYHR1FS0sLSkpKUFhYiNraWnR2dqKhoQH9/f0IBAK6agLQahUUFCA/Px+5ubnIyMgAESElJUW7HHCzOaGGX/RZZGFhoTZeashlZGTA4/EYOlvesGGDVuPJJ5+cs+PVQ/2eCwsLkZCQgPLycuzevRuA8UuKADAwMAAiQmxsrPCZTDgcRnd395xLRiJH/bIso6OjAytXrjS0fdzKXR3+69evNx3+0ZdaCgsLTdUCgKqqKhQUFMypLdJPQ0MDHnvsMRw7dkyoD3WDbGpqAhFhcHBQqI5qcnISDocD7e3tAMR3JKdPn0ZycjKICL/++qtwrcHBQezcuVMLEq/XCwC3DSL1GrU6xn///TfefPNN1NbWanWKioqwZcsWQ718//33aG5uRmVlpTYn1TMUm82G7u5u7Wwn2tTUFHJyclBfX4+tW7fi4Ycf1sL6qaee0kI7Pj4eSUlJc3bkN65nMBiEw+HQ3pOYmIhVq1ahqqoKra2t+Oyzz7B//354vV50dXXhwoULt5yfgUAAhw8fRk1NDYgISUlJWLduneE5FN2j2+0GEWHt2rXw+Xza60YCe2ZmBoFAAGvXrtUOsgDjlxHVZVpaWkBEePnllwGIzUO1fyLCzp07DS+vUr+LmJgYOJ3OOa+JstlsqKys1PPWpRf+BvZst6ROUCJCbW2tqVoA0Nraivz8fNN1BgcHQUQ4cuSIcNBKkgSfz4fU1FS43W4AYhNcHaOysjKkpaXh6NGjQrXUybxnzx7ExsaivLwchw4dMtxP9EZRUVGBhIQE5OTk6Plh65YGBgbm7AT27Nlj6EhbNTMzg7GxMWzduhWpqanaHL3xwCL68sXk5CSA2d9YxsbG5lwC+fnnnzE9PY33338fdXV1aGxs/MdnXrt2DXl5ecjOzobb7UZFRQW6urrQ3d1tuH8AyMvLQ0JCAogITqcTdXV1+OabbwD8f+zvFLjq+3788UdtLjudTnzwwQe6lr+VJ554QrtEMzk5KVRHDez4+HgQEU6fPi3US3QtswdY6noQEZYtWwbg9gcyenD430H0oO/atcvUNTYAGBoawurVqwGI77mjz0b27dtnuqf+/n7k5uZiaGjI1I5kZmYGdrsdJSUlwr2o4+31ehEbG4s1a9YAMD5W0Rt9W1sbiAjV1dXw+/2611FRFMiyrH12IBBAXl4ekpOTkZ6ejpdeeslQTzcKBoMoKyvTPU+jx0DtLdqNP7hG27JlC4aGhm56/V0dK7Xe7ULll19+ARFh/fr1ePXVV+HxeLRLmCJB++CDD8LhcCAzMxM9PT1CdSRJwrVr1+DxeEBEyMrKwvj4uOFegLljrJ7p3fi6UeodeleuXBGuEb0TycnJATA/4d/c3KznrUsz/Ldt26Zn5W7rp59+QkpKClpbW00P+P79+5GRkQFA/NKIioiwYsUKUzWia6WlpZmu8/zzz4OI8PHHH5uu9fnnn8PlcqG8vBznz58XqqFutD6fTztqF92Qo++Geeedd7Bx40YAYteS/wvRl7REybKMc+fOafVE5/D169dRWVmJxsZG7TKPmd7uvfdeuFwu+P1+4Rrq55eWloKI0Nvba+q7VRQFfX19yMrKQkdHh3AdVTAYBBHh4MGDpmuFQiE4nU69P+wvrfBXB2o+wn96ehp2ux0NDQ2mawUCARCRqeBXA0g99Y5+zSh1cmdmZpq+g0i9NTI3NxfFxcVaSIjWmp6e1n4Efv31102FNgDU19ebHi+VoijCd20sVWbH7EbvvvuudknFzPYgyzIcDgfa2tpM1wKAuLg45Obmmqqh9pCdnY24uDhTtYDZ+dbU1ISqqirTB6DA7N115eXlet++tML/vvvuw8aNG+flHu3nnnsO999/P0ZHR01PrK6uLqSkpJjuCYB2d8b09LTpvrZv3w4iwqlTp0z35ff7QUTYsWOHqTpq2HR0dODkyZOm7pBii0t0gM3HNvrdd9/NSy1ZluflBghVYmLivPxW6Pf7UVBQoPsW3DspLS3FJ598ovftuvJ40Tze4cKFC1RdXU02m64H0t2Soii0fPlyevTRR8nlcpl+DkZ2djb99ddfpmoQzf5T8c2bNxMRUXJysqm+IpEI7dixg9asWWP6AXaKolBmZiY1NjbSG2+8YeoZKOpTUN1uNz3yyCPC/4ydLT7RjxUxu42Gw2HKz8+fl1oxMTHk9XqpqKhoXp4DJMsypaSkmK6TlpY2r/O/urqafvvtt3mrR0RkAxb0+SwL8mE+n48yMzMpLi7OdC0AdOnSJUpPTzc9USORCE1NTdGyZctM96Waj8dXM8Zmt3Wz2/gioWsl7srwX+z4Of+M3d0AkCRJ/9XZL4c/Y4xZkK7wX+j/yeuuOKdijLGlji8WM8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBf0PwIACOHRGfV4AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"rotating_image_classification(digit_one, sess1, prob1, X1, keep_prob1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown above, a neural network trained to generate softmax probabilities fails significantly when it encounters a sample that is different from the training examples. The softmax forces neural network to pick one class, even though the object belongs to an unknown category. This is demonstrated when we rotate the digit one between 60 and 130 degrees. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classification with Evidential Deep Learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the following sections, we train the same neural network using the loss functions introduced in the paper."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## Using the Expected Mean Square Error (Eq. 5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"As described in the paper, a neural network can be trained to learn parameters of a Dirichlet distribution, instead of softmax probabilities. Dirichlet distributions with parameters $\\alpha \\geq 1$ behaves like a generative model for softmax probabilities (categorical distributions). It associates a likelihood value with each categorical distribution."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Some functions to convert logits to evidence"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# This function to generate evidence is used for the first example\n",
"def relu_evidence(logits):\n",
" return tf.nn.relu(logits)\n",
"\n",
"# This one usually works better and used for the second and third examples\n",
"# For general settings and different datasets, you may try this one first\n",
"def exp_evidence(logits): \n",
" return tf.exp(tf.clip_by_value(logits/10,-10,10))\n",
"\n",
"# This one is another alternative and \n",
"# usually behaves better than the relu_evidence \n",
"def softplus_evidence(logits):\n",
" return tf.nn.softplus(logits)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the loss function"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def KL(alpha):\n",
" beta=tf.constant(np.ones((1,K)),dtype=tf.float32)\n",
" S_alpha = tf.reduce_sum(alpha,axis=1,keep_dims=True)\n",
" S_beta = tf.reduce_sum(beta,axis=1,keep_dims=True)\n",
" lnB = tf.lgamma(S_alpha) - tf.reduce_sum(tf.lgamma(alpha),axis=1,keep_dims=True)\n",
" lnB_uni = tf.reduce_sum(tf.lgamma(beta),axis=1,keep_dims=True) - tf.lgamma(S_beta)\n",
" \n",
" dg0 = tf.digamma(S_alpha)\n",
" dg1 = tf.digamma(alpha)\n",
" \n",
" kl = tf.reduce_sum((alpha - beta)*(dg1-dg0),axis=1,keep_dims=True) + lnB + lnB_uni\n",
" return kl\n",
"\n",
"def mse_loss(p, alpha, global_step, annealing_step): \n",
" S = tf.reduce_sum(alpha, axis=1, keep_dims=True) \n",
" E = alpha - 1\n",
" m = alpha / S\n",
" \n",
" A = tf.reduce_sum((p-m)**2, axis=1, keep_dims=True) \n",
" B = tf.reduce_sum(alpha*(S-alpha)/(S*S*(S+1)), axis=1, keep_dims=True) \n",
" \n",
" annealing_coef = tf.minimum(1.0,tf.cast(global_step/annealing_step,tf.float32))\n",
" \n",
" alp = E*(1-p) + 1 \n",
" C = annealing_coef * KL(alp)\n",
" return (A + B) + C"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# train LeNet network with expected mean square error loss\n",
"def LeNet_EDL(logits2evidence=relu_evidence,loss_function=mse_loss, lmb=0.005):\n",
" g = tf.Graph()\n",
" with g.as_default():\n",
" X = tf.placeholder(shape=[None,28*28], dtype=tf.float32)\n",
" Y = tf.placeholder(shape=[None,10], dtype=tf.float32)\n",
" keep_prob = tf.placeholder(dtype=tf.float32)\n",
" global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)\n",
" annealing_step = tf.placeholder(dtype=tf.int32) \n",
" \n",
" # first hidden layer - conv\n",
" W1 = var('W1', [5,5,1,20])\n",
" b1 = var('b1', [20])\n",
" out1 = max_pool(tf.nn.relu(conv(tf.reshape(X, [-1, 28,28, 1]), \n",
" W1, strides=[1, 1, 1, 1]) + b1))\n",
" # second hidden layer - conv\n",
" W2 = var('W2', [5,5,20,50])\n",
" b2 = var('b2', [50])\n",
" out2 = max_pool(tf.nn.relu(conv(out1, W2, strides=[1, 1, 1, 1]) + b2))\n",
" # flatten the output\n",
" Xflat = tf.contrib.layers.flatten(out2)\n",
" # third hidden layer - fully connected\n",
" W3 = var('W3', [Xflat.get_shape()[1].value, 500])\n",
" b3 = var('b3', [500]) \n",
" out3 = tf.nn.relu(tf.matmul(Xflat, W3) + b3)\n",
" out3 = tf.nn.dropout(out3, keep_prob=keep_prob)\n",
" #output layer\n",
" W4 = var('W4', [500,10])\n",
" b4 = var('b4',[10])\n",
" logits = tf.matmul(out3, W4) + b4\n",
" \n",
" evidence = logits2evidence(logits)\n",
" alpha = evidence + 1\n",
" \n",
" u = K / tf.reduce_sum(alpha, axis=1, keep_dims=True) #uncertainty\n",
" \n",
" prob = alpha/tf.reduce_sum(alpha, 1, keepdims=True) \n",
" \n",
" loss = tf.reduce_mean(loss_function(Y, alpha, global_step, annealing_step))\n",
" l2_loss = (tf.nn.l2_loss(W3)+tf.nn.l2_loss(W4)) * lmb\n",
" \n",
" step = tf.train.AdamOptimizer().minimize(loss + l2_loss, global_step=global_step)\n",
" \n",
" # Calculate accuracy\n",
" pred = tf.argmax(logits, 1)\n",
" truth = tf.argmax(Y, 1)\n",
" match = tf.reshape(tf.cast(tf.equal(pred, truth), tf.float32),(-1,1))\n",
" acc = tf.reduce_mean(match)\n",
" \n",
" total_evidence = tf.reduce_sum(evidence,1, keepdims=True) \n",
" mean_ev = tf.reduce_mean(total_evidence)\n",
" mean_ev_succ = tf.reduce_sum(tf.reduce_sum(evidence,1, keepdims=True)*match) / tf.reduce_sum(match+1e-20)\n",
" mean_ev_fail = tf.reduce_sum(tf.reduce_sum(evidence,1, keepdims=True)*(1-match)) / (tf.reduce_sum(tf.abs(1-match))+1e-20) \n",
" \n",
" return g, step, X, Y, annealing_step, keep_prob, prob, acc, loss, u, evidence, mean_ev, mean_ev_succ, mean_ev_fail"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"g2, step2, X2, Y2, annealing_step, keep_prob2, prob2, acc2, loss2, u, evidence, \\\n",
" mean_ev, mean_ev_succ, mean_ev_fail= LeNet_EDL()"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"sess2 = tf.Session(graph=g2)\n",
"with g2.as_default():\n",
" sess2.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 - 100%) training: 0.9470 (29.9496 - 6.0470) \t testing: 0.9525 (30.3331 - 6.2809)\n",
"epoch 2 - 100%) training: 0.9674 (33.9131 - 6.2303) \t testing: 0.9695 (34.4653 - 7.0255)\n",
"epoch 3 - 100%) training: 0.9756 (33.1459 - 4.3576) \t testing: 0.9762 (33.7443 - 4.1752)\n",
"epoch 4 - 100%) training: 0.9745 (33.9431 - 3.5602) \t testing: 0.9749 (34.4402 - 3.7496)\n",
"epoch 5 - 100%) training: 0.9807 (36.8166 - 4.0170) \t testing: 0.9789 (37.4320 - 4.1132)\n",
"epoch 6 - 100%) training: 0.9791 (36.6413 - 3.0833) \t testing: 0.9803 (37.2622 - 3.1119)\n",
"epoch 7 - 100%) training: 0.9782 (39.2723 - 3.2930) \t testing: 0.9778 (40.0590 - 3.1863)\n",
"epoch 8 - 100%) training: 0.9808 (37.4109 - 1.9068) \t testing: 0.9800 (38.1145 - 2.2643)\n",
"epoch 9 - 100%) training: 0.9815 (39.3951 - 2.7377) \t testing: 0.9805 (40.1664 - 2.9485)\n",
"epoch 10 - 100%) training: 0.9831 (40.2205 - 2.0345) \t testing: 0.9830 (40.8734 - 2.4930)\n",
"epoch 11 - 100%) training: 0.9839 (40.3787 - 1.5182) \t testing: 0.9840 (41.1746 - 1.7158)\n",
"epoch 12 - 100%) training: 0.9833 (39.9269 - 1.6233) \t testing: 0.9836 (40.7200 - 2.0022)\n",
"epoch 13 - 100%) training: 0.9844 (42.0865 - 1.8232) \t testing: 0.9835 (42.9469 - 2.2165)\n",
"epoch 14 - 100%) training: 0.9816 (40.6318 - 1.4549) \t testing: 0.9811 (41.4050 - 2.1860)\n",
"epoch 15 - 100%) training: 0.9851 (44.5367 - 2.1988) \t testing: 0.9842 (45.3868 - 3.0271)\n",
"epoch 16 - 100%) training: 0.9837 (44.3451 - 1.9175) \t testing: 0.9829 (45.0711 - 3.2015)\n",
"epoch 17 - 100%) training: 0.9839 (47.7866 - 2.7627) \t testing: 0.9824 (48.6517 - 3.8829)\n",
"epoch 18 - 100%) training: 0.9856 (44.8029 - 1.7993) \t testing: 0.9855 (45.7763 - 3.0277)\n",
"epoch 19 - 100%) training: 0.9841 (44.6586 - 1.8779) \t testing: 0.9834 (45.5720 - 3.4869)\n",
"epoch 20 - 100%) training: 0.9875 (45.7646 - 1.8881) \t testing: 0.9877 (46.7228 - 3.3131)\n",
"epoch 21 - 100%) training: 0.9866 (46.3577 - 1.9462) \t testing: 0.9861 (47.1481 - 3.0685)\n",
"epoch 22 - 100%) training: 0.9861 (46.5912 - 1.9597) \t testing: 0.9863 (47.3737 - 2.7471)\n",
"epoch 23 - 100%) training: 0.9869 (48.9247 - 2.2133) \t testing: 0.9867 (49.9383 - 2.7514)\n",
"epoch 24 - 100%) training: 0.9870 (46.9884 - 2.0665) \t testing: 0.9855 (48.0439 - 2.1623)\n",
"epoch 25 - 100%) training: 0.9873 (50.8303 - 2.4647) \t testing: 0.9854 (51.8676 - 3.1680)\n",
"epoch 26 - 100%) training: 0.9880 (49.6770 - 2.3419) \t testing: 0.9879 (50.5636 - 3.7460)\n",
"epoch 27 - 100%) training: 0.9871 (49.9567 - 2.2482) \t testing: 0.9862 (51.1154 - 3.4443)\n",
"epoch 28 - 100%) training: 0.9877 (50.9868 - 2.4529) \t testing: 0.9869 (52.1459 - 3.8382)\n",
"epoch 29 - 100%) training: 0.9883 (46.9654 - 1.3623) \t testing: 0.9873 (47.9654 - 2.1033)\n",
"epoch 30 - 100%) training: 0.9882 (51.7587 - 2.4527) \t testing: 0.9871 (52.7797 - 4.4101)\n",
"epoch 31 - 100%) training: 0.9883 (52.9645 - 3.1079) \t testing: 0.9873 (54.1205 - 3.6993)\n",
"epoch 32 - 100%) training: 0.9881 (51.1556 - 2.3109) \t testing: 0.9875 (52.0979 - 4.0251)\n",
"epoch 33 - 100%) training: 0.9880 (48.3095 - 1.5548) \t testing: 0.9874 (49.4096 - 2.0053)\n",
"epoch 34 - 100%) training: 0.9885 (51.3298 - 2.3731) \t testing: 0.9864 (52.2786 - 3.8313)\n",
"epoch 35 - 100%) training: 0.9892 (53.0268 - 2.6064) \t testing: 0.9875 (54.1511 - 3.8138)\n",
"epoch 36 - 100%) training: 0.9902 (50.2909 - 1.8811) \t testing: 0.9884 (51.5063 - 2.5316)\n",
"epoch 37 - 100%) training: 0.9885 (51.8024 - 2.0432) \t testing: 0.9879 (52.7857 - 2.8535)\n",
"epoch 38 - 100%) training: 0.9888 (51.6903 - 1.5005) \t testing: 0.9881 (52.7587 - 2.3378)\n",
"epoch 39 - 100%) training: 0.9896 (54.4732 - 1.9507) \t testing: 0.9876 (55.7110 - 2.9229)\n",
"epoch 40 - 100%) training: 0.9885 (49.8664 - 1.7354) \t testing: 0.9866 (50.7666 - 2.3627)\n",
"epoch 41 - 100%) training: 0.9893 (54.7296 - 2.3521) \t testing: 0.9878 (55.9593 - 3.4784)\n",
"epoch 42 - 100%) training: 0.9888 (55.0189 - 2.9684) \t testing: 0.9884 (55.9647 - 3.5898)\n",
"epoch 43 - 100%) training: 0.9907 (54.7551 - 2.2131) \t testing: 0.9887 (55.9601 - 4.9516)\n",
"epoch 44 - 100%) training: 0.9889 (55.8486 - 2.6489) \t testing: 0.9880 (57.0468 - 4.3089)\n",
"epoch 45 - 100%) training: 0.9895 (56.3373 - 2.6319) \t testing: 0.9888 (57.3626 - 4.7427)\n",
"epoch 46 - 100%) training: 0.9886 (52.5418 - 1.4541) \t testing: 0.9882 (53.6143 - 2.9547)\n",
"epoch 47 - 100%) training: 0.9908 (53.2042 - 1.7521) \t testing: 0.9876 (54.3501 - 2.6650)\n",
"epoch 48 - 100%) training: 0.9893 (57.3922 - 2.6419) \t testing: 0.9877 (58.5497 - 4.1028)\n",
"epoch 49 - 100%) training: 0.9889 (55.2722 - 2.0515) \t testing: 0.9876 (56.4727 - 2.7934)\n",
"epoch 50 - 100%) training: 0.9887 (55.1349 - 2.0851) \t testing: 0.9880 (56.2424 - 2.6495)\n"
]
}
],
"source": [
"bsize = 1000 #batch size\n",
"n_batches = mnist.train.num_examples // bsize\n",
"L_train_acc1=[]\n",
"L_train_ev_s=[]\n",
"L_train_ev_f=[]\n",
"L_test_acc1=[]\n",
"L_test_ev_s=[]\n",
"L_test_ev_f=[]\n",
"for epoch in range(50): \n",
" for i in range(n_batches):\n",
" data, label = mnist.train.next_batch(bsize)\n",
" feed_dict={X2:data, Y2:label, keep_prob2:.5, annealing_step:10*n_batches}\n",
" sess2.run(step2,feed_dict)\n",
" print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\\r' if i"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_EDL_results(L_train_acc1, L_train_ev_s, L_train_ev_f, L_test_acc1, L_test_ev_s, L_test_ev_f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The figure above indicates that the proposed approach generates much smaller amount of evidence for the misclassified samples than the correctly classified ones. The uncertainty of the misclassified samples are around 0.8, while it is around 0.1 for the correctly classified ones, both for training and testing sets. This means that the neural network is very uncertain for the misclassified samples and provides certain predictions only for the correctly classified ones. In other words, the neural network also predicts when it fails by assigning high uncertainty to its wrong predictions."
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAA1CAYAAACp8OvZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADTxJREFUeJzt3V1MHGUXB/Czy7J8G1pWQitWUk3TIIloGkIoNpC0EmKkkqgtCQWJRiBsmxKCXJDQxA80VWsN0WqRi4oGmoJKiL1zbdwUTWmxSmi1dRvshu62tWLBrmVndv7vBZl5l9qPmWcQoXN+SS+62Tl75tln/vOx06kNADHGGLMW+3/dAGOMsYXH4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbkWODPW5BnSSiKQna7uf2aLMvkcJgbnkgkQjExMaZqMMbuHpIkUWxs7L/9MTY9b7orj/ztdjtNTk6aquFwOCgUCpmqMR/BH4lETNdgjJknSZLpGgsQ/LotuvB/7bXX6ODBg6ZqnDx5kjZt2kTHjx8XDs9QKERJSUm0a9cu4T52795NNpuNjh8/LlwjJiaG9u7dS+Pj40QktjMIh8NzlpVlWbgfFT8QkC0ERVH+8zpXrlyhF198kS5evCj82cFgkGpra2nTpk00MzMj3Munn35q+qqGatGF/wMPPEAHDhwwVSMrK4tOnDhBfX19QiEFgBITEyk7O5va29uFelAUhZ599lmqq6ujhoYGCgaDQr1cvHiRDh8+TBUVFfThhx8KnU04nU4iInrllVdoeHiYHA6H0MYgSZK2A7HZbLwDYP8KRVG0uWW3202Fpd/vp8cff1woMNW53traSl1dXfT7778L9aAoCvX09NCXX35JpaWl2vZoFAA6cOAArVy5Umj5mxZcwD93VF5ejt7eXj1vvS0igsfjMVWjrq4Os0MkLhQKgYjw1VdfIRwOC9eprq6Gw+EAAMzMzAjVaGtrg8vlMjUuRIS4uDiUlpYCgOF1UnsfGRnBt99+K9wHm1+SJCESiUBRFMPLRiIRRCIR7e+nTp3C0aNH57ymtw4AjI+Po7i4GE6nE42NjYbrALPrk5WVBSJCRUWF4eUVRUEoFMILL7wAIsJDDz2kvW5UOBwGESE5OdnwstG2bdsGIkIwGLzTW3Xl8aIK/5GRETidToRCoTuPxC0oioLLly+DiDA6OgpZloXqyLKMiooKU+GvThQiQk1NjdDEUZfp7OwU7kUdg8uXL+O9997DM888A0VRhHZGg4ODyM3NndOLJEmGaoTDYSQmJuKee+5BX1+f9pqRfhRFgSRJCIVC8Pv92vpIkiQ0zkuFum6yLCMSiRge++g6N4b29PS0qd6++OILNDU1weVyGZ6r6nr19/ejqKgINHtzCM6cOSPUy6FDh0BEcLlcOHv2rOHlJUmC1+tFTk4OiouL4fP5hOfVxMQEiAh5eXnC3xcwG/42m03PW5de+Dc3N+tdudsaGxsDEaGlpUX4KBkAurq6kJiYiNHRUVNf2tNPP236DEKWZfT09ICI8NZbbwGA8JGVekQ0PDwsVAcAent7QUSIj4/H1atXARg7Krp69SrOnDkDm82m7Ry//vpr3ctLkoSamhosX74caWlpICK43W54vV5cunRJe5+edTtx4gR6e3uxefPmeTnrvJmpqSn8+eef/3j9dv2pO+3oOXzu3DmcP38eo6OjePvtt/HRRx+hs7MTIyMjuvpQa4VCIRw5cgT79u3DihUrtLCNj4/HH3/8ccc66vbg9/vhdru1A4LMzExs375dd+Cq6//DDz+gvb0ddrsdRARJkhAOh4XmZn5+PogI9fX1AIwfrUciEUxNTYGIkJCQgOvXrxvuQf3ciYkJlJWVoba2FoFAQGh9VDab7e4N/8rKynkJ/7Nnz4KIUFJSInypJRKJoLu7G6tWrTK9E6mrq0NSUhI6OzuFawDA8PCwtpGqPYrYu3cv0tPTkZqaKtxLOBzWjtBWr16NgYEB3ctGn411dHRg3bp12np1d3djampK187W5/NhZGQEHo8HVVVVWo2YmBh0dHRgYmJCVz8JCQnahmWz2ZCenq792bBhA44dO6b7DHJ4eBgtLS3weDwYHR1FS0sLSkpKUFhYiNraWnR2dqKhoQH9/f0IBAK6agLQahUUFCA/Px+5ubnIyMgAESElJUW7HHCzOaGGX/RZZGFhoTZeashlZGTA4/EYOlvesGGDVuPJJ5+cs+PVQ/2eCwsLkZCQgPLycuzevRuA8UuKADAwMAAiQmxsrPCZTDgcRnd395xLRiJH/bIso6OjAytXrjS0fdzKXR3+69evNx3+0ZdaCgsLTdUCgKqqKhQUFMypLdJPQ0MDHnvsMRw7dkyoD3WDbGpqAhFhcHBQqI5qcnISDocD7e3tAMR3JKdPn0ZycjKICL/++qtwrcHBQezcuVMLEq/XCwC3DSL1GrU6xn///TfefPNN1NbWanWKioqwZcsWQ718//33aG5uRmVlpTYn1TMUm82G7u5u7Wwn2tTUFHJyclBfX4+tW7fi4Ycf1sL6qaee0kI7Pj4eSUlJc3bkN65nMBiEw+HQ3pOYmIhVq1ahqqoKra2t+Oyzz7B//354vV50dXXhwoULt5yfgUAAhw8fRk1NDYgISUlJWLduneE5FN2j2+0GEWHt2rXw+Xza60YCe2ZmBoFAAGvXrtUOsgDjlxHVZVpaWkBEePnllwGIzUO1fyLCzp07DS+vUr+LmJgYOJ3OOa+JstlsqKys1PPWpRf+BvZst6ROUCJCbW2tqVoA0Nraivz8fNN1BgcHQUQ4cuSIcNBKkgSfz4fU1FS43W4AYhNcHaOysjKkpaXh6NGjQrXUybxnzx7ExsaivLwchw4dMtxP9EZRUVGBhIQE5OTk6Plh65YGBgbm7AT27Nlj6EhbNTMzg7GxMWzduhWpqanaHL3xwCL68sXk5CSA2d9YxsbG5lwC+fnnnzE9PY33338fdXV1aGxs/MdnXrt2DXl5ecjOzobb7UZFRQW6urrQ3d1tuH8AyMvLQ0JCAogITqcTdXV1+OabbwD8f+zvFLjq+3788UdtLjudTnzwwQe6lr+VJ554QrtEMzk5KVRHDez4+HgQEU6fPi3US3QtswdY6noQEZYtWwbg9gcyenD430H0oO/atcvUNTYAGBoawurVqwGI77mjz0b27dtnuqf+/n7k5uZiaGjI1I5kZmYGdrsdJSUlwr2o4+31ehEbG4s1a9YAMD5W0Rt9W1sbiAjV1dXw+/2611FRFMiyrH12IBBAXl4ekpOTkZ6ejpdeeslQTzcKBoMoKyvTPU+jx0DtLdqNP7hG27JlC4aGhm56/V0dK7Xe7ULll19+ARFh/fr1ePXVV+HxeLRLmCJB++CDD8LhcCAzMxM9PT1CdSRJwrVr1+DxeEBEyMrKwvj4uOFegLljrJ7p3fi6UeodeleuXBGuEb0TycnJATA/4d/c3KznrUsz/Ldt26Zn5W7rp59+QkpKClpbW00P+P79+5GRkQFA/NKIioiwYsUKUzWia6WlpZmu8/zzz4OI8PHHH5uu9fnnn8PlcqG8vBznz58XqqFutD6fTztqF92Qo++Geeedd7Bx40YAYteS/wvRl7REybKMc+fOafVE5/D169dRWVmJxsZG7TKPmd7uvfdeuFwu+P1+4Rrq55eWloKI0Nvba+q7VRQFfX19yMrKQkdHh3AdVTAYBBHh4MGDpmuFQiE4nU69P+wvrfBXB2o+wn96ehp2ux0NDQ2mawUCARCRqeBXA0g99Y5+zSh1cmdmZpq+g0i9NTI3NxfFxcVaSIjWmp6e1n4Efv31102FNgDU19ebHi+VoijCd20sVWbH7EbvvvuudknFzPYgyzIcDgfa2tpM1wKAuLg45Obmmqqh9pCdnY24uDhTtYDZ+dbU1ISqqirTB6DA7N115eXlet++tML/vvvuw8aNG+flHu3nnnsO999/P0ZHR01PrK6uLqSkpJjuCYB2d8b09LTpvrZv3w4iwqlTp0z35ff7QUTYsWOHqTpq2HR0dODkyZOm7pBii0t0gM3HNvrdd9/NSy1ZluflBghVYmLivPxW6Pf7UVBQoPsW3DspLS3FJ598ovftuvJ40Tze4cKFC1RdXU02m64H0t2Soii0fPlyevTRR8nlcpl+DkZ2djb99ddfpmoQzf5T8c2bNxMRUXJysqm+IpEI7dixg9asWWP6AXaKolBmZiY1NjbSG2+8YeoZKOpTUN1uNz3yyCPC/4ydLT7RjxUxu42Gw2HKz8+fl1oxMTHk9XqpqKhoXp4DJMsypaSkmK6TlpY2r/O/urqafvvtt3mrR0RkAxb0+SwL8mE+n48yMzMpLi7OdC0AdOnSJUpPTzc9USORCE1NTdGyZctM96Waj8dXM8Zmt3Wz2/gioWsl7srwX+z4Of+M3d0AkCRJ/9XZL4c/Y4xZkK7wX+j/yeuuOKdijLGlji8WM8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBf0PwIACOHRGfV4AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"rotating_image_classification(digit_one, sess2, prob2, X2, keep_prob2, u)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the Expected Cross Entropy (Eq. 4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we train neural network using the loss function described in Eq. 4 in the paper. This loss function is derived using the expected value of the cross entropy loss over the predicted Dirichlet distribution."
]
},
{
"cell_type": "code",
"execution_count": 219,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"id": "pzgZFuUx2vM-"
},
"outputs": [],
"source": [
"def loss_EDL(func=tf.digamma):\n",
" def loss_func(p, alpha, global_step, annealing_step): \n",
" S = tf.reduce_sum(alpha, axis=1, keep_dims=True) \n",
" E = alpha - 1\n",
" \n",
" A = tf.reduce_sum(p * (func(S) - func(alpha)), axis=1, keepdims=True)\n",
" \n",
" annealing_coef = tf.minimum(1.0, tf.cast(global_step/annealing_step,tf.float32))\n",
" \n",
" alp = E*(1-p) + 1 \n",
" B = annealing_coef * KL(alp)\n",
" \n",
" return (A + B)\n",
" return loss_func"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"g3, step3, X3, Y3, annealing_step3, keep_prob3, prob3, acc3, loss3, u3, evidence3, \\\n",
" mean_ev3, mean_ev_succ3, mean_ev_fail3 = LeNet_EDL(exp_evidence, loss_EDL(tf.digamma), lmb=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 211,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"id": "PQ2AqucY20f2"
},
"outputs": [],
"source": [
"sess3 = tf.Session(graph=g3)\n",
"with g3.as_default():\n",
" sess3.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 179
},
"colab_type": "code",
"executionInfo": {
"elapsed": 2222,
"status": "ok",
"timestamp": 1527923836119,
"user": {
"displayName": "Murat Sensoy",
"photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
"userId": "102692943223630372304"
},
"user_tz": -180
},
"id": "FDhvxNKN25VE",
"outputId": "b96b936e-955f-4338-faa3-ccd2f0a7e36b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 - 100%) training: 0.8303 (42.6223 - 16.4772) \t testing: 0.8412 (42.8285 - 16.6934)\n",
"epoch 2 - 100%) training: 0.8992 (123.3957 - 11.1870) \t testing: 0.9062 (122.5605 - 11.3092)\n",
"epoch 3 - 100%) training: 0.9206 (208.4110 - 9.0025) \t testing: 0.9272 (203.0646 - 9.2725)\n",
"epoch 4 - 100%) training: 0.9413 (264.3557 - 8.2034) \t testing: 0.9446 (259.2374 - 8.0735)\n",
"epoch 5 - 100%) training: 0.9485 (298.0337 - 6.7932) \t testing: 0.9540 (297.0534 - 6.4097)\n",
"epoch 6 - 100%) training: 0.9526 (300.2455 - 5.8476) \t testing: 0.9573 (301.2290 - 5.1179)\n",
"epoch 7 - 100%) training: 0.9595 (536.3224 - 6.1392) \t testing: 0.9631 (548.7296 - 5.8329)\n",
"epoch 8 - 100%) training: 0.9632 (616.2153 - 5.7967) \t testing: 0.9669 (642.6508 - 5.4167)\n",
"epoch 9 - 100%) training: 0.9664 (691.5225 - 5.5155) \t testing: 0.9705 (711.5176 - 4.6776)\n",
"epoch 10 - 100%) training: 0.9671 (743.4854 - 4.3620) \t testing: 0.9693 (765.5163 - 4.1878)\n",
"epoch 11 - 100%) training: 0.9695 (1386.9512 - 6.1451) \t testing: 0.9735 (1369.8014 - 5.8578)\n",
"epoch 12 - 100%) training: 0.9723 (1498.5973 - 6.5162) \t testing: 0.9747 (1543.9032 - 6.2582)\n",
"epoch 13 - 100%) training: 0.9717 (1548.7797 - 4.6337) \t testing: 0.9743 (1684.6056 - 4.7657)\n",
"epoch 14 - 100%) training: 0.9735 (1498.0391 - 5.4553) \t testing: 0.9757 (1613.8787 - 5.6297)\n",
"epoch 15 - 100%) training: 0.9759 (1493.4908 - 5.2220) \t testing: 0.9779 (1563.8809 - 4.5270)\n",
"epoch 16 - 100%) training: 0.9764 (2248.6760 - 5.7876) \t testing: 0.9777 (2307.0652 - 5.2200)\n",
"epoch 17 - 100%) training: 0.9769 (1995.6049 - 4.3072) \t testing: 0.9804 (2252.7351 - 4.1795)\n",
"epoch 18 - 100%) training: 0.9765 (2430.6807 - 4.4716) \t testing: 0.9795 (2640.8728 - 4.7834)\n",
"epoch 19 - 100%) training: 0.9776 (3172.2893 - 5.9542) \t testing: 0.9792 (3272.4224 - 5.0041)\n",
"epoch 20 - 100%) training: 0.9788 (3570.1475 - 4.9005) \t testing: 0.9812 (4264.9502 - 5.8081)\n",
"epoch 21 - 100%) training: 0.9787 (4876.1646 - 6.1165) \t testing: 0.9808 (5475.2358 - 5.8976)\n",
"epoch 22 - 100%) training: 0.9799 (3180.2634 - 3.9057) \t testing: 0.9813 (3317.7156 - 3.9012)\n",
"epoch 23 - 100%) training: 0.9806 (2328.6206 - 3.7536) \t testing: 0.9826 (2397.2944 - 3.6987)\n",
"epoch 24 - 100%) training: 0.9798 (4677.1987 - 4.4783) \t testing: 0.9823 (5023.0850 - 4.6507)\n",
"epoch 25 - 100%) training: 0.9811 (3927.5339 - 4.8530) \t testing: 0.9830 (4024.1538 - 5.2015)\n",
"epoch 26 - 100%) training: 0.9798 (4298.3862 - 5.0999) \t testing: 0.9831 (4730.0220 - 4.9904)\n",
"epoch 27 - 100%) training: 0.9811 (4639.9775 - 5.3606) \t testing: 0.9831 (4931.9014 - 5.6236)\n",
"epoch 28 - 100%) training: 0.9825 (4483.1514 - 4.3366) \t testing: 0.9844 (4708.0806 - 4.3383)\n",
"epoch 29 - 100%) training: 0.9829 (4930.3579 - 4.2995) \t testing: 0.9836 (5377.6284 - 4.9327)\n",
"epoch 30 - 100%) training: 0.9829 (6808.6553 - 3.7501) \t testing: 0.9822 (7229.0522 - 3.9000)\n",
"epoch 31 - 100%) training: 0.9829 (6299.5278 - 5.7005) \t testing: 0.9845 (7308.3398 - 5.1458)\n",
"epoch 32 - 100%) training: 0.9829 (6325.0591 - 5.7004) \t testing: 0.9854 (6570.1626 - 6.2992)\n",
"epoch 33 - 100%) training: 0.9834 (5909.4917 - 4.6970) \t testing: 0.9843 (6029.9980 - 5.9591)\n",
"epoch 34 - 100%) training: 0.9841 (7869.9731 - 5.8773) \t testing: 0.9849 (8628.2305 - 5.6427)\n",
"epoch 35 - 100%) training: 0.9836 (8436.3779 - 4.7038) \t testing: 0.9850 (9239.3271 - 6.6275)\n",
"epoch 36 - 100%) training: 0.9841 (5976.3433 - 3.9183) \t testing: 0.9854 (6832.0786 - 4.6664)\n",
"epoch 37 - 100%) training: 0.9835 (7943.0410 - 4.1992) \t testing: 0.9848 (9035.4375 - 5.9112)\n",
"epoch 38 - 100%) training: 0.9856 (7578.3315 - 4.6529) \t testing: 0.9869 (8479.9932 - 5.4112)\n",
"epoch 39 - 100%) training: 0.9851 (6393.1548 - 4.1753) \t testing: 0.9859 (7391.0679 - 4.8497)\n",
"epoch 40 - 100%) training: 0.9861 (7625.8906 - 4.5776) \t testing: 0.9862 (8205.8027 - 5.7368)\n",
"epoch 41 - 100%) training: 0.9859 (8331.8799 - 3.7653) \t testing: 0.9870 (9274.6816 - 4.7081)\n",
"epoch 42 - 100%) training: 0.9859 (13754.2227 - 4.9862) \t testing: 0.9875 (14675.8350 - 5.6302)\n",
"epoch 43 - 100%) training: 0.9855 (15424.7334 - 6.8568) \t testing: 0.9864 (16225.4473 - 7.8045)\n",
"epoch 44 - 100%) training: 0.9862 (12365.5068 - 4.6459) \t testing: 0.9866 (13331.6621 - 5.5387)\n",
"epoch 45 - 100%) training: 0.9861 (12153.3447 - 4.4586) \t testing: 0.9876 (14052.3027 - 6.7540)\n",
"epoch 46 - 100%) training: 0.9865 (17505.6484 - 4.8752) \t testing: 0.9870 (18024.4199 - 6.1390)\n",
"epoch 47 - 100%) training: 0.9859 (15554.4268 - 4.5205) \t testing: 0.9860 (17220.3281 - 5.6510)\n",
"epoch 48 - 100%) training: 0.9870 (12240.6279 - 3.7974) \t testing: 0.9873 (14007.5693 - 5.2733)\n",
"epoch 49 - 100%) training: 0.9867 (13552.8066 - 3.8983) \t testing: 0.9875 (15117.2031 - 4.9112)\n",
"epoch 50 - 100%) training: 0.9858 (13994.8203 - 3.8285) \t testing: 0.9854 (14965.4561 - 4.7617)\n"
]
}
],
"source": [
"bsize = 1000 #batch size\n",
"n_batches = mnist.train.num_examples // bsize\n",
"L3_train_acc1=[]\n",
"L3_train_ev_s=[]\n",
"L3_train_ev_f=[]\n",
"L3_test_acc1=[]\n",
"L3_test_ev_s=[]\n",
"L3_test_ev_f=[]\n",
"for epoch in range(50): \n",
" for i in range(n_batches):\n",
" data, label = mnist.train.next_batch(bsize)\n",
" feed_dict={X3:data, Y3:label, keep_prob3:.5, annealing_step3:10*n_batches}\n",
" sess3.run(step3,feed_dict)\n",
" print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\\r' if i"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_EDL_results(L3_train_acc1, L3_train_ev_s, L3_train_ev_f, L3_test_acc1, L3_test_ev_s, L3_test_ev_f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The figure above indicates that the neural network generates much more evidence for the correctly classified samples. As a result, it has a very low uncertainty (around zero) for the correctly classified samples, while the uncertainty is very high (around 0.7) for the misclassified samples."
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "yrZTPQ563PlZ"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAA1CAYAAACp8OvZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADTxJREFUeJzt3V1MHGUXB/Czy7J8G1pWQitWUk3TIIloGkIoNpC0EmKkkqgtCQWJRiBsmxKCXJDQxA80VWsN0WqRi4oGmoJKiL1zbdwUTWmxSmi1dRvshu62tWLBrmVndv7vBZl5l9qPmWcQoXN+SS+62Tl75tln/vOx06kNADHGGLMW+3/dAGOMsYXH4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbkWODPW5BnSSiKQna7uf2aLMvkcJgbnkgkQjExMaZqMMbuHpIkUWxs7L/9MTY9b7orj/ztdjtNTk6aquFwOCgUCpmqMR/BH4lETNdgjJknSZLpGgsQ/LotuvB/7bXX6ODBg6ZqnDx5kjZt2kTHjx8XDs9QKERJSUm0a9cu4T52795NNpuNjh8/LlwjJiaG9u7dS+Pj40QktjMIh8NzlpVlWbgfFT8QkC0ERVH+8zpXrlyhF198kS5evCj82cFgkGpra2nTpk00MzMj3Munn35q+qqGatGF/wMPPEAHDhwwVSMrK4tOnDhBfX19QiEFgBITEyk7O5va29uFelAUhZ599lmqq6ujhoYGCgaDQr1cvHiRDh8+TBUVFfThhx8KnU04nU4iInrllVdoeHiYHA6H0MYgSZK2A7HZbLwDYP8KRVG0uWW3202Fpd/vp8cff1woMNW53traSl1dXfT7778L9aAoCvX09NCXX35JpaWl2vZoFAA6cOAArVy5Umj5mxZcwD93VF5ejt7eXj1vvS0igsfjMVWjrq4Os0MkLhQKgYjw1VdfIRwOC9eprq6Gw+EAAMzMzAjVaGtrg8vlMjUuRIS4uDiUlpYCgOF1UnsfGRnBt99+K9wHm1+SJCESiUBRFMPLRiIRRCIR7e+nTp3C0aNH57ymtw4AjI+Po7i4GE6nE42NjYbrALPrk5WVBSJCRUWF4eUVRUEoFMILL7wAIsJDDz2kvW5UOBwGESE5OdnwstG2bdsGIkIwGLzTW3Xl8aIK/5GRETidToRCoTuPxC0oioLLly+DiDA6OgpZloXqyLKMiooKU+GvThQiQk1NjdDEUZfp7OwU7kUdg8uXL+O9997DM888A0VRhHZGg4ODyM3NndOLJEmGaoTDYSQmJuKee+5BX1+f9pqRfhRFgSRJCIVC8Pv92vpIkiQ0zkuFum6yLCMSiRge++g6N4b29PS0qd6++OILNDU1weVyGZ6r6nr19/ejqKgINHtzCM6cOSPUy6FDh0BEcLlcOHv2rOHlJUmC1+tFTk4OiouL4fP5hOfVxMQEiAh5eXnC3xcwG/42m03PW5de+Dc3N+tdudsaGxsDEaGlpUX4KBkAurq6kJiYiNHRUVNf2tNPP236DEKWZfT09ICI8NZbbwGA8JGVekQ0PDwsVAcAent7QUSIj4/H1atXARg7Krp69SrOnDkDm82m7Ry//vpr3ctLkoSamhosX74caWlpICK43W54vV5cunRJe5+edTtx4gR6e3uxefPmeTnrvJmpqSn8+eef/3j9dv2pO+3oOXzu3DmcP38eo6OjePvtt/HRRx+hs7MTIyMjuvpQa4VCIRw5cgT79u3DihUrtLCNj4/HH3/8ccc66vbg9/vhdru1A4LMzExs375dd+Cq6//DDz+gvb0ddrsdRARJkhAOh4XmZn5+PogI9fX1AIwfrUciEUxNTYGIkJCQgOvXrxvuQf3ciYkJlJWVoba2FoFAQGh9VDab7e4N/8rKynkJ/7Nnz4KIUFJSInypJRKJoLu7G6tWrTK9E6mrq0NSUhI6OzuFawDA8PCwtpGqPYrYu3cv0tPTkZqaKtxLOBzWjtBWr16NgYEB3ctGn411dHRg3bp12np1d3djampK187W5/NhZGQEHo8HVVVVWo2YmBh0dHRgYmJCVz8JCQnahmWz2ZCenq792bBhA44dO6b7DHJ4eBgtLS3weDwYHR1FS0sLSkpKUFhYiNraWnR2dqKhoQH9/f0IBAK6agLQahUUFCA/Px+5ubnIyMgAESElJUW7HHCzOaGGX/RZZGFhoTZeashlZGTA4/EYOlvesGGDVuPJJ5+cs+PVQ/2eCwsLkZCQgPLycuzevRuA8UuKADAwMAAiQmxsrPCZTDgcRnd395xLRiJH/bIso6OjAytXrjS0fdzKXR3+69evNx3+0ZdaCgsLTdUCgKqqKhQUFMypLdJPQ0MDHnvsMRw7dkyoD3WDbGpqAhFhcHBQqI5qcnISDocD7e3tAMR3JKdPn0ZycjKICL/++qtwrcHBQezcuVMLEq/XCwC3DSL1GrU6xn///TfefPNN1NbWanWKioqwZcsWQ718//33aG5uRmVlpTYn1TMUm82G7u5u7Wwn2tTUFHJyclBfX4+tW7fi4Ycf1sL6qaee0kI7Pj4eSUlJc3bkN65nMBiEw+HQ3pOYmIhVq1ahqqoKra2t+Oyzz7B//354vV50dXXhwoULt5yfgUAAhw8fRk1NDYgISUlJWLduneE5FN2j2+0GEWHt2rXw+Xza60YCe2ZmBoFAAGvXrtUOsgDjlxHVZVpaWkBEePnllwGIzUO1fyLCzp07DS+vUr+LmJgYOJ3OOa+JstlsqKys1PPWpRf+BvZst6ROUCJCbW2tqVoA0Nraivz8fNN1BgcHQUQ4cuSIcNBKkgSfz4fU1FS43W4AYhNcHaOysjKkpaXh6NGjQrXUybxnzx7ExsaivLwchw4dMtxP9EZRUVGBhIQE5OTk6Plh65YGBgbm7AT27Nlj6EhbNTMzg7GxMWzduhWpqanaHL3xwCL68sXk5CSA2d9YxsbG5lwC+fnnnzE9PY33338fdXV1aGxs/MdnXrt2DXl5ecjOzobb7UZFRQW6urrQ3d1tuH8AyMvLQ0JCAogITqcTdXV1+OabbwD8f+zvFLjq+3788UdtLjudTnzwwQe6lr+VJ554QrtEMzk5KVRHDez4+HgQEU6fPi3US3QtswdY6noQEZYtWwbg9gcyenD430H0oO/atcvUNTYAGBoawurVqwGI77mjz0b27dtnuqf+/n7k5uZiaGjI1I5kZmYGdrsdJSUlwr2o4+31ehEbG4s1a9YAMD5W0Rt9W1sbiAjV1dXw+/2611FRFMiyrH12IBBAXl4ekpOTkZ6ejpdeeslQTzcKBoMoKyvTPU+jx0DtLdqNP7hG27JlC4aGhm56/V0dK7Xe7ULll19+ARFh/fr1ePXVV+HxeLRLmCJB++CDD8LhcCAzMxM9PT1CdSRJwrVr1+DxeEBEyMrKwvj4uOFegLljrJ7p3fi6UeodeleuXBGuEb0TycnJATA/4d/c3KznrUsz/Ldt26Zn5W7rp59+QkpKClpbW00P+P79+5GRkQFA/NKIioiwYsUKUzWia6WlpZmu8/zzz4OI8PHHH5uu9fnnn8PlcqG8vBznz58XqqFutD6fTztqF92Qo++Geeedd7Bx40YAYteS/wvRl7REybKMc+fOafVE5/D169dRWVmJxsZG7TKPmd7uvfdeuFwu+P1+4Rrq55eWloKI0Nvba+q7VRQFfX19yMrKQkdHh3AdVTAYBBHh4MGDpmuFQiE4nU69P+wvrfBXB2o+wn96ehp2ux0NDQ2mawUCARCRqeBXA0g99Y5+zSh1cmdmZpq+g0i9NTI3NxfFxcVaSIjWmp6e1n4Efv31102FNgDU19ebHi+VoijCd20sVWbH7EbvvvuudknFzPYgyzIcDgfa2tpM1wKAuLg45Obmmqqh9pCdnY24uDhTtYDZ+dbU1ISqqirTB6DA7N115eXlet++tML/vvvuw8aNG+flHu3nnnsO999/P0ZHR01PrK6uLqSkpJjuCYB2d8b09LTpvrZv3w4iwqlTp0z35ff7QUTYsWOHqTpq2HR0dODkyZOm7pBii0t0gM3HNvrdd9/NSy1ZluflBghVYmLivPxW6Pf7UVBQoPsW3DspLS3FJ598ovftuvJ40Tze4cKFC1RdXU02m64H0t2Soii0fPlyevTRR8nlcpl+DkZ2djb99ddfpmoQzf5T8c2bNxMRUXJysqm+IpEI7dixg9asWWP6AXaKolBmZiY1NjbSG2+8YeoZKOpTUN1uNz3yyCPC/4ydLT7RjxUxu42Gw2HKz8+fl1oxMTHk9XqpqKhoXp4DJMsypaSkmK6TlpY2r/O/urqafvvtt3mrR0RkAxb0+SwL8mE+n48yMzMpLi7OdC0AdOnSJUpPTzc9USORCE1NTdGyZctM96Waj8dXM8Zmt3Wz2/gioWsl7srwX+z4Of+M3d0AkCRJ/9XZL4c/Y4xZkK7wX+j/yeuuOKdijLGlji8WM8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBf0PwIACOHRGfV4AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"rotating_image_classification(digit_one, sess3, prob3, X3, keep_prob3, u3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"collapsed": true,
"id": "x1G0RMxw3SWj"
},
"source": [
"## Using Negative Log of the Expected Likelihood (Eq. 3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we repeat our experiments using the loss function based on Eq. 3 in the paper."
]
},
{
"cell_type": "code",
"execution_count": 221,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 454
},
"colab_type": "code",
"collapsed": true,
"id": "qZeZ8M2-3U2o",
"outputId": "ba07af58-3193-44a1-bcfe-e0a699affbdc"
},
"outputs": [],
"source": [
"g4, step4, X4, Y4, annealing_step4, keep_prob4, prob4, acc4, loss4, u4, evidence4, \\\n",
" mean_ev4, mean_ev_succ4, mean_ev_fail4 = LeNet_EDL(exp_evidence, loss_EDL(tf.log), lmb=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 225,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 1291
},
"colab_type": "code",
"collapsed": true,
"executionInfo": {
"elapsed": 460,
"status": "ok",
"timestamp": 1527923232289,
"user": {
"displayName": "Murat Sensoy",
"photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
"userId": "102692943223630372304"
},
"user_tz": -180
},
"id": "aVltdhRR5dNG",
"outputId": "fab9365b-f64b-4e47-de90-e48e6eda4aab"
},
"outputs": [],
"source": [
"sess4 = tf.Session(graph=g4)\n",
"with g4.as_default():\n",
" sess4.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 226,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "ZnmB0--c351F"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1 - 100%) training: 0.8389 (47.7661 - 14.2396) \t testing: 0.8481 (48.1106 - 14.5359)\n",
"epoch 2 - 100%) training: 0.9120 (98.9960 - 8.6988) \t testing: 0.9194 (99.6993 - 8.8808)\n",
"epoch 3 - 100%) training: 0.9304 (180.1434 - 7.9925) \t testing: 0.9390 (182.9537 - 7.6880)\n",
"epoch 4 - 100%) training: 0.9454 (357.4596 - 8.9753) \t testing: 0.9515 (353.7061 - 7.7063)\n",
"epoch 5 - 100%) training: 0.9546 (374.0661 - 6.4429) \t testing: 0.9585 (369.2362 - 6.1300)\n",
"epoch 6 - 100%) training: 0.9579 (717.8140 - 7.7555) \t testing: 0.9627 (698.8102 - 7.6518)\n",
"epoch 7 - 100%) training: 0.9617 (820.1666 - 5.8483) \t testing: 0.9646 (819.3505 - 5.0063)\n",
"epoch 8 - 100%) training: 0.9646 (685.4586 - 4.7092) \t testing: 0.9682 (700.0871 - 4.2215)\n",
"epoch 9 - 100%) training: 0.9676 (864.0211 - 4.7806) \t testing: 0.9700 (882.0850 - 4.2597)\n",
"epoch 10 - 100%) training: 0.9699 (1145.9790 - 4.9029) \t testing: 0.9728 (1166.5842 - 4.3104)\n",
"epoch 11 - 100%) training: 0.9698 (1265.3774 - 3.9921) \t testing: 0.9730 (1349.9486 - 3.4476)\n",
"epoch 12 - 100%) training: 0.9730 (1586.4016 - 5.1049) \t testing: 0.9747 (1696.7997 - 4.9364)\n",
"epoch 13 - 100%) training: 0.9735 (2014.4788 - 5.1080) \t testing: 0.9763 (2115.8186 - 4.7999)\n",
"epoch 14 - 100%) training: 0.9735 (2741.4673 - 4.3296) \t testing: 0.9752 (2957.9802 - 3.9254)\n",
"epoch 15 - 100%) training: 0.9752 (2673.9426 - 4.4678) \t testing: 0.9772 (2692.9707 - 4.9921)\n",
"epoch 16 - 100%) training: 0.9768 (2388.1035 - 4.0882) \t testing: 0.9781 (2634.2166 - 3.7764)\n",
"epoch 17 - 100%) training: 0.9764 (2701.2002 - 4.6162) \t testing: 0.9791 (3023.1316 - 5.1292)\n",
"epoch 18 - 100%) training: 0.9773 (2878.8640 - 4.0546) \t testing: 0.9792 (3105.1746 - 3.6160)\n",
"epoch 19 - 100%) training: 0.9768 (3326.2048 - 4.1775) \t testing: 0.9787 (3591.4688 - 4.3881)\n",
"epoch 20 - 100%) training: 0.9788 (3257.1694 - 4.1882) \t testing: 0.9791 (3451.4497 - 3.3292)\n",
"epoch 21 - 100%) training: 0.9788 (4058.1035 - 4.2671) \t testing: 0.9801 (4339.0176 - 4.3674)\n",
"epoch 22 - 100%) training: 0.9795 (4905.1646 - 5.0746) \t testing: 0.9806 (5299.4648 - 4.9685)\n",
"epoch 23 - 100%) training: 0.9789 (4146.0679 - 4.8996) \t testing: 0.9794 (4280.2456 - 4.3978)\n",
"epoch 24 - 100%) training: 0.9789 (5102.7090 - 5.0765) \t testing: 0.9811 (5573.3687 - 4.9651)\n",
"epoch 25 - 100%) training: 0.9797 (4721.9268 - 3.4639) \t testing: 0.9824 (4908.5527 - 3.7816)\n",
"epoch 26 - 100%) training: 0.9794 (4624.8179 - 3.5023) \t testing: 0.9810 (4622.0835 - 4.1510)\n",
"epoch 27 - 100%) training: 0.9803 (7247.1953 - 4.8597) \t testing: 0.9823 (7369.5410 - 5.0344)\n",
"epoch 28 - 100%) training: 0.9820 (6480.2974 - 4.0996) \t testing: 0.9825 (7157.7944 - 4.7370)\n",
"epoch 29 - 100%) training: 0.9813 (7673.8716 - 4.6725) \t testing: 0.9817 (7892.9888 - 4.8215)\n",
"epoch 30 - 100%) training: 0.9819 (7318.6362 - 4.3854) \t testing: 0.9841 (7933.3677 - 4.6528)\n",
"epoch 31 - 100%) training: 0.9825 (8063.1187 - 4.8635) \t testing: 0.9827 (8506.7451 - 4.2365)\n",
"epoch 32 - 100%) training: 0.9816 (6621.0513 - 3.1235) \t testing: 0.9819 (7588.6938 - 3.7367)\n",
"epoch 33 - 100%) training: 0.9826 (10533.4658 - 4.9056) \t testing: 0.9843 (11875.2090 - 5.4827)\n",
"epoch 34 - 100%) training: 0.9833 (10538.7793 - 4.8260) \t testing: 0.9830 (11016.7715 - 4.3061)\n",
"epoch 35 - 100%) training: 0.9830 (9445.5898 - 4.3326) \t testing: 0.9830 (10311.8994 - 4.2545)\n",
"epoch 36 - 100%) training: 0.9824 (9012.2568 - 3.6567) \t testing: 0.9831 (10201.7939 - 3.7273)\n",
"epoch 37 - 100%) training: 0.9838 (8231.3916 - 3.4212) \t testing: 0.9848 (9213.5693 - 4.3222)\n",
"epoch 38 - 100%) training: 0.9835 (9698.7676 - 3.6955) \t testing: 0.9844 (12237.3604 - 4.4046)\n",
"epoch 39 - 100%) training: 0.9813 (12806.3682 - 3.5367) \t testing: 0.9833 (13110.8662 - 4.3908)\n",
"epoch 40 - 100%) training: 0.9834 (12758.0078 - 3.7434) \t testing: 0.9852 (14355.3516 - 5.7558)\n",
"epoch 41 - 100%) training: 0.9841 (18760.6660 - 4.6124) \t testing: 0.9847 (18847.6367 - 4.9546)\n",
"epoch 42 - 100%) training: 0.9844 (14055.1133 - 4.2950) \t testing: 0.9850 (15240.1768 - 5.1640)\n",
"epoch 43 - 100%) training: 0.9839 (17531.1875 - 4.8606) \t testing: 0.9856 (20936.6172 - 6.0384)\n",
"epoch 44 - 100%) training: 0.9842 (11528.1709 - 4.0646) \t testing: 0.9857 (12673.5449 - 4.8653)\n",
"epoch 45 - 100%) training: 0.9838 (13236.9697 - 4.4013) \t testing: 0.9848 (14658.7588 - 6.3498)\n",
"epoch 46 - 100%) training: 0.9841 (16241.7793 - 3.5962) \t testing: 0.9871 (18014.9746 - 4.5276)\n",
"epoch 47 - 100%) training: 0.9851 (15670.2637 - 5.8087) \t testing: 0.9855 (17604.8535 - 5.6529)\n",
"epoch 48 - 100%) training: 0.9855 (14058.3115 - 4.6745) \t testing: 0.9865 (15463.1846 - 5.6835)\n",
"epoch 49 - 100%) training: 0.9854 (16627.9102 - 4.3638) \t testing: 0.9851 (18772.2539 - 4.5671)\n",
"epoch 50 - 100%) training: 0.9860 (18039.0078 - 3.9769) \t testing: 0.9875 (20124.0410 - 4.8621)\n"
]
}
],
"source": [
"bsize = 1000 #batch size\n",
"n_batches = mnist.train.num_examples // bsize\n",
"L4_train_acc1=[]\n",
"L4_train_ev_s=[]\n",
"L4_train_ev_f=[]\n",
"L4_test_acc1=[]\n",
"L4_test_ev_s=[]\n",
"L4_test_ev_f=[]\n",
"for epoch in range(50): \n",
" for i in range(n_batches):\n",
" data, label = mnist.train.next_batch(bsize)\n",
" feed_dict={X4:data, Y4:label, keep_prob4:.5, annealing_step4:10*n_batches}\n",
" sess4.run(step4,feed_dict)\n",
" print('epoch %d - %d%%) '% (epoch+1, (100*(i+1))//n_batches), end='\\r' if i"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_EDL_results(L4_train_acc1, L4_train_ev_s, L4_train_ev_f, L4_test_acc1, L4_test_ev_s, L4_test_ev_f)"
]
},
{
"cell_type": "code",
"execution_count": 228,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAA1CAYAAACp8OvZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADTxJREFUeJzt3V1MHGUXB/Czy7J8G1pWQitWUk3TIIloGkIoNpC0EmKkkqgtCQWJRiBsmxKCXJDQxA80VWsN0WqRi4oGmoJKiL1zbdwUTWmxSmi1dRvshu62tWLBrmVndv7vBZl5l9qPmWcQoXN+SS+62Tl75tln/vOx06kNADHGGLMW+3/dAGOMsYXH4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbE4c8YYxbkWODPW5BnSSiKQna7uf2aLMvkcJgbnkgkQjExMaZqMMbuHpIkUWxs7L/9MTY9b7orj/ztdjtNTk6aquFwOCgUCpmqMR/BH4lETNdgjJknSZLpGgsQ/LotuvB/7bXX6ODBg6ZqnDx5kjZt2kTHjx8XDs9QKERJSUm0a9cu4T52795NNpuNjh8/LlwjJiaG9u7dS+Pj40QktjMIh8NzlpVlWbgfFT8QkC0ERVH+8zpXrlyhF198kS5evCj82cFgkGpra2nTpk00MzMj3Munn35q+qqGatGF/wMPPEAHDhwwVSMrK4tOnDhBfX19QiEFgBITEyk7O5va29uFelAUhZ599lmqq6ujhoYGCgaDQr1cvHiRDh8+TBUVFfThhx8KnU04nU4iInrllVdoeHiYHA6H0MYgSZK2A7HZbLwDYP8KRVG0uWW3202Fpd/vp8cff1woMNW53traSl1dXfT7778L9aAoCvX09NCXX35JpaWl2vZoFAA6cOAArVy5Umj5mxZcwD93VF5ejt7eXj1vvS0igsfjMVWjrq4Os0MkLhQKgYjw1VdfIRwOC9eprq6Gw+EAAMzMzAjVaGtrg8vlMjUuRIS4uDiUlpYCgOF1UnsfGRnBt99+K9wHm1+SJCESiUBRFMPLRiIRRCIR7e+nTp3C0aNH57ymtw4AjI+Po7i4GE6nE42NjYbrALPrk5WVBSJCRUWF4eUVRUEoFMILL7wAIsJDDz2kvW5UOBwGESE5OdnwstG2bdsGIkIwGLzTW3Xl8aIK/5GRETidToRCoTuPxC0oioLLly+DiDA6OgpZloXqyLKMiooKU+GvThQiQk1NjdDEUZfp7OwU7kUdg8uXL+O9997DM888A0VRhHZGg4ODyM3NndOLJEmGaoTDYSQmJuKee+5BX1+f9pqRfhRFgSRJCIVC8Pv92vpIkiQ0zkuFum6yLCMSiRge++g6N4b29PS0qd6++OILNDU1weVyGZ6r6nr19/ejqKgINHtzCM6cOSPUy6FDh0BEcLlcOHv2rOHlJUmC1+tFTk4OiouL4fP5hOfVxMQEiAh5eXnC3xcwG/42m03PW5de+Dc3N+tdudsaGxsDEaGlpUX4KBkAurq6kJiYiNHRUVNf2tNPP236DEKWZfT09ICI8NZbbwGA8JGVekQ0PDwsVAcAent7QUSIj4/H1atXARg7Krp69SrOnDkDm82m7Ry//vpr3ctLkoSamhosX74caWlpICK43W54vV5cunRJe5+edTtx4gR6e3uxefPmeTnrvJmpqSn8+eef/3j9dv2pO+3oOXzu3DmcP38eo6OjePvtt/HRRx+hs7MTIyMjuvpQa4VCIRw5cgT79u3DihUrtLCNj4/HH3/8ccc66vbg9/vhdru1A4LMzExs375dd+Cq6//DDz+gvb0ddrsdRARJkhAOh4XmZn5+PogI9fX1AIwfrUciEUxNTYGIkJCQgOvXrxvuQf3ciYkJlJWVoba2FoFAQGh9VDab7e4N/8rKynkJ/7Nnz4KIUFJSInypJRKJoLu7G6tWrTK9E6mrq0NSUhI6OzuFawDA8PCwtpGqPYrYu3cv0tPTkZqaKtxLOBzWjtBWr16NgYEB3ctGn411dHRg3bp12np1d3djampK187W5/NhZGQEHo8HVVVVWo2YmBh0dHRgYmJCVz8JCQnahmWz2ZCenq792bBhA44dO6b7DHJ4eBgtLS3weDwYHR1FS0sLSkpKUFhYiNraWnR2dqKhoQH9/f0IBAK6agLQahUUFCA/Px+5ubnIyMgAESElJUW7HHCzOaGGX/RZZGFhoTZeashlZGTA4/EYOlvesGGDVuPJJ5+cs+PVQ/2eCwsLkZCQgPLycuzevRuA8UuKADAwMAAiQmxsrPCZTDgcRnd395xLRiJH/bIso6OjAytXrjS0fdzKXR3+69evNx3+0ZdaCgsLTdUCgKqqKhQUFMypLdJPQ0MDHnvsMRw7dkyoD3WDbGpqAhFhcHBQqI5qcnISDocD7e3tAMR3JKdPn0ZycjKICL/++qtwrcHBQezcuVMLEq/XCwC3DSL1GrU6xn///TfefPNN1NbWanWKioqwZcsWQ718//33aG5uRmVlpTYn1TMUm82G7u5u7Wwn2tTUFHJyclBfX4+tW7fi4Ycf1sL6qaee0kI7Pj4eSUlJc3bkN65nMBiEw+HQ3pOYmIhVq1ahqqoKra2t+Oyzz7B//354vV50dXXhwoULt5yfgUAAhw8fRk1NDYgISUlJWLduneE5FN2j2+0GEWHt2rXw+Xza60YCe2ZmBoFAAGvXrtUOsgDjlxHVZVpaWkBEePnllwGIzUO1fyLCzp07DS+vUr+LmJgYOJ3OOa+JstlsqKys1PPWpRf+BvZst6ROUCJCbW2tqVoA0Nraivz8fNN1BgcHQUQ4cuSIcNBKkgSfz4fU1FS43W4AYhNcHaOysjKkpaXh6NGjQrXUybxnzx7ExsaivLwchw4dMtxP9EZRUVGBhIQE5OTk6Plh65YGBgbm7AT27Nlj6EhbNTMzg7GxMWzduhWpqanaHL3xwCL68sXk5CSA2d9YxsbG5lwC+fnnnzE9PY33338fdXV1aGxs/MdnXrt2DXl5ecjOzobb7UZFRQW6urrQ3d1tuH8AyMvLQ0JCAogITqcTdXV1+OabbwD8f+zvFLjq+3788UdtLjudTnzwwQe6lr+VJ554QrtEMzk5KVRHDez4+HgQEU6fPi3US3QtswdY6noQEZYtWwbg9gcyenD430H0oO/atcvUNTYAGBoawurVqwGI77mjz0b27dtnuqf+/n7k5uZiaGjI1I5kZmYGdrsdJSUlwr2o4+31ehEbG4s1a9YAMD5W0Rt9W1sbiAjV1dXw+/2611FRFMiyrH12IBBAXl4ekpOTkZ6ejpdeeslQTzcKBoMoKyvTPU+jx0DtLdqNP7hG27JlC4aGhm56/V0dK7Xe7ULll19+ARFh/fr1ePXVV+HxeLRLmCJB++CDD8LhcCAzMxM9PT1CdSRJwrVr1+DxeEBEyMrKwvj4uOFegLljrJ7p3fi6UeodeleuXBGuEb0TycnJATA/4d/c3KznrUsz/Ldt26Zn5W7rp59+QkpKClpbW00P+P79+5GRkQFA/NKIioiwYsUKUzWia6WlpZmu8/zzz4OI8PHHH5uu9fnnn8PlcqG8vBznz58XqqFutD6fTztqF92Qo++Geeedd7Bx40YAYteS/wvRl7REybKMc+fOafVE5/D169dRWVmJxsZG7TKPmd7uvfdeuFwu+P1+4Rrq55eWloKI0Nvba+q7VRQFfX19yMrKQkdHh3AdVTAYBBHh4MGDpmuFQiE4nU69P+wvrfBXB2o+wn96ehp2ux0NDQ2mawUCARCRqeBXA0g99Y5+zSh1cmdmZpq+g0i9NTI3NxfFxcVaSIjWmp6e1n4Efv31102FNgDU19ebHi+VoijCd20sVWbH7EbvvvuudknFzPYgyzIcDgfa2tpM1wKAuLg45Obmmqqh9pCdnY24uDhTtYDZ+dbU1ISqqirTB6DA7N115eXlet++tML/vvvuw8aNG+flHu3nnnsO999/P0ZHR01PrK6uLqSkpJjuCYB2d8b09LTpvrZv3w4iwqlTp0z35ff7QUTYsWOHqTpq2HR0dODkyZOm7pBii0t0gM3HNvrdd9/NSy1ZluflBghVYmLivPxW6Pf7UVBQoPsW3DspLS3FJ598ovftuvJ40Tze4cKFC1RdXU02m64H0t2Soii0fPlyevTRR8nlcpl+DkZ2djb99ddfpmoQzf5T8c2bNxMRUXJysqm+IpEI7dixg9asWWP6AXaKolBmZiY1NjbSG2+8YeoZKOpTUN1uNz3yyCPC/4ydLT7RjxUxu42Gw2HKz8+fl1oxMTHk9XqpqKhoXp4DJMsypaSkmK6TlpY2r/O/urqafvvtt3mrR0RkAxb0+SwL8mE+n48yMzMpLi7OdC0AdOnSJUpPTzc9USORCE1NTdGyZctM96Waj8dXM8Zmt3Wz2/gioWsl7srwX+z4Of+M3d0AkCRJ/9XZL4c/Y4xZkK7wX+j/yeuuOKdijLGlji8WM8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBXH4M8aYBf0PwIACOHRGfV4AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"rotating_image_classification(digit_one, sess4, prob4, X4, keep_prob4, u4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Some Other Data Uncertainty Experiments"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider the case that we mix two digits from the MNIST dataset and query a classifier trained on MNIST dataset to classify it. For example, the following image is created by overlaying digit 0 with digit 6. The resulting image have similarities to both digits but neither 0 nor 6."
]
},
{
"cell_type": "code",
"execution_count": 317,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAACFCAYAAABL2gNbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAErtJREFUeJzt3Xt0VdWdB/DvLzcJgQQCIUAxBAgQAiioGB8oWhStSjviWo5dPnCQUukssFKrXWLtOKvLqWOnykyt7SgWRUdHsVoU0dElNL6qIIoSBYQgDwkSIEQghEdee/7gevb5Xb1JvLn33Ht3vp+1WPx2fif3bPJLNif77rOPGGNARETpLyPZHSAiovjggE5E5AgO6EREjuCATkTkCA7oRESO4IBOROQIDuhERI7o1IAuIpeIyEYR2Swi8+LVKUou1tVdrK3bJNYbi0QkBGATgIsAVANYDeBqY8z6+HWPgsa6uou1dV9mJz73DACbjTFbAEBEngYwFUDUb45s6WZykNuJU1I8HEUDGs0xiZJmXdNUO3UFvmVtWdfUUY8va40x/do7rjMDehGAHb52NYAz2/qEHOTiTJnciVNSPKwyK9pKs65pqp26At+ytqxr6lhunt3ekeM6M6B3iIjMAjALAHLQI9Gno4Cwrm5iXdNbZ94U3Qmg2NceFP6YYoxZYIwpN8aUZ6FbJ05HAWFd3dVubVnX9NaZAX01gFIRKRGRbABXAVgan25RErGu7mJtHRfzlIsxpllEbgTwKoAQgEeMMevi1jNKCtbVXayt+zo1h26MeRnAy3HqC6UI1tVdrK3beKcoEZEjOKATETmCAzoRkSM4oBMROSLhNxYRpYqWSeNVO/PO3V78Yplevbe75YhqXz/9Ji8OVaxJQO8oVpklQ1T7i+8XefHB4a0ql9Gkd0YY+eAuL27esi3+nQsYr9CJiBzBAZ2IyBGccgk7+g9neHH3/9O/UpvyMV689TK9+9y5F3ys2m/9bWzUcwx8t8WLc158L6Z+Uuzyfq13MFg8YpkXt0Ycu71Z72Oyd66dgvlORdy7Rt9SaHSpF1d/T29CWH/aUS8uH6b3tLqg4FPV/m3BpV48ZIl+nR5b9ntxy4aq2DsbIF6hExE5ggM6EZEjOKATETmiS82hhwr7enHL4u4q93TpfC/e3ZKlcvkZr3vx4Mx29oie/mbU1J5ph734i/uzVe4nd8/14r4Pv9v2OajDGq6wz2+4a/ADUY8btXSOapf9uUG1C/pxb/CgZfSwX/OaGaeo3JHv1nvxM2fMV7khmfaxmnfsmqRy/7H6YtXO7tnoxTsu1DXOaCr0YmkuVLmSpb7vj5WV39T9pOAVOhGRIzigExE5oktNuWz6/WAv3jhqYUTW/rrVP6Qzf9o/0ovX1A9WueqG3lHPFxK9GO6lshejnmPxr37nxf+84UaVy3j7o6jnIC2z6ATVfuDe+714dLa+frlk/T96cdlN+mtsmhpVW0+QURDqrjjZi/ePbVa51RMe9OLCkF5KPHvnWV781mJ9d3C+vgFYi3i8trrLVM/CYsvl9pwja/SY0Lzt8zZOkli8QicicgQHdCIiR3BAJyJyhNNz6GbCyaq9+OyHfC39T3/liJ1Dv+cX01Wu57pa29hbp3IZX+6Ifv4MPVE+8r7ZXrz+h39QueFZeV585FcHVS7/+gFe3FyzGxRd7QV6573IeXPlvv5eaJq2Rz+OApFZPEi1vxxt4z4nHFC5A612aeJ5q6apXPG9tuaDPtLbeLQePYqoRE+i511utwP54lyda82x5991SZHKDfifffa4Br38NdF4hU5E5AgO6EREjnB6yqUpXy82OyXb/nNbYVTuF4/+yIuLl7yjci2IUav+zBE3r/Ti0dl6aWLl1N978Rtjn1W5cy60UzX5T3DKpS17JurlbRm+a5bzP75S5XJfWR1In6hjTPduqt3S3f6MTjxhi8rN3WprmbO8p36hlfZO68hdNNvugB4TeixZ5cUDQ2eqXM0EOwVzoEx/Xv+RQ23jw3Xfpgedxit0IiJHcEAnInIEB3QiIkc4PYfekiNRc+PeuV61B//mnW8+MEFK56xS7WUXDvTiK/P2qdz+y+zSp/wnEtuvdJQ5pNiL5058TeVafbOo5pH+0LaAUkimXubb2sO+B1VztJfKbVhjl6eOfOdL/XkJ6Frus/rnNevEs724sbc+Y93Jtq99PkxAZ9rAK3QiIke0O6CLyCMiskdEPvF9rEBEXhORqvDffRLbTYo31tVdrG3X1ZEpl0UAHgDwuO9j8wCsMMbcIyLzwu3b4t+9zim7PfqSodAHPaPmkuGO1Zd78ZXn650g55xoH5qxDHH7OVyENK1rpM032DsMn+/9vMptbba7Jnbfq3dQdNgipGFtqy/VD5E4/cSNXnywMUflMhrtdGrGIX33ZyKmXCLlf2bPsvc0nTtUZPsW9P+a7V6hG2PeBFAX8eGpAB4Lx48BuByUVlhXd7G2XVesb4oOMMbsCsc1AAZEO1BEZgGYBQA54GO8Uhzr6q4O1ZZ1TW+dflPUGGOAiNsudX6BMabcGFOehW7RDqMUw7q6q63asq7pLdYr9N0iMtAYs0tEBgLYE89OdUbGuFFePKm3XsK2qcnOtRVWNgXWp47o84ZvjvD8pHUjZevaFhl5KGruqf2ne3GoYk3U47qAlKxt5nfsLwpNeTpX3N0uR3zpixNVLvuAnadu3rItIX1rS/5m+8D3vaelzm8ysV6hLwXw1R6z0wG8EJ/uUJKxru5ibbuAjixbfArAuwDKRKRaRGYCuAfARSJSBeDCcJvSCOvqLta262p3ysUYc3WU1OQ49yUuqqbbhzZflbdX5SZWXufFvV7u2jvtpVtd2/LgaU9Gzf3lfyd5cRGCvRs4WdKptvVn2Ts+I++4zMyw7aad+kHQRauOJbZjaYp3ihIROYIDOhGRIzigExE5wrndFm++9CUv9i9TBIDsP/b1tT4LqEeUaBli51qzRO/YN+hi+/DnjSVnqFzZyJ1e/GLZUpWLfJ0mY3f+W9JQoHL/9uC1XnzC/e+pnGnWT1AirfYkOwSFjuml8X+pHO/FEc9vRt0Yu0a+R+FZKne4v71OPTg8YiOAyA1YfafMPqCvb4tet+NH6I2At02MEa/QiYgcwQGdiMgRzk25+D207zzVzln2XpQjKZ21Gntd0mT0HcAvlPl2Xyxr4zUi2vPrSlV7Th+789/U3FqVm3qLfcD3uL43qdzQO94FdUy3fXo+RJrttEpzrp6OOTis1RdHvpKvmhEbHOTs1dewRwvtsY299HfB1svsQ+aLc8sj+pqayyZ5hU5E5AgO6EREjuCATkTkiLSfQw/1zlftnhnVSeoJpaLtvicW3brtCpXbsdhOvvao1fOnvZZVqvZzP/ieF7fO0FtKvD7uaS+uvP5+lTt7u51TL1zA+XTpprfkNTFeUvqfWJS7U8+9919jd0IM1euly2b9Zt2fMSO8eM8E/Xyh/WV2An7HZL2Mtfen/q0Iou4yHTheoRMROYIDOhGRIzigExE5Iu3n0Ktn6ieZXNuzwovXNAwNuDexOzblQNTc4dbsqDkCZrz6Yy/eNPW/Ve77f5/jxcOu+Ujl+qMm6mtGrkvPe2alF4dW6Fv/H31rqO1L/jaVOzTYxvqZ9l1T8wT989rYx36lM4/o68vCtTZXc7aeJy9Yb+Pej0ffFjmyjpFM5ade3H9TjsodudVuPXCsQL9SXbnd0iGjQc+vS3snTSBeoRMROYIDOhGRI9J+yiVdNV9wmmo/feoDvpZe2rXkt/ZBM/lYCdKy60JRc9NOtNs9vIP4TF217KtT7fmVtj4zzl0Yl3MQkNXgn7vQNT5UZKdgeiM+Wo/qJY55O+xyxKYxOjdigN3+YXPlIJUrfsVOnwa9oJFX6EREjuCATkTkCA7oRESO4Bx6gPzz5nVzG1RuVJadN5+98xyV6714jRenzk3GqSPzsJ1PjXzSUM+QnfvM6KFnW1sPH0YsWiaNV+3Hz3jYnoPXSHGT0eibQ4+oq//JQ5Kph7FYnxKVWTJEtfeeZ7di/um4N1Vu5f4SL86t1jU3az9FsvC7j4jIERzQiYgckfZTLr22taj2tubYfo1OhMhfBfffXO/F749/WuVeO9Ldizf9i76bLrvp/QT0zh2D7rZ3Co4ZP03lPpzwqBc/uPBclRs+a4sXt9bXoy2h0fYJRqW/W6dyp3azUwORNwnm1EY+lbhry6rVP58Zx7pFORLIWv6BFxcUT1C5urH2K/3lNaerXJ+n7OeZpka0JdSvnxd/fmWRyv30zFe8+OcFW1RubJV9MHXxUn3HcUurHpOCxCt0IiJHtDugi0ixiFSIyHoRWScic8MfLxCR10SkKvx3n/Zei1IH6+om1rVr68gVejOAW4wxYwCcBWCOiIwBMA/ACmNMKYAV4TalD9bVTaxrF9buHLoxZheAXeG4XkQ2ACgCMBXApPBhjwF4HcBtCellG3KfW6Xar9w12ouH5+gny1QNOsmLm6t3xuX8rRNPUe2ts218xWi9u9/d/fW8ucrdOt2Lu7/6XtTj4iXV6xqrkht1zSvezvPij8/7s8qNXWB3aSx8vrvKNeXque/f/NJ+7ne763ngiiP2HLNfmqFypf8VfSfAREj1urZ+opf0ZR0624tbInZmCPXq5cWFz+gnSNWXjPPi2lMiF/Pa5cEFa/erzKFhvVR7xG1228Zb+j2kcqVZ9hb+G3ZcrHLFd9ilkS1Ven49mb7Vm6IiMhTAqQBWARgQ/uYBgBoAA6J8ziwAswAgBz1i7SclEOvqJta16+nwm6IikgfgOQA/M8Yc9OeMMQZR7nkxxiwwxpQbY8qzEP0dbUoO1tVNrGvX1KErdBHJwvFvjieNMX8Nf3i3iAw0xuwSkYEA9iSqk7Ga3Xurau9eZn/der9ucOThMbmnZIFqn5Id/Uv6QaNdznTdezNVbvjf7K+iQS16Ste6tqW5Zrdq33vDtbbx8JMqp6ZgztOvE3nHZ6tvQeLVn01RuYP/WuzFpRXJ3w0zXet6rFAv+vx8tp0izTqkjzUZ9v+jaZPfUrmSKXrazW9Krh4T+ofsw56fOaQfOP/j1+30Wb83s1Su4LM1SEUdWeUiABYC2GCMme9LLQXw1cTvdAAvxL97lCisq5tY166tI1fo5wC4DsDHIvLVu3y/BHAPgGdEZCaA7QB+mJguUoKwrm5iXbuwjqxyeRtqKxxlcpSPU4pjXd3EunZtaX/rf6RF9/7Ai/fM1Tuk/brfWtvwx52iv4TNvhnwtRF3HU9bfJMXl8x7V+WSd7Ow20IVdq7zvpnXqNx/3mnnWpeOWqJyM7brsW91hV0OO+yuD/U5jqbmfGo6GPyq3XKhenJPlTsyoNUXR3+NJ9fpW//vHL/Mi/+pV63KPd+gX+iJGnsL/9q/l6rc4LftT2XOMv3zmqq7nvLWfyIiR3BAJyJyhBxfkhqMXlJgzpTgpvFCI0pU+/zn7d1mP+9TFZdzjHrjR6qd/bG9GWPQvwd7l2BHrTIrcNDUxW0bwKDrSt8s3esa6lug2tt/MsqLj/aL3McyNoUf6i9P3k47L5q54oPIw1PGcvPsB8aY8vaO4xU6EZEjOKATETmCAzoRkSOcW7bo17JZ3+a7/CS7LGo5xkceHpNh+Kj9g4ioXS376lTb/yQq6hheoRMROYIDOhGRIzigExE5ggM6EZEjOKATETmCAzoRkSM4oBMROYIDOhGRIzigExE5ggM6EZEjOKATETmCAzoRkSM4oBMROSLQJxaJyF4A2wEUAqht5/CgdMW+DDHG9IvXi7Gu7WJd46er9qVDtQ10QPdOKvJ+Rx6nFAT2JX5Sqf/sS/ykUv/Zl7ZxyoWIyBEc0ImIHJGsAX1Bks77TdiX+Eml/rMv8ZNK/Wdf2pCUOXQiIoo/TrkQETki0AFdRC4RkY0isllE5gV57vD5HxGRPSLyie9jBSLymohUhf/uE0A/ikWkQkTWi8g6EZmbrL7EA+uq+uJMbVlX1Ze0qGtgA7qIhAD8EcClAMYAuFpExgR1/rBFAC6J+Ng8ACuMMaUAVoTbidYM4BZjzBgAZwGYE/5aJKMvncK6fo0TtWVdvyY96mqMCeQPgAkAXvW1bwdwe1Dn9513KIBPfO2NAAaG44EANiahTy8AuCgV+sK6srasa/rWNcgplyIAO3zt6vDHkm2AMWZXOK4BMCDIk4vIUACnAliV7L7EiHWNIs1ry7pGkcp15ZuiPub4f7OBLfsRkTwAzwH4mTHmYDL74rJkfC1Z28RjXb8uyAF9J4BiX3tQ+GPJtltEBgJA+O89QZxURLJw/BvjSWPMX5PZl05iXSM4UlvWNUI61DXIAX01gFIRKRGRbABXAVga4PmjWQpgejiejuNzYwklIgJgIYANxpj5yexLHLCuPg7VlnX1SZu6BvxGwhQAmwB8BuCOJLyR8RSAXQCacHxOcCaAvjj+7nQVgOUACgLox0Qc/9WsEsBH4T9TktEX1pW1ZV3dqSvvFCUicgTfFCUicgQHdCIiR3BAJyJyBAd0IiJHcEAnInIEB3QiIkdwQCcicgQHdCIiR/w/JhQsBvSk4w8AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"im0 = mnist.test.images[10]\n",
"im6 = mnist.test.images[21]\n",
"img = im0 + im6\n",
"img /= img.max()\n",
"plt.subplot(1,3,1)\n",
"plt.imshow(im0.reshape(28,28))\n",
"plt.subplot(1,3,2)\n",
"plt.imshow(im6.reshape(28,28))\n",
"plt.subplot(1,3,3)\n",
"plt.imshow(img.reshape(28,28))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The neural network trained with softmax cross entropy loss has the following prediction for the classification of this image, where the image is classifed as 0 with probability 0.9."
]
},
{
"cell_type": "code",
"execution_count": 318,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"softmax prob: [0.901 0. 0. 0. 0. 0.017 0.078 0. 0.003 0. ]\n"
]
}
],
"source": [
"p1 = sess1.run(prob1, feed_dict={X1:img[None,:], keep_prob1:1.0})\n",
"print('softmax prob: ', np.round(p1[0], decimals=3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we do the same experiments on the neural net trained using the loss function in Eq. 7, we have a much different results. The neural network could not generate any evidence to classify the image into one of 10 digits. Hence, it provides uniform distribution as its prediction. It implies I do not know by providing maximum uncertainty."
]
},
{
"cell_type": "code",
"execution_count": 324,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uncertainty: 1.0\n",
"Dirichlet mean: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n"
]
}
],
"source": [
"uncertainty2, p2 = sess2.run([u, prob2], feed_dict={X2:img[None,:], keep_prob2:1.0})\n",
"print('uncertainty:', np.round(uncertainty2[0,0], decimals=2))\n",
"print('Dirichlet mean: ', np.round(p2[0], decimals=3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we use the loss function in Eq. 5, the exepcted probability is highest for digit 0. It is around 0.32, however, the associated uncertainty is quite high around 0.73 as shown below."
]
},
{
"cell_type": "code",
"execution_count": 325,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uncertainty: 0.73\n",
"Dirichlet mean: [0.325 0.073 0.075 0.074 0.073 0.081 0.078 0.073 0.074 0.074]\n"
]
}
],
"source": [
"uncertainty3, p3 = sess3.run([u3, prob3], feed_dict={X3:img[None,:], keep_prob3:1.0})\n",
"print('uncertainty:', np.round(uncertainty3[0,0], decimals=2))\n",
"print('Dirichlet mean: ', np.round(p3[0], decimals=3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The uncertainty increase to 0.85 while the expected probability for the digit 0 decreases to 0.184 when the loss function in Eq. 6 is used."
]
},
{
"cell_type": "code",
"execution_count": 326,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uncertainty: 0.85\n",
"Dirichlet mean: [0.184 0.085 0.085 0.085 0.085 0.097 0.123 0.085 0.087 0.085]\n"
]
}
],
"source": [
"uncertainty4, p4 = sess4.run([u4, prob4], feed_dict={X4:img[None,:], keep_prob4:1.0})\n",
"print('uncertainty:', np.round(uncertainty4[0,0], decimals=2))\n",
"print('Dirichlet mean: ', np.round(p4[0], decimals=3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets try another settings where each of these two digits can be recognizable easily. You can see below an image which is created by combining images for digit 0 and digit 6 without any overlap. "
]
},
{
"cell_type": "code",
"execution_count": 330,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAETJJREFUeJzt3X2UVPV9x/HPd5ddQJ6RSCigoCI+F8wWayUGfIpSK+ZYPVL1YLRijtgQa3NqTWvsSTW0jebYxGhQMViNktagSKwJJSixGmQlBkSj+LAcobCgqwURYR++/WOHuOre3x3n6c7ye7/O2cPsfOd379eRD3dmfnPvz9xdAOJTk3UDALJB+IFIEX4gUoQfiBThByJF+IFIEX4gUoQfiBThByLVq5I7q7fe3kf9KrlLICofaKf2+G7L57FFhd/MzpB0q6RaSXe5+9zQ4/uon463U4rZJYCAlb4s78cW/LLfzGol3SbpTElHSpphZkcWuj0AlVXMe/5Jkl5199fdfY+kByVNL01bAMqtmPCPlPRml9835u77CDObZWaNZtbYqt1F7A5AKZX90353n+fuDe7eUKfe5d4dgDwVE/5NkkZ3+X1U7j4APUAx4V8laZyZjTWzekkXSFpcmrYAlFvBU33u3mZmV0n6uTqn+ua7+7qSdQagrIqa53f3xyQ9VqJeAFQQX+8FIkX4gUgRfiBShB+IFOEHIkX4gUgRfiBShB+IFOEHIkX4gUgRfiBShB+IFOEHIlXRS3cD1aJ9ynHBeq/rm4P1R8eHL13R3L4rWL9k5lcTa7XLVwfHlgpHfiBShB+IFOEHIkX4gUgRfiBShB+IFOEHIsU8f84HfzYpWO/7X8lzr94QXp/0jbPDy5J//uS1wfqvfnlMsB4y4pn2YL3Po88WvO2erP8/hteXWXjokmC9I2X7G9r2C9a3zUn+HsBnl6dsvEQ48gORIvxApAg/ECnCD0SK8AORIvxApAg/EKmi5vnNrEnSDkntktrcvaEUTRWidtj+wXr7wr7B+oPjbgnWm9vrEmuDap4Ijj2wV3jON9XMFQUP3XrR+8H6//5bfbB+xU1zgvX973zmU/dUKTvPPT6x9q0Dv1/Utg9fPDtYH3/XzmB96GeK/DtRAqX4ks9Ud3+rBNsBUEG87AciVWz4XdIvzOw5M5tVioYAVEaxL/snu/smMztA0lIz+527f+QNau4fhVmS1EfZv88B0KmoI7+7b8r9uVXSIkmfODvG3ee5e4O7N9SpdzG7A1BCBYffzPqZ2YC9tyWdLumFUjUGoLyKedk/XNIiM9u7nR+7++Ml6QpA2Zm7V2xnA22oH2+nlGXbr90/MVh/ecrdZdmvJP3g3bHB+uodBwbrG3cOLmr/tZZ8dvnPxj9a1Lab2sLfE/jKhVcF6zVPPV/U/kN6jfyDYP3GpxYl1o6oD7/oPePFPw/We0/bHKx7655gvVxW+jJt9xbL57FM9QGRIvxApAg/ECnCD0SK8AORIvxApHrUpbv9hD9MrC38kx+mjA7/pz6+K/zV47lfn5lYG7Au5aTGbS3Bcs07b4bHp/Ca2sTaYTdfGRz74vnfC9YPqesfrO/6++3B+qBLhifW2raEl8FO89bJBwXradN5QTcfECx764bCt10lOPIDkSL8QKQIPxApwg9EivADkSL8QKQIPxCpHjXP3zoo+TLTE+rD/ykdCp+6/PV7Lg3WRy96OrEWXgS7AjqSOzj06l8Hhx5RHz4ld830W4P1J4/5z2D9xFOTv2cw6L7i5vm3Tm4L1msCx7apa88Lju33+KqCeupJOPIDkSL8QKQIPxApwg9EivADkSL8QKQIPxCpHjXP394nrysSd+vYpy8J1g+8MXkef182bvbKYH3JqSOC9fP6vx2sv3t28lLVg+4LDlWvg0YH63MmLw3WO5R8SXOfHz5fX3o9pd7zceQHIkX4gUgRfiBShB+IFOEHIkX4gUgRfiBSqfP8ZjZf0lmStrr70bn7hkpaKGmMpCZJ57v7O+Vrs9P4v1tX8Nja5waUsJN4fGPVOcH6eVPDS5/PPmpFYm2JhgTHvnr5qGD94cEPB+tvtCUvk913WzZLaFeTfI78P5J0xsfuu1bSMncfJ2lZ7ncAPUhq+N19haSPLzkzXdKC3O0FksKHBwBVp9D3/MPdfXPu9hZJyWsyAahKRX/g5+4uJV8gz8xmmVmjmTW2anexuwNQIoWGv9nMRkhS7s+tSQ9093nu3uDuDXXqXeDuAJRaoeFfLGnvsrUzJT1SmnYAVEpq+M3sAUnPSBpvZhvN7DJJcyWdZmbrJZ2a+x1AD5I6z+/uMxJKp5S4F9Uce3iwPmVw8vnbr7R+EBw7bE1rQT3FbsiTfcIPmFq+fdth7xU1/oF3/yixVrt8dVHb3hfwDT8gUoQfiBThByJF+IFIEX4gUoQfiFRVXbp7/czBwfoF/bcl1iavuTg4duBj+/6Sy/uaOz53f1Hj/+PHUxJrIxXnpdq74sgPRIrwA5Ei/ECkCD8QKcIPRIrwA5Ei/ECkqmqe/+ozfxash07brb9t/5Stv1ZAR8hSjSUvsS1JdVYbrI/64obE2stjJwXHjj9sU7D+6PjFwXpab63enlhbtHNocOw/3XFh8nbv+3VwbFcc+YFIEX4gUoQfiBThByJF+IFIEX4gUoQfiFRVzfOn+eHbJyXW+ix5toKdoBI6PHxsavXw5dgfGR9Ywnt8IR19KPwNBOmWlnHB+uwhLyfWpvd7Kzh2+jW3JtZOXJq4eNYncOQHIkX4gUgRfiBShB+IFOEHIkX4gUgRfiBSqfP8ZjZf0lmStrr70bn7bpB0uaS9F9K/zt0fS91Wba1qBw5KrA+o2ZhHy0B+NrTtSaz9TdO5wbFvLjw4WN/vrfBM/8Ala4L1h846PbHW8eXk9Skk6YljHwzW85XPkf9Hks7o5v7vuvuE3E9q8AFUl9Twu/sKSS0V6AVABRXznv8qM1tjZvPNbEjJOgJQEYWG/3ZJh0iaIGmzpJuTHmhms8ys0cwa9/iuAncHoNQKCr+7N7t7u7t3SLpTUuLVEN19nrs3uHtDvfUttE8AJVZQ+M1sRJdfvyTphdK0A6BS8pnqe0DSFEnDzGyjpG9KmmJmEyS5pCZJV5SxRwBlkBp+d5/Rzd13F7KzPcP6aOPFRyXWLxywPDh+9c4xhewWRdg97f+KGv9+R33BY7/8878M1l+Zfnuw/qf/MzuxdvBfPB8ce4C2BOtp0s737/+T5Ovr1y4LX7f/nl+NSay93f52yp4/xDf8gEgRfiBShB+IFOEHIkX4gUgRfiBSPerS3Si9tpM/F6w/OPH7KVvoHawu+udTEmuDFF5Our4lvMx1mouOSr6c+9MqfAqy3NrfDp9Hd8ua5Oe0edf6vPfDkR+IFOEHIkX4gUgRfiBShB+IFOEHIkX4gUgxz7+PS5vHb5mzM1g/vC48j3/lphOD9cELVyfWPDhS6vW+Bet1Fv4ewIDaDxJrNfsNDo7teP/9YL2c2qccF6zfO+nOxNqlKct7d8WRH4gU4QciRfiBSBF+IFKEH4gU4QciRfiBSFV0nr92jzSwqT2x3tSW3dxqT2a9kv83vnv1juDYxuPCyz0v3RVeZemVf0i+FLsk1bc2Busho256Olg/8riLgvXfnHBPYu2Ouz8fHHvIrNeD9Y4d4ec1Te0R4xJr4/51XXDsxN7JFwbfL/zViI/gyA9EivADkSL8QKQIPxApwg9EivADkSL8QKTMPXxWtZmNlnSvpOHqPAV7nrvfamZDJS2UNEZSk6Tz3f2d0LYG2lA/3pKvOX72i+HlhTs8+d+qx794dHBs28ZNwXqWOiZPCNbfuDI8/twjkpebvumA5PPp8zF19leC9b4PJ18bv9x6fXZ4sP5XTz2RWJva973g2GNWhJcHH/Zw+PsPrf3CE+43XndXYu0LfcPfd1m+q39ibc7017R+7a68ZvvzOfK3SbrG3Y+U9MeSZpvZkZKulbTM3cdJWpb7HUAPkRp+d9/s7qtzt3dIeknSSEnTJS3IPWyBpHPK1SSA0vtU7/nNbIykiZJWShru7ptzpS3qfFsAoIfIO/xm1l/SQ5K+5u7bu9a884ODbj88MLNZZtZoZo2t2l1UswBKJ6/wm1mdOoN/v7v/NHd3s5mNyNVHSNra3Vh3n+fuDe7eUJeyqCOAykkNv5mZpLslveTut3QpLZY0M3d7pqRHSt8egHLJ55TeEyVdLGmtme2dU7pO0lxJPzGzyyRtkHR+eVr80JWD30isNS8ZGBzb2HJgqdspmblj5wXrE+oLP/P6uT3Jp1BL0sXPXhasH/LL3wXr4a2XV9uW5mD9O5dfmFy88/7g2LUnJU/FSZJOCpdrUo6rHUo+LXfGa9OCY7d/c3RirXnDbeHGukj9W+XuT0lKmjdMnrQHUNX4hh8QKcIPRIrwA5Ei/ECkCD8QKcIPRCr1lN5SSjult+XSE4Ljp81ZkVi7ftjagvuqdm0ps+m/3ZNcu2jhV4Njx177TCEt9XgdX5gYrNdcvy1YX3z4omD9sg2nBeurlh+RWDv4W78Jju34IHnp8ZW+TNu9pWSn9ALYBxF+IFKEH4gU4QciRfiBSBF+IFKEH4hUVc3zp6k9dGxiberDa4Jj/3rI+oL3W26HP3lpsF6/dr9gfdS3w0tZIx7M8wNIRfiBSBF+IFKEH4gU4QciRfiBSBF+IFI9ap4fQBjz/ABSEX4gUoQfiBThByJF+IFIEX4gUoQfiFRq+M1stJktN7MXzWydmc3J3X+DmW0ys+dzP+FFxQFUlV55PKZN0jXuvtrMBkh6zsyW5mrfdffvlK89AOWSGn533yxpc+72DjN7SdLIcjcGoLw+1Xt+MxsjaaKklbm7rjKzNWY238yGJIyZZWaNZtbYqt1FNQugdPIOv5n1l/SQpK+5+3ZJt0s6RNIEdb4yuLm7ce4+z90b3L2hTr1L0DKAUsgr/GZWp87g3+/uP5Ukd29293Z375B0p6RJ5WsTQKnl82m/Sbpb0kvufkuX+0d0ediXJL1Q+vYAlEs+n/afKOliSWvN7PncfddJmmFmEyS5pCZJV5SlQwBlkc+n/U9J6u784MdK3w6ASuEbfkCkCD8QKcIPRIrwA5Ei/ECkCD8QKcIPRIrwA5Ei/ECkCD8QKcIPRIrwA5Ei/ECkCD8QqYou0W1m2yRt6HLXMElvVayBT6dae6vWviR6K1QpezvI3T+TzwMrGv5P7Nys0d0bMmsgoFp7q9a+JHorVFa98bIfiBThByKVdfjnZbz/kGrtrVr7kuitUJn0lul7fgDZyfrIDyAjmYTfzM4ws5fN7FUzuzaLHpKYWZOZrc2tPNyYcS/zzWyrmb3Q5b6hZrbUzNbn/ux2mbSMequKlZsDK0tn+txV24rXFX/Zb2a1kl6RdJqkjZJWSZrh7i9WtJEEZtYkqcHdM58TNrOTJL0n6V53Pzp3379IanH3ubl/OIe4+99WSW83SHov65WbcwvKjOi6srSkcyRdogyfu0Bf5yuD5y2LI/8kSa+6++vuvkfSg5KmZ9BH1XP3FZJaPnb3dEkLcrcXqPMvT8Ul9FYV3H2zu6/O3d4hae/K0pk+d4G+MpFF+EdKerPL7xtVXUt+u6RfmNlzZjYr62a6MTy3bLokbZE0PMtmupG6cnMlfWxl6ap57gpZ8brU+MDvkya7+3GSzpQ0O/fytip553u2apquyWvl5krpZmXp38vyuSt0xetSyyL8mySN7vL7qNx9VcHdN+X+3Cppkapv9eHmvYuk5v7cmnE/v1dNKzd3t7K0quC5q6YVr7MI/ypJ48xsrJnVS7pA0uIM+vgEM+uX+yBGZtZP0umqvtWHF0uambs9U9IjGfbyEdWycnPSytLK+LmruhWv3b3iP5KmqfMT/9ckfSOLHhL6OljSb3M/67LuTdID6nwZ2KrOz0Yuk7S/pGWS1kv6b0lDq6i3f5e0VtIadQZtREa9TVbnS/o1kp7P/UzL+rkL9JXJ88Y3/IBI8YEfECnCD0SK8AORIvxApAg/ECnCD0SK8AORIvxApP4fYskg6RI7YOkAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = np.zeros((28,28))\n",
"img[:,:-6] += mnist.test.images[10].reshape(28,28)[:,6:]\n",
"img[:,14:] += mnist.test.images[21].reshape(28,28)[:,5:19]\n",
"img /= img.max()\n",
"plt.imshow(img)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below, you can see the prediction of the neural network trained with softmax cross entropy for this example. The prediction of the network is digit 2 with probability 0.775. Hence, the network associates quite high probability with the wrong label. "
]
},
{
"cell_type": "code",
"execution_count": 331,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"softmax prob: [0. 0.199 0.775 0.007 0.003 0. 0. 0.015 0. 0. ]\n"
]
}
],
"source": [
"p1 = sess1.run(prob1, feed_dict={X1:img.reshape(1,-1), keep_prob1:1.0})\n",
"print('softmax prob: ', np.round(p1[0], decimals=3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On the otherhand, when we do the same using the network trained based on the loss in Eq. 7, the output of the neural network is uniform distribution with uncertainty 1.0, as shown below."
]
},
{
"cell_type": "code",
"execution_count": 332,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uncertainty: 1.0\n",
"Dirichlet mean: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]\n"
]
}
],
"source": [
"uncertainty2, p2 = sess2.run([u, prob2], feed_dict={X2:img.reshape(1,-1), keep_prob2:1.0})\n",
"print('uncertainty:', np.round(uncertainty2[0,0], decimals=2))\n",
"print('Dirichlet mean: ', np.round(p2[0], decimals=3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The neural networks, trained using the loss functions defined in Eq. 5 and Eq. 6 in the paper, also have very high uncertainty for their predictions. These networks assing small amount of evidence for the classification of the image as digit 2. However, they associate very high uncertainty with their misclassifications of the image."
]
},
{
"cell_type": "code",
"execution_count": 333,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uncertainty: 0.92\n",
"Dirichlet mean: [0.092 0.094 0.143 0.12 0.092 0.092 0.092 0.093 0.092 0.092]\n"
]
}
],
"source": [
"uncertainty3, p3 = sess3.run([u3, prob3], feed_dict={X3:img.reshape(1,-1), keep_prob3:1.0})\n",
"print('uncertainty:', np.round(uncertainty3[0,0], decimals=2))\n",
"print('Dirichlet mean: ', np.round(p3[0], decimals=3))"
]
},
{
"cell_type": "code",
"execution_count": 334,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uncertainty: 0.93\n",
"Dirichlet mean: [0.093 0.093 0.16 0.098 0.093 0.093 0.093 0.094 0.093 0.093]\n"
]
}
],
"source": [
"uncertainty4, p4 = sess4.run([u4, prob4], feed_dict={X4:img.reshape(1,-1), keep_prob4:1.0})\n",
"print('uncertainty:', np.round(uncertainty4[0,0], decimals=2))\n",
"print('Dirichlet mean: ', np.round(p4[0], decimals=3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"default_view": {},
"name": "EDL MNIST Demo.ipynb",
"provenance": [
{
"file_id": "1AAN37ioPFkhPfTxBaLeFywBwIliguMu1",
"timestamp": 1527923401234
}
],
"version": "0.3.2",
"views": {}
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}