“What kind of bear is best?” Building a bear classifier with fast.ai

That's debatable...

Motivation

Usually, when my dad and I go hiking in the mountains and we spot a bear (from a safe distance), we have a tough time figuring out if it's a brown bear or a black bear. How you respond to a bear encounter depends on the species. For brown bears, you're supposed to pretend you're dead. For black bears, you're supposed to be loud and intimidating to scare them off. Being unable to tell the two apart could be the last mistake you ever make.

Unfortunately, telling them apart isn't as simple as the difference between black and brown, as the National Park Service points out:

The name “black bear” is misleading, however. This species can range from black to gray to cinammon to white depending on the location and the individual. To ensure proper identification of an American black bear, do not depend on the bear's coloration.

The NPS also provides a helpful infographic to help distinguish the two: Brown/Grizzly vs. Black Bear

Side note: I learned that grizzly bears and brown bears are the same species. Their only distinguishing feature is their location. Grizzly bears live inland, while brown bears live on the coast, e.g. in Alaska

You can see it's not always trivial to tell the difference. So I decided to build a simple image classifier using fast.ai to help me settle my hiking debates and potentially save my life.

The Data

All good classifiers need a lot of data. Since I'll be using resnet34 and resnet50 as a base to build my classifier, I won't need a ton of new data, but the more data, the better. In this case, I'll need a bunch of images classified as “brown” or “black” to train and test my classifier.

Luckily, ImageNet has a decent collection of images classified accordingly:

Now, I couldn't download the images directly as an independent party, but I did have access to a list of URLs for each image set:

import urllib.request

black_bear_images_link = 'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n02133161'   
black_bear_image_urls = urllib.request.urlopen(black_bear_images_link).read().decode()

brown_bear_images_link = 'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n02132136'   
brown_bear_image_urls = urllib.request.urlopen(brown_bear_images_link).read().decode()

One I did that, I could go through each image and download it to my server. I decided to save all brown bear images under data/bears/brown and all black bear images under data/bears/black.

import pathlib
if not os.path.exists('../data/bears'):
    os.makedirs('../data/bears')
if not os.path.exists('../data/bears/brown'):
    os.makedirs('../data/bears/brown')
if not os.path.exists('../data/bears/black'):
    os.makedirs('../data/bears/black')

path = pathlib.Path('../data/bears')
brown_path = path/'brown'
black_path = path/'black'

Once I had the directory structure established, I could download the images with the following python snippet:

from PIL import Image
IMAGE_NO_LONGER_AVAILABLE = 2051
# download black bear images
for pic_num, i in list(enumerate(black_bear_image_urls.split('\n'))):
    file_name = "../data/bears/black/"+str(pic_num+1)+".jpg"
    try:
        if not os.path.exists(file_name):
            urllib.request.urlretrieve(i, file_name)
	     # Check to make sure the image is valid
            img = Image.open(file_name)
            img.verify()
            assert os.path.getsize(file_name) != IMAGE_NO_LONGER_AVALABLE
            
    except Exception as e:
        print(pic_num+1, e)
        if os.path.exists(file_name):
            os.remove(file_name)
            
# download brown bear images
for pic_num, i in list(enumerate(brown_bear_image_urls.split('\n'))):
    file_name = "../data/bears/brown/"+str(pic_num+1)+".jpg"
    try:
        if not os.path.exists(file_name):
            urllib.request.urlretrieve(i, file_name)
	     # Check to make sure the image is valid
            img = Image.open(file_name)
            img.verify()
            assert os.path.getsize(file_name) != IMAGE_NO_LONGER_AVALABLE
            
    except Exception as e:
        print(pic_num+1, e)
        if os.path.exists(file_name):
            os.remove(file_name)

There are a few things I'd like to note about this script:

This, to my knowledge, ensures that the ImageNet data set is good quality.

On the server

Following the fast.ai Lesson 1 format:

%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai import *
from fastai.vision import *

bs = 64

fn_paths = [brown_path, black_path]

fnames = get_image_files(brown_path) + get_image_files(black_path)

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

np.random.seed(2)
# create the data bunch based on path
data = ImageDataBunch.from_name_func(path, fnames, ds_tfms=get_transforms(), size=224, bs=bs,
        label_func = lambda x: 'brown' if '/brown/' in str(x) else 'black')
data.classes # ['brown', 'black']

data.normalize(imagenet_stats)

After running this, I sanity checked the classifications using data.show_batch(rows=3, figsize=(7,6)). There didn't appear to be anything blatantly wrong.

Training

Again, following the fast.ai Lesson 1 format:

learn = create_cnn(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)

This yielded the following results:

Total time: 01:03
epoch  train_loss  valid_loss  error_rate
1      0.354983    0.098601    0.036939    (00:19)
2      0.208561    0.131596    0.039578    (00:14)
3      0.156123    0.127603    0.036939    (00:14)
4      0.125690    0.121351    0.036939    (00:14)

That seemed consistent with what I expected: one minute of training and about a 96.5% correct classification.

Trying to

Checking the top confusions:

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_top_losses(9, figsize=(15,11))
interp.most_confused(min_val=2)

From the first plot, I was able to check some of the errors. I happened to notice an interesting classification:

Classified black bear, actually brown

ImageNet had this one classified as a black bear. I checked the watermarked website, and lo and behold:

Actual screenshot from photographer's website

Looks like the dataset had a mis-classified entry! Good thing I checked. It may not have effected the accuracy too much, but it's good to know.

Let's try to fine-tune the results using a learning rate varied over layers:

learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))

Output:

Total time: 00:38
epoch  train_loss  valid_loss  error_rate
1      0.082740    0.113588    0.036939    (00:19)
2      0.079821    0.118543    0.036939    (00:18)

Hmmm, not much change. Let's try training on resnet50 this time.

data = ImageDataBunch.from_name_func(path, fnames, ds_tfms=get_transforms(), size=299, bs=bs//2,
        label_func = lambda x: 'brown' if '/brown/' in str(x) else 'black')
learn = create_cnn(data, models.resnet50, metrics=error_rate)
learn.fit_one_cycle(8)

Output:

Total time: 07:17
epoch  train_loss  valid_loss  error_rate
1      0.220857    0.175432    0.044321    (01:24)
2      0.162850    0.182251    0.052632    (00:49)
3      0.133017    0.143796    0.049862    (00:50)
4      0.095450    0.154293    0.047091    (00:50)
5      0.072514    0.144659    0.041551    (00:50)
6      0.069240    0.112004    0.033241    (00:50)
7      0.042691    0.118620    0.036011    (00:50)
8      0.032191    0.119663    0.036011    (00:50)

And tuning:

learn.unfreeze()
learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))

Output:

Total time: 03:26
epoch  train_loss  valid_loss  error_rate
1      0.036145    0.122415    0.024931    (01:15)
2      0.029445    0.118053    0.030471    (01:05)
3      0.025955    0.118540    0.027701    (01:05)

Looks like we improved our error rate by about 1% over resnet34. interp.most_confused(min_val=2) now yields ten mistakes in the whole dataset (400+ images).

Reflections

Overall, this was a pretty solid learning experience. I tried to solve an interesting problem, and it worked pretty well.

My most learning came from the dataset creation portion. I had initially assumed that the ImageNet data was actually in reasonable shape. It turns out only about 15% of the images were still available and non-corrupt! I'll try downloading directly from Google or Bing next time to see if I can get higher-quality data for my next claissifier.

I didn't have to change much of fast.ai's boilerplate learning code to achieve my 97.5% accuracy. Of course, I haven't actually tested this on outside data yet, so I'm making a big assumption that the results hold. If they do, it really shows the power of training for specific classification on a pre-trained, general image model!