“What kind of bear is best?” Building a bear classifier with fast.ai
Usually, when my dad and I go hiking in the mountains and we spot a bear (from a safe distance), we have a tough time figuring out if it's a brown bear or a black bear. How you respond to a bear encounter depends on the species. For brown bears, you're supposed to pretend you're dead. For black bears, you're supposed to be loud and intimidating to scare them off. Being unable to tell the two apart could be the last mistake you ever make.
Unfortunately, telling them apart isn't as simple as the difference between black and brown, as the National Park Service points out:
The name “black bear” is misleading, however. This species can range from black to gray to cinammon to white depending on the location and the individual. To ensure proper identification of an American black bear, do not depend on the bear's coloration.
The NPS also provides a helpful infographic to help distinguish the two:
Side note: I learned that grizzly bears and brown bears are the same species. Their only distinguishing feature is their location. Grizzly bears live inland, while brown bears live on the coast, e.g. in Alaska
You can see it's not always trivial to tell the difference. So I decided to build a simple image classifier using fast.ai to help me settle my hiking debates and potentially save my life.
All good classifiers need a lot of data. Since I'll be using resnet34 and resnet50 as a base to build my classifier, I won't need a ton of new data, but the more data, the better. In this case, I'll need a bunch of images classified as “brown” or “black” to train and test my classifier.
Luckily, ImageNet has a decent collection of images classified accordingly:
Now, I couldn't download the images directly as an independent party, but I did have access to a list of URLs for each image set:
import urllib.request black_bear_images_link = 'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n02133161' black_bear_image_urls = urllib.request.urlopen(black_bear_images_link).read().decode() brown_bear_images_link = 'http://image-net.org/api/text/imagenet.synset.geturls?wnid=n02132136' brown_bear_image_urls = urllib.request.urlopen(brown_bear_images_link).read().decode()
One I did that, I could go through each image and download it to my server. I decided to save all brown bear images under
data/bears/brown and all black bear images under
import pathlib if not os.path.exists('../data/bears'): os.makedirs('../data/bears') if not os.path.exists('../data/bears/brown'): os.makedirs('../data/bears/brown') if not os.path.exists('../data/bears/black'): os.makedirs('../data/bears/black') path = pathlib.Path('../data/bears') brown_path = path/'brown' black_path = path/'black'
Once I had the directory structure established, I could download the images with the following python snippet:
from PIL import Image IMAGE_NO_LONGER_AVAILABLE = 2051 # download black bear images for pic_num, i in list(enumerate(black_bear_image_urls.split('\n'))): file_name = "../data/bears/black/"+str(pic_num+1)+".jpg" try: if not os.path.exists(file_name): urllib.request.urlretrieve(i, file_name) # Check to make sure the image is valid img = Image.open(file_name) img.verify() assert os.path.getsize(file_name) != IMAGE_NO_LONGER_AVALABLE except Exception as e: print(pic_num+1, e) if os.path.exists(file_name): os.remove(file_name) # download brown bear images for pic_num, i in list(enumerate(brown_bear_image_urls.split('\n'))): file_name = "../data/bears/brown/"+str(pic_num+1)+".jpg" try: if not os.path.exists(file_name): urllib.request.urlretrieve(i, file_name) # Check to make sure the image is valid img = Image.open(file_name) img.verify() assert os.path.getsize(file_name) != IMAGE_NO_LONGER_AVALABLE except Exception as e: print(pic_num+1, e) if os.path.exists(file_name): os.remove(file_name)
There are a few things I'd like to note about this script:
- I include Pillow's Image module to call
img.verify(). This raises an exception whenever an image is corrupted or cannot be opened.
- I also check to make sure the image size makes sense. I fould several “This image no longer exists” images, and they all had the same size: 2.05kb. The
assertraises an exception whenever it downloads one of these images.
- If one of the above exceptions is raised, it prints out the exception and deletes the file (if it exists).
This, to my knowledge, ensures that the ImageNet data set is good quality.
On the server
Following the fast.ai Lesson 1 format:
%reload_ext autoreload %autoreload 2 %matplotlib inline from fastai import * from fastai.vision import * bs = 64 fn_paths = [brown_path, black_path] fnames = get_image_files(brown_path) + get_image_files(black_path) from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True np.random.seed(2) # create the data bunch based on path data = ImageDataBunch.from_name_func(path, fnames, ds_tfms=get_transforms(), size=224, bs=bs, label_func = lambda x: 'brown' if '/brown/' in str(x) else 'black') data.classes # ['brown', 'black'] data.normalize(imagenet_stats)
After running this, I sanity checked the classifications using
data.show_batch(rows=3, figsize=(7,6)). There didn't appear to be anything blatantly wrong.
Again, following the fast.ai Lesson 1 format:
learn = create_cnn(data, models.resnet34, metrics=error_rate) learn.fit_one_cycle(4)
This yielded the following results:
Total time: 01:03 epoch train_loss valid_loss error_rate 1 0.354983 0.098601 0.036939 (00:19) 2 0.208561 0.131596 0.039578 (00:14) 3 0.156123 0.127603 0.036939 (00:14) 4 0.125690 0.121351 0.036939 (00:14)
That seemed consistent with what I expected: one minute of training and about a 96.5% correct classification.
Checking the top confusions:
interp = ClassificationInterpretation.from_learner(learn) interp.plot_top_losses(9, figsize=(15,11)) interp.most_confused(min_val=2)
From the first plot, I was able to check some of the errors. I happened to notice an interesting classification:
ImageNet had this one classified as a black bear. I checked the watermarked website, and lo and behold:
Looks like the dataset had a mis-classified entry! Good thing I checked. It may not have effected the accuracy too much, but it's good to know.
Let's try to fine-tune the results using a learning rate varied over layers:
learn.unfreeze() learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
Total time: 00:38 epoch train_loss valid_loss error_rate 1 0.082740 0.113588 0.036939 (00:19) 2 0.079821 0.118543 0.036939 (00:18)
Hmmm, not much change. Let's try training on resnet50 this time.
data = ImageDataBunch.from_name_func(path, fnames, ds_tfms=get_transforms(), size=299, bs=bs//2, label_func = lambda x: 'brown' if '/brown/' in str(x) else 'black') learn = create_cnn(data, models.resnet50, metrics=error_rate) learn.fit_one_cycle(8)
Total time: 07:17 epoch train_loss valid_loss error_rate 1 0.220857 0.175432 0.044321 (01:24) 2 0.162850 0.182251 0.052632 (00:49) 3 0.133017 0.143796 0.049862 (00:50) 4 0.095450 0.154293 0.047091 (00:50) 5 0.072514 0.144659 0.041551 (00:50) 6 0.069240 0.112004 0.033241 (00:50) 7 0.042691 0.118620 0.036011 (00:50) 8 0.032191 0.119663 0.036011 (00:50)
learn.unfreeze() learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))
Total time: 03:26 epoch train_loss valid_loss error_rate 1 0.036145 0.122415 0.024931 (01:15) 2 0.029445 0.118053 0.030471 (01:05) 3 0.025955 0.118540 0.027701 (01:05)
Looks like we improved our error rate by about 1% over resnet34.
interp.most_confused(min_val=2) now yields ten mistakes in the whole dataset (400+ images).
Overall, this was a pretty solid learning experience. I tried to solve an interesting problem, and it worked pretty well.
My most learning came from the dataset creation portion. I had initially assumed that the ImageNet data was actually in reasonable shape. It turns out only about 15% of the images were still available and non-corrupt! I'll try downloading directly from Google or Bing next time to see if I can get higher-quality data for my next claissifier.
I didn't have to change much of fast.ai's boilerplate learning code to achieve my 97.5% accuracy. Of course, I haven't actually tested this on outside data yet, so I'm making a big assumption that the results hold. If they do, it really shows the power of training for specific classification on a pre-trained, general image model!