An easier way to generate seaborn ridgeplots

Python
Published

June 1, 2024

While their Joy Division inspired t-shirts and posters have given some people the impression that they’re not a serious visualization tool, ridgeline plots are an effective way to get a general sense of distributional differences between groups in a dataset.

But while they’re pretty enough to appear on the seaborn logo; seaborn does not bother to include an api for them. The sole ridgeline example in the seaborn gallery relies on looping through a facetgrid which is hardly more elegant than just doing it in matplotlib.

Seaborn has an obvious bias towards the violin plot for distributional comparisons despite the violin plot being less informative and less attractive than the ridge plot. Luckily there are enough tools in the violin plot api that we can torture it until it generates the desired output.

Here is the fairly self-explanatory ridge_plot function.

import pandas as pd
import seaborn as sns


def ridge_plot(data, x, y, **kwargs):
    codes, uniques = pd.factorize(data[y])
    uniques = uniques.to_list()
    uniques = ["NA"] + uniques if -1 in codes else uniques

    plot_args = {
        "data": data,
        "x": x,
        "y": codes,
        "split": True,
        "orient": "h",
        "inner": None,
        "dodge": False,
        "native_scale": True,
        "width": 1.75,
    }
    plot_args.update(**kwargs)
    ax = sns.violinplot(**plot_args)

    unique_codes = pd.unique(codes)
    spacing = unique_codes[1] - unique_codes[0]
    tick_locations = unique_codes + plot_args["width"] * (spacing / 2)
    ax.set_yticks(tick_locations, labels=uniques)

    ax.invert_yaxis()
    return ax
1
Factorize the y-axis.
2
Plot the y-axis as the factor codes at native_scale.
3
Increase width to create overlap.
4
Pass all kwargs to the violinplot api.
5
Add the factor labels at the base of each distribution.
6
Invert y-axis so the distributions face upwards.

And here are a few examples.

Code
penguins = sns.load_dataset("penguins")
ridge_plot(
    penguins, "flipper_length_mm", "species", inner="stick", color="#56ad74"
).set(title="Flipper length by species", xlabel="")
plt.show()

Code
diamonds = sns.load_dataset("diamonds")
ridge_plot(diamonds, "price", "cut", width=4, color="#6f83a9").set(
    title="Diamond price by cut", xlabel="Price"
)
plt.show()

Code
mpg = sns.load_dataset("mpg")

# Get common makes and format labels
mpg["make"] = mpg["name"].str.split(" ").str[0].str.title()
mpg = mpg.groupby("make").filter(lambda x: len(x) > 10)

# Sort by horsepower range
hp_range = mpg.groupby("make")["horsepower"].agg(lambda x: x.max() - x.min()).sort_values()
mpg["make"] = pd.Categorical(mpg["make"], categories=hp_range.index, ordered=True)
mpg = mpg.sort_values("make", ascending=False)

ridge_plot(
    mpg,
    "horsepower",
    "make",
    density_norm="width",
    color="#6ea0b1",
).set(title="American manufacturers offer a wider range of horsepowers", xlabel="Horsepower")
plt.show()