Visualisation and Plotting of Causal Impact

Inspired by this tweet:

Been working on some visualizations in @matplotlib to highlight statistical concepts. This one intends to demonstrate (via simulation) the Central Limit Theorem: the sampling distribution of the mean is gaussian if given a large enough sample size regardless* of the population. pic.twitter.com/vvUhEERQ50

— Cameron Riddell (@RiddleMeCam) August 10, 2022

I figured i’d create some plots to highlight and explain the Bootstrapping A/B tests i’ve been working on for a bit. Think it came together nicely. Wondered if you have any favourite visualisation for showing causal impact?

from matplotlib.lines import Line2D

x = pd.Series(np.random.exponential(1, 500))
x1 = pd.Series(np.random.gumbel(1, .6, 500))
x2 = x - x1

def make_bootstrap_plot(x, x1, x2, titles=['Treatment', 'Control', 'Diff']):
    fig, axs = plt.subplots(3, 3, figsize=(20, 15), 
    gridspec_kw={'width_ratios':[2,2, 2], 'height_ratios':[2,4,2]})
    axs = axs.flatten()
    np.random.seed(100)


    axs[0].hist(x, edgecolor='black', color='lightskyblue', alpha=0.5, label='Distribution')
    axs[0].axvline(x.mean(), linestyle='--', color='black', label='Mean')
    axs[0].set_title(titles[0], fontsize=20)
    axs[0].set_yticks([])
    legend = axs[0].legend( bbox_to_anchor=(-1.5, 1, 0, 0), fontsize=20, title="Original Sample",  loc='upper left')
    plt.setp(legend.get_title(),fontsize=25)
    axs[0].spines["top"].set_visible(False)
    axs[0].spines["left"].set_visible(False)
    axs[0].spines["right"].set_visible(False)
    axs[1].hist(x1, edgecolor='black', color='lightskyblue', alpha=0.5)
    axs[1].axvline(x1.mean(), linestyle='--', color='black')
    axs[1].set_title(titles[1], fontsize=20)
    axs[1].set_yticks([])
    axs[1].set_ylim([0, 100])
    axs[1].spines["top"].set_visible(False)
    axs[1].spines["left"].set_visible(False)
    axs[1].spines["right"].set_visible(False)
    axs[2].hist(x2, edgecolor='black', color='lightskyblue', alpha=0.5)
    axs[2].axvline(x2.mean(), linestyle='--', color='black')
    axs[2].set_title(titles[2], fontsize=20)
    axs[2].set_yticks([])
    axs[2].spines["top"].set_visible(False)
    axs[2].spines["left"].set_visible(False)
    axs[2].spines["right"].set_visible(False);

    ### Bootstrap X samples
    bs_x = [x.sample(replace=True, n=len(x)).values for i in range(100)]
    labels = ['Bootstrap_Sample_{i}0'.format(i=i) for i in range(100)]
    means_x = [np.mean(x) for x in bs_x]
    ci_l_x = [np.quantile(x, 0.25) for x in bs_x]
    ci_u_x = [np.quantile(x, 0.95) for x in bs_x]
    levels = np.linspace(0, 100, 100)
    for s, i, indx, level in zip(bs_x, labels, range(100), levels):
        axs[3].plot((s[0:100]),np.linspace(levels[indx-1], levels[indx], 100), 'o', alpha=0.3, color='slateblue')
    axs[3].plot(means_x, levels, linestyle='dashdot', color='darksalmon', linewidth=4)
    axs[3].plot(ci_l_x, levels, color='darksalmon', linewidth=4)
    axs[3].plot(ci_u_x, levels, color='darksalmon', linewidth=4)
    axs[3].fill_betweenx(levels, means_x, ci_l_x, color='darksalmon')
    axs[3].fill_betweenx(levels, means_x, ci_u_x, color='darksalmon')
    axs[3].set_ylim(10, 100)
    axs[3].set_yticklabels(labels, fontsize=12)
    custom_lines = [Line2D([0], [0], color='darksalmon', lw=4),
                    Line2D([0], [0], color='slateblue', lw=4),]

    legend = axs[3].legend(custom_lines, ['Mean & 95% CI', 'Sample'], bbox_to_anchor=(-1.5, 1, 0, 0), fontsize=20, title="Bootstrap Samples \n N=100", loc='upper left')
    plt.setp(legend.get_title(),fontsize=25)

    ### Bootstrap X1 samples
    bs_x1 = [x1.sample(replace=True, n=len(x1)).values for i in range(100)]
    means_x1 = [np.mean(x1) for x1 in bs_x1]
    ci_l_x1 = [np.quantile(x1, 0.25) for x1 in bs_x1]
    ci_u_x1 = [np.quantile(x1, 0.95) for x1 in bs_x1]
    for s, i, indx, level in zip(bs_x1, labels, range(100), levels):
        axs[4].plot((s[0:100]),np.linspace(levels[indx-1], levels[indx], 100), 'o', alpha=0.3, color='slateblue')
    axs[4].plot(means_x1, levels, linestyle='dashdot', color='darksalmon', linewidth=4)
    axs[4].plot(ci_l_x1, levels, color='darksalmon', linewidth=4)
    axs[4].plot(ci_u_x1, levels, color='darksalmon', linewidth=4)
    axs[4].fill_betweenx(levels, means_x1, ci_l_x1, color='darksalmon')
    axs[4].fill_betweenx(levels, means_x1, ci_u_x1, color='darksalmon')
    axs[4].set_ylim(10, 100)
    axs[4].set_yticklabels([], fontsize=12);


    ### Bootstrap X2 samples
    bs_x2 = [x2.sample(replace=True, n=len(x2)).values for i in range(100)]
    means_x2 = [np.mean(x2) for x2 in bs_x2]
    ci_l_x2 = [np.quantile(x2, 0.25) for x2 in bs_x2]
    ci_u_x2 = [np.quantile(x2, 0.95) for x2 in bs_x2]
    for s, i, indx, level in zip(bs_x2, labels, range(100), levels):
        axs[5].plot((s[0:100]),np.linspace(levels[indx-1], levels[indx], 100), 'o', alpha=0.3, color='slateblue')
    axs[5].plot(means_x2, levels, linestyle='dashdot', color='darksalmon', linewidth=4)
    axs[5].plot(ci_l_x2, levels, color='darksalmon', linewidth=4)
    axs[5].plot(ci_u_x2, levels, color='darksalmon', linewidth=4)
    axs[5].fill_betweenx(levels, means_x2, ci_l_x2, color='darksalmon')
    axs[5].fill_betweenx(levels, means_x2, ci_u_x2, color='darksalmon')
    axs[5].set_ylim(10, 100)
    axs[5].set_yticklabels([], fontsize=12);


    bs_x_means = [x.sample(replace=True, n=len(x)).mean() for i in range(1000)]
    axs[6].hist(bs_x_means, edgecolor='black', color='yellow', label='Sampling Distribution of Means')
    axs[6].axvline(np.mean(bs_x_means), linestyle='--', color='black', label="Est'd Pop Mean")
    axs[6].set_yticks([])

    legend = axs[6].legend(bbox_to_anchor=(-1.5, 1, 0, 0), fontsize=20, title="Bootstrapped Estimates Distribution", loc='upper left')
    plt.setp(legend.get_title(),fontsize=25)

    bs_x1_means = [x1.sample(replace=True, n=len(x1)).mean() for i in range(1000)]
    axs[7].hist(bs_x1_means, edgecolor='black', color='yellow', label='Sampling Distribution of Means')
    axs[7].axvline(np.mean(bs_x1_means), linestyle='--', color='black')
    axs[7].set_yticks([])

    bs_x2_means = [x2.sample(replace=True, n=len(x1)).mean() for i in range(1000)]
    axs[8].hist(bs_x2_means, edgecolor='black', color='yellow', label='Sampling Distribution of Means')
    axs[8].axvline(np.mean(bs_x2_means), linestyle='--', color='black')
    axs[8].set_yticks([]);
    plt.suptitle("Bootstrapped Estimation for A/B Testing", fontsize=30)

make_bootstrap_plot(x, x1, x2)


This is SO COOL.

Thank you for sharing!

1 Like