Source code for corradjust.plots

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.ticker import MaxNLocator
from matplotlib.lines import Line2D
from adjustText import adjust_text
import numpy as np
import warnings

from corradjust.utils import *


[docs] class PCAVariancePlotter: """ Create PCA explained variance plot. Parameters ---------- plot_width : float, optional, default=6.4 Plot width. plot_height : float, optional, default=4.8 Plot height. Attributes ---------- fig : matplotlib.figure.Figure axs : dict Keys of `axs` are strings ``"individual"`` and ``"cumulative"`` referring to the two panels of the plot. Values of `axs` are instances of `matplotlib.axes.Axes`. """ def __init__(self, plot_width=6.4, plot_height=4.8): fig, axs = plt.subplots(1, 2, figsize=(plot_width * 2, plot_height)) axs = {"individual": axs[0], "cumulative": axs[1]} axs["individual"].set_xlabel("PC") axs["individual"].set_ylabel("% variance (individual)") axs["cumulative"].set_xlabel("PC") axs["cumulative"].set_ylabel("% variance (cumulative)") axs["cumulative"].set_ylim([0, 105]) self.fig = fig self.axs = axs
[docs] def plot(self, PCA_model, n_PCs): """ Draw the plots. Parameters ---------- PCA_model : sklearn.decomposition.PCA PCA model. The `fit` method should be called on `PCA_model` prior to calling this method. n_PCs : int Number of PCs to plot on the X-axis. """ vars = PCA_model.explained_variance_ratio_ * 100 comps = np.arange(1, len(vars) + 1) sns.lineplot( x=comps, y=vars, marker="o", markeredgewidth=0, ax=self.axs["individual"] ) self.axs["individual"].set_ylim([0, np.max(vars) + 5]) sns.lineplot( x=comps, y=np.cumsum(vars), marker="o", markeredgewidth=0, ax=self.axs["cumulative"] ) self.axs["cumulative"].axvline( n_PCs, ls=":", color="red", label=f"Knee (n={n_PCs})" )
[docs] def save_plot(self, out_path, title=None): """ Save the plot. This method doesn't call `plt.close`, so it will display the figure in jupyter notebook in addition to saving the file. Parameters ---------- out_path : str Path to the figure (with extension, e.g., ``".png"``). title : str or None, optional, default=None Short text to show at the top-left corner of the plot. """ self.axs["cumulative"].legend(loc="lower right") if title: self.fig.text(0.01, 0.99, title, va="top", transform=self.fig.transFigure) self.fig.tight_layout(rect=[0, 0, 1, 0.98]) else: self.fig.tight_layout() self.fig.savefig(out_path, dpi=300)
[docs] class GreedyOptimizationPlotter: """ Create a lineplot with score optimization trajectories. Parameters ---------- samp_group_name : str or None, optional, default=None Main title with sample group name to use in the plot. metric : {"enrichment-based", "BP@K"}, optional, default="enrichment-based" Metric for evaluating feature correlations. palette : str or list or dict, optional Name of matplotlib colormap or list of colors for the lines or dict mapping reference collection names to colors. The argument is directly passed to `sns.lineplot`. legend_loc : str, optional, default="lower right" Where to put the legend (follows matplotlib notation). legend_fontsize : int, optional, default=10 Font size of legend text. plot_width : float, optional, default=6.4 Plot width. plot_height : float, optional, default=4.8 Plot height. Attributes ---------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes """ def __init__( self, samp_group_name=None, metric="enrichment-based", palette=(lambda x: x[::-1][:2] + x[::-1][3:])(sns.color_palette("tab10")), legend_loc="lower right", legend_fontsize=10, plot_width=6.4, plot_height=4.8 ): # Now create a figure for greedy optimization fig, ax = plt.subplots(1, 1, figsize=(plot_width, plot_height)) ax.set_xlabel("Iteration") ax.xaxis.set_major_locator(MaxNLocator(integer=True)) if metric == "enrichment-based": ax.set_ylabel("Average -log$_{{10}}$(adjusted p-value)") else: ax.set_ylabel("Balanced precision at K") if samp_group_name: ax.set_title(samp_group_name) self.palette = palette self.legend_loc = legend_loc self.legend_fontsize = legend_fontsize self.fig = fig self.ax = ax
[docs] def plot(self, df, peak_iter): """ Draw the lines. Parameters ---------- df : pandas.DataFrame Data frame with scores. One can generate `df` by taking ``fit.tsv`` table and selecting columns for only one sample group. See `CorrAdjust._make_best_iter_scores_plot`` method for an example. peak_iter : int This number is used to draw vertical dashed red line at the selected early stopping iteration. """ self.peak_iter = peak_iter # Remove sample group name from the columns df = df.rename(columns={col: ";".join(col.split(";")[1:]) for col in df.columns}) # Plot is either called with training and validation (pairs) # or with training and test (samples) if "mean;validation" in set(df.columns): subset_label = "Feature pairs" self.validation_pairs = True # 3 columns per reference collection (training, validation, all) self.num_ref_feature_colls = len(df.columns) // 3 # Plot only training and validation df = df.drop(columns=[ col for col in df.columns if not col.endswith(";training") and not col.endswith(";validation") ]) else: subset_label = "Samples" # 2 columns per reference collection (training, test) self.num_ref_feature_colls = len(df.columns) // 2 # Plot only training and test df = df.drop(columns=[ col for col in df.columns if not col.endswith(";training") and not col.endswith(";test") ]) # Re-order columns so that mean goes first (it is last by default) df = df[df.columns[-2:].tolist() + df.columns[:-2].tolist()] # We don't plot mean if there is only 1 reference collection # self.num_ref_feature_colls already includes "mean" as a collection if self.num_ref_feature_colls == 2: df = df.iloc[:, 2:] self.num_ref_feature_colls = 1 df = df.melt(var_name="metric_name", value_name="score", ignore_index=False) df = df.reset_index() df["Ref. collection"] = df["metric_name"].str.split(";").str[0] # Capitalize word "mean" df["Ref. collection"] = df["Ref. collection"].str.replace("mean", "Mean") df[subset_label] = df["metric_name"].str.split(";").str[-1].str.capitalize() # Lineplot produces annoying warning that there are # more colors in the palette than needed; suppress it with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) sns.lineplot( x="Iteration", y="score", hue="Ref. collection", style=subset_label, data=df, palette=self.palette, dashes=["", (1, 1)], markers=["o", "X"], markeredgewidth=0, markersize=4, ax=self.ax ) # Set highest zorder for the first drawn lines zorder = 1000 for line in self.ax.get_lines(): line.set_zorder(zorder) zorder -= 1 x_pad = df["Iteration"].max() * 0.02 self.ax.set_xlim(-x_pad, df["Iteration"].max() + x_pad) y_pad = df["score"].max() * 0.02 if df["score"].max() != 0: self.ax.set_ylim(-y_pad, df["score"].max() + y_pad) else: self.ax.set_ylim(-0.01, 1.02)
[docs] def save_plot(self, out_path, title=None): """ Save the plot. This method doesn't call `plt.close`, so it will display the figure in jupyter notebook in addition to saving the file. Parameters ---------- out_path : str Path to the figure (with extension, e.g., ``.png``). title : str or None, optional, default=None Short text to show at the top-left corner of the plot. """ # We draw this here to make legend code easier self.ax.axvline( self.peak_iter, ls=":", color="red", label=f"Iter = {self.peak_iter}" ) # Make a legend handles, labels = self.ax.get_legend_handles_labels() new_handles, new_labels = [], [] empty_handle = mpatches.Patch(color="none") empty_label = " " # Number of rows in the last column n_last = 5 nrows = max(self.num_ref_feature_colls + 1, n_last) # No more than 5 rows per column nrows = min(nrows, 5) # For reference collections, each column has a header # So, we have 4 collections per column ncols = self.num_ref_feature_colls // (nrows - 1) if self.num_ref_feature_colls % (nrows - 1): ncols += 1 # Last column is always for training, validation, and early stopping ncols += 1 # Legend for reference sets # First, fill in complete columns complete_columns = self.num_ref_feature_colls // (nrows - 1) for i in range(complete_columns): new_handles += [handles[0]] new_labels += [labels[0]] new_handles += handles[1 + (nrows - 1) * i : 1 + (nrows - 1) * (i + 1)] new_labels += labels[1 + (nrows - 1) * i : 1 + (nrows - 1) * (i + 1)] # An additional incomplete column if self.num_ref_feature_colls % (nrows - 1): new_handles += [handles[0]] new_labels += [labels[0]] new_handles += handles[1 + (nrows - 1) * complete_columns:1 + self.num_ref_feature_colls] new_labels += labels[1 + (nrows - 1) * complete_columns:1 + self.num_ref_feature_colls] new_handles += [empty_handle] * (nrows - 1 - self.num_ref_feature_colls % (nrows - 1)) new_labels += [empty_label] * (nrows - 1 - self.num_ref_feature_colls % (nrows - 1)) # Last column new_handles += handles[self.num_ref_feature_colls + 1:self.num_ref_feature_colls + 1 + 3] new_labels += labels[self.num_ref_feature_colls + 1:self.num_ref_feature_colls + 1 + 3] new_handles += [empty_handle] + handles[-1:] new_labels += ["Early stopping"] + labels[-1:] legend = self.ax.legend( new_handles, new_labels, ncols=ncols, loc=self.legend_loc, fontsize=self.legend_fontsize ) legend.set_zorder(2000) # Remove left padding for legend titles # first loop goes over legend columns for i, vpack in enumerate(legend._legend_handle_box.get_children()): for j, hpack in enumerate(vpack.get_children()): if j == 0 or (i == ncols - 1 and j == 3): hpack.get_children()[0].set_width(0) if title: self.fig.text(0.01, 0.99, title, va="top", transform=self.fig.transFigure) self.fig.tight_layout(rect=[0, 0, 1, 0.98]) else: self.fig.tight_layout() self.fig.savefig(out_path, dpi=300)
[docs] class VolcanoPlotter: """ Create a volcano plot with feature-wise enrichment statistics. Parameters ---------- corr_scorer : CorrScorer Instance of `CorrScorer`. annotate_features : int or None, optional, default=None How many features with the lowest adjusted p-value to annotate. annot_fontsize : int, optional, default=8 Font size of feature names when ``annotate_features=True``. feature_name_fmt : function or None, optional, default=None Function that maps feature names to labels to show on the plot when ``annotate_features=True``. If ``None``, shows unmodified feature names. signif_color : matplotlib color, optional Color to draw statistically significant features. nonsignif_color : matplotlib color, optional Color to draw non-significant features. panel_size : float, optional, default=4.8 Size of each (square) panel. Attributes ---------- fig : matplotlib.figure.Figure axs : dict Keys of `axs` are tuples ``(ref_feature_coll, state)``, where ``state`` is either ``"Raw"`` or ``"Clean"``. Values of `axs` are instances of `matplotlib.axes.Axes`. """ def __init__( self, corr_scorer, annotate_features=False, annot_fontsize=8, feature_name_fmt=None, signif_color=(*sns.color_palette("tab10")[1], 0.9), nonsignif_color=(0.6, 0.6, 0.6, 0.5), panel_size=4.8 ): n_rows = len(corr_scorer.data) n_columns = 2 fig, axs = plt.subplots(n_rows, n_columns, figsize=(panel_size * n_columns, panel_size * n_rows)) # By default, axs object is 2D array # We convert it to dict to have meaningful keys axs_dict = {} for i, ref_feature_coll in enumerate(corr_scorer.data): for j, state in enumerate(["Raw", "Clean"]): if len(corr_scorer.data) > 1: axs_dict[(ref_feature_coll, state)] = axs[i, j] else: axs_dict[(ref_feature_coll, state)] = axs[j] axs = axs_dict for ref_feature_coll in corr_scorer.data: for state in ["Raw", "Clean"]: axs[(ref_feature_coll, state)].set_title(f"{ref_feature_coll}, {state.lower()} corr.") axs[(ref_feature_coll, state)].set_xlabel("Balanced precision") axs[(ref_feature_coll, state)].set_xlim([-0.03, 1.03]) axs[(ref_feature_coll, state)].set_xticks(np.arange(0, 1.1, 0.1)) axs[(ref_feature_coll, state)].set_ylabel("$-$log$_{{10}}$(adjusted p-value)") self.fig = fig self.axs = axs self.annotate_features = annotate_features self.annot_fontsize = annot_fontsize self.feature_name_fmt = feature_name_fmt self.annotations = {} self.metric = corr_scorer.metric self.signif_color = signif_color self.nonsignif_color = nonsignif_color self.min_marker_size = 10 # seaborn default 18 self.max_marker_size = 100 # seaborn default 72
[docs] def plot(self, feature_scores): """ Create a volcano plot. Parameters ---------- feature_scores : dict Dict with feature-wise enrichment scores generated by the `CorrAdjust.compute_feature_scores` method. """ # We need to have identical y-axis limits and legends for numbers of # pairs beterrn raw and clean corrs num_pairs_max = {} for ref_feature_coll in feature_scores["Raw"]: df_raw = feature_scores["Raw"][ref_feature_coll].dropna() padj_raw = -np.log10(df_raw["padj"]) num_pairs_raw = df_raw["ref_pairs@K"].str.split("/").str[1].astype("int64") if feature_scores["Clean"] is not None: df_clean = feature_scores["Clean"][ref_feature_coll].dropna() padj_clean = -np.log10(df_clean["padj"]) num_pairs_clean = df_clean["ref_pairs@K"].str.split("/").str[1].astype("int64") padj_max = max(np.max(padj_raw), np.max(padj_clean)) num_pairs_max[ref_feature_coll] = max(np.max(num_pairs_raw), np.max(num_pairs_clean)) else: padj_max = np.max(padj_raw) num_pairs_max[ref_feature_coll] = np.max(num_pairs_raw) if padj_max != 0: y_pad = padj_max * 0.03 # So the points won't stick to ax bottom/top else: y_pad = 1 for state in ["Raw", "Clean"]: self.axs[(ref_feature_coll, state)].set_ylim([-y_pad, padj_max + y_pad]) # Plot boundaries of significance if padj_max > -np.log10(0.05): self.axs[(ref_feature_coll, state)].fill_between( [0.5, 1.03], [padj_max + y_pad, padj_max + y_pad], [-np.log10(0.05), -np.log10(0.05)], facecolor=self.signif_color, alpha=0.05, zorder=0 ) self.axs[(ref_feature_coll, state)].plot( [0.5, 1.03], [-np.log10(0.05), -np.log10(0.05)], color=self.signif_color, alpha=0.5, lw=1., zorder=0 ) self.axs[(ref_feature_coll, state)].plot( [0.5, 0.5], [-np.log10(0.05), padj_max + y_pad], color=self.signif_color, alpha=0.5, lw=1.0, zorder=0 ) self._plot_one_state("Raw", feature_scores["Raw"], num_pairs_max) if feature_scores["Clean"] is not None: self._plot_one_state("Clean", feature_scores["Clean"], num_pairs_max)
[docs] def _plot_one_state( self, state, feature_scores_state, num_pairs_max ): """ Make the plot for ``"Raw"`` or ``"Clean"`` data. Parameters ---------- state : {"Raw", "Clean"} feature_scores_state : {feature_scores["Raw"], feature_scores["Clean"]} num_pairs_max : int Maximum number for ``K_j`` to scale the markers and display in legend. """ for ref_feature_coll in feature_scores_state: df_plot = feature_scores_state[ref_feature_coll].dropna().copy() df_plot["padj"] = -np.log10(df_plot["padj"]) df_plot["num_pairs"] = df_plot["ref_pairs@K"].str.split("/").str[1].astype("int64") if self.metric == "enrichment-based": avg_log_padj = df_plot["padj"].mean() score_label = f"{avg_log_padj:.2f}" else: agg_BP_at_K, agg_enrichment, agg_pvalue = ( compute_aggregated_scores(df_plot) ) score_label = f"{agg_BP_at_K:.2f}" perc_signif = (10**(-df_plot["padj"]) <= 0.05).sum() / len(df_plot) * 100 signif_label = f"Yes ({np.round(perc_signif, 1)}%)" nonsignif_label = f"No ({np.round(100 - np.round(perc_signif, 1), 1)}%)" df_plot["Adj. p ≤ 0.05"] = [ signif_label if 10**(-p) <= 0.05 else nonsignif_label for p in df_plot["padj"] ] sns.scatterplot( x="balanced_precision", y="padj", hue="Adj. p ≤ 0.05", size="num_pairs", data=df_plot, hue_order=[signif_label, nonsignif_label], palette={ signif_label: self.signif_color, nonsignif_label: self.nonsignif_color }, sizes=(self.min_marker_size, self.max_marker_size), size_norm=(0, num_pairs_max[ref_feature_coll]), ax=self.axs[(ref_feature_coll, state)] ) self.axs[(ref_feature_coll, state)].set_title( self.axs[(ref_feature_coll, state)].get_title() + f", score = {score_label}" ) # Fix legend empty_handle = mpatches.Patch(color="none") empty_label = " " handles, labels = self.axs[(ref_feature_coll, state)].get_legend_handles_labels() # Keep the legend part about statistical significance handles, labels = handles[:3], labels[:3] # Empty line handles.append(empty_handle) labels.append(empty_label) # Title handles.append(empty_handle) labels.append("# highly ranked\npairs ($K_j$)") # Markers handles += [ Line2D( [], [], markersize=np.sqrt(self.min_marker_size), markeredgewidth=1, color="black", markeredgecolor="white", linestyle="None", marker="o" ), Line2D( [], [], markersize=np.sqrt(self.max_marker_size), markeredgewidth=1, color="black", markeredgecolor="white", linestyle="None", marker="o" ) ] labels += ["0", str(num_pairs_max[ref_feature_coll])] legend = self.axs[(ref_feature_coll, state)].legend(handles, labels, loc="upper left") # Remove left padding for legend titles # first loop is just 1 iteration since legend is 1 column for vpack in legend._legend_handle_box.get_children(): for j, hpack in enumerate(vpack.get_children()): if j in {0, 4}: hpack.get_children()[0].set_width(0) if self.annotate_features: # Annotate features with the lowest padj df_top_features = df_plot.loc[10**(-df_plot["padj"]) <= 0.05].iloc[:self.annotate_features] annotations = [] for _, row in df_top_features.iterrows(): if not self.feature_name_fmt: text = row["feature_name"] else: text = self.feature_name_fmt(row["feature_name"]) annotation = self.axs[(ref_feature_coll, state)].text( x=row["balanced_precision"], y=row["padj"], s=text, ha="center", va="center", fontsize=self.annot_fontsize ) annotations.append(annotation) self.annotations[(ref_feature_coll, state)] = annotations
[docs] def save_plot(self, out_path, title=None): """ Save the plot. This method doesn't call `plt.close`, so it will display the figure in jupyter notebook in addition to saving the file. Parameters ---------- out_path : str Path to the figure (with extension, e.g., ``.png``). title : str or None, optional, default=None Short text to show at the top-left corner of the plot. """ if title: self.fig.text(0.01, 0.99, title, va="top", transform=self.fig.transFigure) self.fig.tight_layout(rect=[0, 0, 1, 0.98]) else: self.fig.tight_layout() # adjust_text should be called after everything else is done if self.annotate_features: for ref_feature_coll, state in self.annotations: adjust_text( self.annotations[(ref_feature_coll, state)], ax=self.axs[(ref_feature_coll, state)], arrowprops=dict(arrowstyle="-", color="gray", alpha=0.5, lw=0.5, shrinkA=1, shrinkB=1), explode_radius=10 ) self.fig.savefig(out_path, dpi=300)
[docs] class CorrDistrPlotter: """ Create KDE and CDF plots of correlations. Parameters ---------- corr_scorer : CorrScorer Instance of `CorrScorer`. pairs_subset : {"all", "training", "validation"}, optional, default="all" Which set of feature pairs to use for computing scores. color_raw_ref : color, optional, default=sns.color_palette("tab20")[0] Color for raw correlations, reference feature pairs. color_raw_non_ref : color, optional, default=sns.color_palette("tab20")[1], Color for raw correlations, non-reference feature pairs. color_clean_ref : color, optional, default=sns.color_palette("tab20")[2], Color for clean correlations, reference feature pairs. color_clean_non_ref : color, optional, default=sns.color_palette("tab20")[3], Color for clean correlations, non-reference feature pairs. legend_fontsize : int, optional, default=10 Font size of legend text. panel_size : float, optional, default=4.8 Size of each (square) panel. Attributes ---------- fig : matplotlib.figure.Figure axs : dict Keys of `axs` are tuples ``(ref_feature_coll, plot_name)``, where ``plot_name`` is either ``"corr-KDE"`` or ``"corr-CDF"``. Values of `axs` are instances of `matplotlib.axes.Axes`. """ def __init__( self, corr_scorer, pairs_subset="all", color_raw_ref=sns.color_palette("tab20")[6], color_raw_non_ref=sns.color_palette("tab20")[7], color_clean_ref=sns.color_palette("tab20")[4], color_clean_non_ref=sns.color_palette("tab20")[5], legend_fontsize=10, panel_size=4.8 ): n_rows = len(corr_scorer.data) n_columns = 2 fig, axs = plt.subplots(n_rows, n_columns, figsize=(panel_size * n_columns, panel_size * n_rows)) # By default, axs object is 2D array # We convert it to dict to have meaningful keys axs_dict = {} for i, ref_feature_coll in enumerate(corr_scorer.data): for j, plot_name in enumerate(["corr-KDE", "corr-CDF"]): if len(corr_scorer.data) > 1: axs_dict[(ref_feature_coll, plot_name)] = axs[i, j] else: axs_dict[(ref_feature_coll, plot_name)] = axs[j] axs = axs_dict for ref_feature_coll in corr_scorer.data: sign = corr_scorer.data[ref_feature_coll]["sign"] if sign == "absolute": corr_label = "Absolute correlation" corr_lim = [-0.05, 1.05] else: corr_label = "Correlation" corr_lim = [-1.05, 1.05] high_corr_frac = corr_scorer.data[ref_feature_coll]["high_corr_frac"] axs[(ref_feature_coll, "corr-KDE")].set_title(ref_feature_coll) axs[(ref_feature_coll, "corr-KDE")].set_xlabel("Correlation") axs[(ref_feature_coll, "corr-KDE")].set_xlim(-1.05, 1.05) axs[(ref_feature_coll, "corr-KDE")].set_ylabel("Density") axs[(ref_feature_coll, "corr-CDF")].set_title(ref_feature_coll) axs[(ref_feature_coll, "corr-CDF")].set_xlabel(corr_label) axs[(ref_feature_coll, "corr-CDF")].set_xlim(*corr_lim) axs[(ref_feature_coll, "corr-CDF")].set_ylabel(f"Cumulative fraction of feature pairs") axs[(ref_feature_coll, "corr-CDF")].set_yscale("log") axs[(ref_feature_coll, "corr-CDF")].axhline( high_corr_frac, ls=":", color="grey", label=f"Highly ranked pairs" ) self.legend_fontsize = legend_fontsize self.colors = { ("Raw", 0): color_raw_non_ref, ("Raw", 1): color_raw_ref, ("Clean", 0): color_clean_non_ref, ("Clean", 1): color_clean_ref } self.fig = fig self.axs = axs self.pairs_subset = pairs_subset self.signs = { ref_feature_coll: corr_scorer.data[ref_feature_coll]["sign"] for ref_feature_coll in corr_scorer.data }
[docs] def add_plots( self, corr_scores, state, num_points=100000 ): """ Make KDE and eCDF plots. Parameters ---------- corr_scores : dict Results of `CorrAdjust.compute_corr_scores` method. state : {"Raw", "Clean"} num_points : int, optional, default=100000 How many correlations to sample for plotting. """ for ref_feature_coll in corr_scores: corrs = corr_scores[ref_feature_coll]["corrs"] mask = corr_scores[ref_feature_coll]["mask"] train_val_mask = corr_scores[ref_feature_coll]["train_val_mask"] # Limit to training/validation pairs if needed if self.pairs_subset == "training": corrs = corrs[train_val_mask == 0] mask = mask[train_val_mask == 0] elif self.pairs_subset == "validation": corrs = corrs[train_val_mask == 1] mask = mask[train_val_mask == 1] for ref_flag in [1, 0]: corrs_subset = corrs[mask == ref_flag] color = self.colors[(state, ref_flag)] linestyle = "-" if ref_flag == 1 else "--" # Make KDE plot # Downsample corrs to allow plotting in adequate time if len(corrs_subset) >= num_points: idx_grid = np.linspace(0, len(corrs_subset) - 1, num=num_points, dtype=int) # dtype=int might cause duplicates because of rounding - kill them idx_grid = np.unique(idx_grid) corrs_for_KDE = corrs_subset[idx_grid] else: corrs_for_KDE = corrs_subset sns.kdeplot( x=corrs_for_KDE, label=f"{state}, {'ref. pairs' if ref_flag == 1 else 'non-ref. pairs'}", color=color, linestyle=linestyle, ax=self.axs[(ref_feature_coll, "corr-KDE")] ) # Make CDF plot ranks = np.arange(1, corrs_subset.shape[0] + 1) fractions = ranks / np.max(ranks) # Downsample corrs to allow plotting in adequate time # Since Y axis is log, we use logspace of indices if len(corrs_subset) >= num_points: idx_grid = np.logspace(0, np.log10(len(corrs_subset)), num=num_points, base=10, dtype=int) - 1 # dtype=int might cause duplicates because of rounding - kill them idx_grid = np.unique(idx_grid) corrs_for_CDF = corrs_subset[idx_grid] fractions = fractions[idx_grid] else: corrs_for_CDF = corrs_subset if self.signs[ref_feature_coll] == "absolute": corrs_for_CDF = np.abs(corrs_for_CDF) self.axs[(ref_feature_coll, "corr-CDF")].plot( corrs_for_CDF, fractions, label=f"{state}, {'ref. pairs' if ref_flag == 1 else 'non-ref. pairs'}", color=color, linestyle=linestyle )
[docs] def save_plot(self, out_path, title=None): """ Save the plot. This method doesn't call `plt.close`, so it will display the figure in jupyter notebook in addition to saving the file. Parameters ---------- out_path : str Path to the figure (with extension, e.g., ``.png``). title : str or None, optional, default=None Short text to show at the top-left corner of the plot. """ for (ref_feature_coll, plot_name), ax in self.axs.items(): sign = self.signs[ref_feature_coll] handles, labels = ax.get_legend_handles_labels() if plot_name == "corr-CDF": # Put K last handles = handles[1:] + [handles[0]] labels = labels[1:] + [labels[0]] loc = "lower right" if sign == "negative" else "lower left" if plot_name == "corr-KDE": loc = "upper left" legend = ax.legend(handles, labels, loc=loc, fontsize=self.legend_fontsize) # For KDE plot, we increase ylim to fit the legend on top if plot_name == "corr-KDE": bbox = legend.get_window_extent() legend_frac = bbox.transformed(ax.transAxes.inverted()).height y_min, y_max = ax.get_ylim() y_max_new = (y_max - y_min) / (1 - legend_frac) ax.set_ylim([-y_max_new * 0.01, y_max_new]) if title: self.fig.text(0.01, 0.99, title, va="top", transform=self.fig.transFigure) self.fig.tight_layout(rect=[0, 0, 1, 0.98]) else: self.fig.tight_layout() self.fig.savefig(out_path, dpi=300)