Skip to content

Technical Reference

Data Import

Defines the central data type along with importing logic.

LimeSurveyData(structure_file, responses_file)

Base LimeSurvey class.

Parameters:

  • structure_file (Path) –

    path to the structure XML file

  • responses_file (Path) –

    path to the responses CVS file

Source code in src/survey_framework/data_import/data_import.py
def __init__(
    self,
    structure_file: Path,
    responses_file: Path,
) -> None:
    """Initialize an instance of the Survey.

    Args:
        structure_file: path to the structure XML file
        responses_file: path to the responses CVS file
    """
    # Store path to structure file
    self._read_structure(structure_file)
    self._read_responses(responses_file)

__str__()

Print all questions, responses and sections for debugging.

Source code in src/survey_framework/data_import/data_import.py
def __str__(self) -> str:
    """Print all questions, responses and sections for debugging."""
    string = f"QUESTIONS\n{self.questions}\n"
    string += f"RESPONSES\n{self.responses}\n"
    string += f"SECTIONS\n{self.sections}\n"
    return string

export_Qs_to_CSV(output_path)

Export the question sheet from the survey to CSV.

Parameters:

  • output_path (Path) –

    output path to where CSV is saved

Source code in src/survey_framework/data_import/data_import.py
def export_Qs_to_CSV(self, output_path: Path) -> None:
    """Export the question sheet from the survey to CSV.

    Args:
        output_path: output path to where CSV is saved
    """
    output_path.mkdir(parents=True, exist_ok=True)

    output = Path(output_path / "Q.csv")
    self.questions.to_csv(output)

get_choices(question)

Get choices of a question.

  • For multiple-choice group, format is <subquestion code: subquestion title>, for example, {"C3_SQ001": "I do not like scientific work.", "C3_SQ002": ...}
  • For all other fixed questions (i.e. array, single choice, subquestion), returns choices of that question or column
  • For free and contingent, returns None

Parameters:

  • question (str) –

    Name of question or subquestion to retrieve

Returns:

Source code in src/survey_framework/data_import/data_import.py
def get_choices(self, question: str) -> dict[str, str]:
    """Get choices of a question.

    * For multiple-choice group, format is `<subquestion code: subquestion title>`,
      for example, {"C3_SQ001": "I do not like scientific work.", "C3_SQ002": ...}
    * For all other fixed questions (i.e. array, single choice, subquestion),
      returns choices of that question or column
    * For free and contingent, returns None

    Args:
        question: Name of question or subquestion to retrieve

    Returns:
        dict of choices mappings
    """
    question_info = self.get_question(question)
    question_info = question_info[~question_info.is_contingent]
    question_type = self.get_question_type(question)

    # If set of multiple-choice questions
    if (question_info.shape[0] > 1) and (
        question_type == QuestionType.MULTIPLE_CHOICE
    ):
        # Flatten nested dict and get choice text directly for multiple-choice
        choices_dict = {
            cast(str, index): row.choices["Y"]
            for index, row in question_info.iterrows()
        }
    # If single-choice, free, individual subquestion, or array
    else:
        choices_dict = question_info.choices.iloc[0]

    return choices_dict

get_question(question, drop_other=False)

Get question structure (i.e. subset from self.questions).

Parameters:

  • question (str) –

    Name of question or subquestion

  • drop_other (bool, default: False ) –

    Whether to exclude contingent question (i.e. "other")

Raises:

  • ValueError

    There is no such question or subquestion

Returns:

  • DataFrame

    pd.DataFrame: Subset from self.questions with corresponding rows

Source code in src/survey_framework/data_import/data_import.py
def get_question(self, question: str, drop_other: bool = False) -> pd.DataFrame:
    """Get question structure (i.e. subset from self.questions).

    Args:
        question: Name of question or subquestion
        drop_other: Whether to exclude contingent question (i.e. "other")

    Raises:
        ValueError: There is no such question or subquestion

    Returns:
        pd.DataFrame: Subset from `self.questions` with corresponding rows
    """
    questions_subdf = self.questions[
        (self.questions["question_group"] == question)
        | (self.questions.index == question)
    ]

    if questions_subdf.empty:
        raise ValueError(f"Unexpected question code '{question}'")

    if drop_other:
        questions_subdf = questions_subdf[~questions_subdf.is_contingent]

    return questions_subdf

get_question_type(question)

Get question type and validate it.

Parameters:

  • question (str) –

    question or column code

Raises:

Returns:

  • QuestionType ( QuestionType ) –

    Question type like "single-choice", "array", etc.

Source code in src/survey_framework/data_import/data_import.py
def get_question_type(self, question: str) -> QuestionType:
    """Get question type and validate it.

    Args:
        question: question or column code

    Raises:
        AssertionError: Unconsistent question types within question
        ValueError: Unexpected question type

    Returns:
        QuestionType: Question type like "single-choice", "array", etc.
    """
    question_group = self.get_question(question)
    question_types = question_group.type.unique()

    if len(question_types) > 1:
        raise AssertionError(
            f"Question {question} has multiple types {list(question_types)}."
        )

    question_type = QuestionType(question_types[0])

    return question_type

get_questions_by_type(type)

Get a list of all questions with the given QuestionType.

Parameters:

  • type (QuestionType) –

    Desired QuestionType, e.g. SINGLE_CHOICE.

Returns:

  • list[str]

    Question IDs for all matching questions.

Source code in src/survey_framework/data_import/data_import.py
def get_questions_by_type(self, type: QuestionType) -> list[str]:
    """Get a list of all questions with the given QuestionType.

    Args:
        type: Desired QuestionType, e.g. SINGLE_CHOICE.

    Returns:
        Question IDs for all matching questions.
    """
    return list(
        self.questions.loc[self.questions["type"] == type.value]["question_group"]
        .unique()
        .tolist()
    )

get_responses(question, drop_other=False)

Get responses for given question with or without contingent questions.

Parameters:

  • question (str) –

    Question to get the responses for.

  • drop_other (bool, default: False ) –

    Whether to exclude contingent question (i.e. "other")

Raises:

  • ValueError

    Inconsistent question types within question groups.

  • ValueError

    Unknown question types.

Returns:

  • DataFrame

    The response data for the selected question.

Source code in src/survey_framework/data_import/data_import.py
def get_responses(
    self,
    question: str,
    drop_other: bool = False,
) -> pd.DataFrame:
    """Get responses for given question with or without contingent questions.

    Args:
        question: Question to get the responses for.
        drop_other: Whether to exclude contingent question (i.e. "other")

    Raises:
        ValueError: Inconsistent question types within question groups.
        ValueError: Unknown question types.

    Returns:
        The response data for the selected question.
    """
    question_group = self.get_question(question, drop_other=drop_other)
    question_type = self.get_question_type(question)

    responses = self.responses.loc[:, list(question_group.index)]

    # convert multiple-choice responses
    if question_type == QuestionType.MULTIPLE_CHOICE:
        # ASSUME: question response consists of multiple columns with
        #         'Y' or NaN as entries.
        # Masked with boolean values the responses with nan only for
        # the columns where is_contingent is True.
        responses[question_group.index[~question_group.is_contingent]] = (
            responses.loc[:, ~question_group.is_contingent].notnull()
        )

    assert isinstance(responses, pd.DataFrame)
    return responses

query(expr)

Filter responses DataFrame with a boolean expression.

Parameters:

  • expr (str) –

    Condition str for pd.DataFrame.query(). E.g. "A6 == 'A3' & "B2 == 'A5'"

Returns:

  • DataFrame

    pd.DataFrame: Filtered responses

Source code in src/survey_framework/data_import/data_import.py
def query(self, expr: str) -> pd.DataFrame:
    """Filter responses DataFrame with a boolean expression.

    Args:
        expr: Condition str for pd.DataFrame.query().
            E.g. "A6 == 'A3' & "B2 == 'A5'"

    Returns:
        pd.DataFrame: Filtered responses
    """
    return self.responses.query(expr)

QuestionType

Bases: StrEnum

Each type of question has a distinct data format.

Data Aggregation

Convert raw dataframes into cleaned, sorted, counted dataframes ready for plotting.

We have specific functions for single-choice and multiple-choice questions, as well as "grouped" variants for both (which can be used for comparison barplots).

prepare_df_comparison(responses_df_all, responses_df_comparison, q, q_comparison, ordering)

Compare groups of participants (determined by comparison_series).

This function is for single-choice questions.

The output dataframe contains the following columns
  • q: The answer options
  • q_comparison: The groups
  • total: total number of participants in this group
  • count: number of participants (in this group) that gave this answer
  • proportion: share of participants (relative to "total") that gave this answ.

Parameters:

  • responses_df_all (DataFrame) –

    DataFrame with answers for the base question

  • responses_df_comparison (Series[str]) –

    Answers for the intersecting question

  • q (str) –

    name of the output column for answer options

  • q_comparison (str) –

    name of the output column for groups

  • ordering (dict[str, list[str]]) –

    Answer re-ordering dict, e.g. ORDER from order/order2024.py

Returns:

  • tuple[DataFrame, dict[Hashable, int]]

    Tuple of [DataFrame, group size dict]. The latter is used as N in plots.

Source code in src/survey_framework/data_analysis/count_responses.py
def prepare_df_comparison(
    responses_df_all: pd.DataFrame,
    responses_df_comparison: "pd.Series[str]",
    q: str,
    q_comparison: str,
    ordering: dict[str, list[str]],
) -> tuple[pd.DataFrame, dict[Hashable, int]]:
    """Compare groups of participants (determined by comparison_series).

    This function is for single-choice questions.

    The output dataframe contains the following columns:
        - q: The answer options
        - q_comparison: The groups
        - total: total number of participants in this group
        - count: number of participants (in this group) that gave this answer
        - proportion: share of participants (relative to "total") that gave this answ.

    Args:
        responses_df_all: DataFrame with answers for the base question
        responses_df_comparison: Answers for the intersecting question
        q: name of the output column for answer options
        q_comparison: name of the output column for groups
        ordering: Answer re-ordering dict, e.g. ORDER from `order/order2024.py`

    Returns:
        Tuple of [DataFrame, group size dict]. The latter is used as N in plots.
    """
    assert "id" not in responses_df_all.columns
    responses_joined = responses_df_all.join(responses_df_comparison)

    grouped_by_center = responses_joined.groupby(q_comparison, observed=False)[q]
    responses_df_counts = pd.concat(
        [
            grouped_by_center.value_counts(normalize=True).rename("proportion"),
            grouped_by_center.value_counts().rename("count"),
        ],
        axis=1,
    ).reset_index()

    # sort DF
    order_left = ordering.get(q)
    if order_left:
        responses_df_counts[q] = pd.Categorical(
            responses_df_counts[q], categories=order_left, ordered=True
        )
    order_right = ordering.get(q_comparison)
    if order_right:
        responses_df_counts[q_comparison] = pd.Categorical(
            responses_df_counts[q_comparison], categories=order_right, ordered=True
        )
    responses_df_counts_sorted = responses_df_counts.sort_values(by=[q_comparison, q])

    return responses_df_counts_sorted, grouped_by_center.count().to_dict()

prepare_df_comparison_multiple(responses_df, comparison_series, q, q_comparison, ordering)

Compare groups of participants (determined by comparison_series).

This function is for multiple-choice questions.

The output dataframe contains the following columns
  • q: The answer options
  • q_comparison: The groups
  • total: total number of participants in this group
  • count: number of participants (in this group) that gave this answer
  • proportion: share of participants (relative to "total") that gave this answ.

Parameters:

  • responses_df (DataFrame) –

    The main DataFrame of answers

  • comparison_series (Series[str]) –

    Participant group (shares index with the main DF)

  • q (str) –

    name of the output column for answer options

  • q_comparison (str) –

    name of the output column for groups

  • ordering (dict[str, list[str]]) –

    Answer re-ordering dict, e.g. ORDER from order/order2024.py

Returns:

  • tuple[DataFrame, dict[Hashable, int]]

    Tuple of [DataFrame, group size dict]. The latter is used as N in plots.

Source code in src/survey_framework/data_analysis/count_responses.py
def prepare_df_comparison_multiple(
    responses_df: pd.DataFrame,
    comparison_series: "pd.Series[str]",
    q: str,
    q_comparison: str,
    ordering: dict[str, list[str]],
) -> tuple[pd.DataFrame, dict[Hashable, int]]:
    """Compare groups of participants (determined by comparison_series).

    This function is for multiple-choice questions.

    The output dataframe contains the following columns:
        - q: The answer options
        - q_comparison: The groups
        - total: total number of participants in this group
        - count: number of participants (in this group) that gave this answer
        - proportion: share of participants (relative to "total") that gave this answ.

    Args:
        responses_df: The main DataFrame of answers
        comparison_series: Participant group (shares index with the main DF)
        q: name of the output column for answer options
        q_comparison: name of the output column for groups
        ordering: Answer re-ordering dict, e.g. ORDER from `order/order2024.py`

    Returns:
        Tuple of [DataFrame, group size dict]. The latter is used as N in plots.
    """
    # boolean value: participants who answered anything (summed up per group later)
    responses_df["total"] = responses_df.sum(axis="columns").gt(0)
    # melt into long form, merge with comparison
    responses_melt = pd.melt(
        responses_df.reset_index(),
        id_vars=["id", "total"],
        var_name=q,
        value_name="count",
    ).join(comparison_series, on="id")

    # for each subquestion, count `True` values, and normalize per group
    responses_counts = (
        responses_melt.groupby([q_comparison, q]).sum().drop(columns=["id"])
    )
    responses_counts["proportion"] = (
        responses_counts["count"] / responses_counts["total"]
    )

    # re-index for sorting
    responses_clean = responses_counts.reset_index()

    # get the number of participants per group in q_comparison
    participants = responses_clean.set_index(q_comparison)["total"].drop_duplicates()

    # ordering (copied from `prepare_df_comparison` above)
    order_left = ordering.get(q)
    if order_left:
        responses_clean[q] = pd.Categorical(
            responses_clean[q], categories=order_left, ordered=True
        )
    order_right = ordering.get(q_comparison)
    if order_right:
        responses_clean[q_comparison] = pd.Categorical(
            responses_clean[q_comparison], categories=order_right, ordered=True
        )
    responses_sort = responses_clean.sort_values(by=[q_comparison, q])
    # print(responses_sort)

    return responses_sort, participants.astype(int).to_dict()

prepare_df_multiple(data, q, ordering)

Count participants in the data. This function is for multiple-choice questions.

The output dataframe contains the following columns
  • q: The answer options
  • count: number of participants (in this group) that gave this answer
  • proportion: share of participants (relative to "total") that gave this answ.

Parameters:

  • data (DataFrame) –

    The main DataFrame of answers

  • q (str) –

    name of the output column for answer options

  • ordering (dict[str, list[str]]) –

    Answer re-ordering dict, e.g. ORDER from order/order2024.py

Returns:

  • tuple[DataFrame, int]

    Tuple of [DataFrame, participant number]. The latter is used as N in plots.

Source code in src/survey_framework/data_analysis/count_responses.py
def prepare_df_multiple(
    data: pd.DataFrame, q: str, ordering: dict[str, list[str]]
) -> tuple[pd.DataFrame, int]:
    """Count participants in the data. This function is for multiple-choice questions.

    The output dataframe contains the following columns:
        - q: The answer options
        - count: number of participants (in this group) that gave this answer
        - proportion: share of participants (relative to "total") that gave this answ.

    Args:
        data: The main DataFrame of answers
        q: name of the output column for answer options
        ordering: Answer re-ordering dict, e.g. ORDER from `order/order2024.py`

    Returns:
        Tuple of [DataFrame, participant number]. The latter is used as N in plots.
    """
    # boolean value: participants who answered anything (summed up later)
    data["total"] = data.sum(axis="columns").gt(0)
    # melt into long form
    responses_melted = pd.melt(data, id_vars=["total"], value_name="count", var_name=q)
    responses_counts = responses_melted.groupby(q).sum()

    # add percentages column
    responses_counts["proportion"] = (
        responses_counts["count"] / responses_counts["total"]
    )

    # re-index for sorting
    responses_clean = responses_counts.reset_index()
    # get number of participants
    participants = responses_clean["total"].drop_duplicates().iloc[0]

    # sort the DF
    orderlist = ordering.get(q)
    if orderlist:
        # sort with given order
        responses_clean[q] = pd.Categorical(
            responses_clean[q], categories=orderlist, ordered=True
        )
        responses_sorted = responses_clean.sort_values(by=q)
    else:
        # no order given, sort by descending values
        responses_sorted = responses_clean.sort_values(by="count", ascending=False)
        # TODO: sorting by value is unstable between centers. We probably want
        #       to define a fixed order for all questions.
        # print(q, responses_df_counts_sorted["variable"].to_list())

    return responses_sorted, participants

prepare_df_single(data, q, ordering)

Count participants in the data. This function is for single-choice questions.

The output dataframe contains the following columns
  • q: The answer options
  • count: number of participants (in this group) that gave this answer
  • proportion: share of participants (relative to "total") that gave this answ.

Parameters:

  • data (DataFrame) –

    The main DataFrame of answers

  • q (str) –

    name of the output column for answer options

  • ordering (dict[str, list[str]]) –

    Answer re-ordering dict, e.g. ORDER from order/order2024.py

Returns:

  • tuple[DataFrame, int]

    Tuple of [DataFrame, participant number]. The latter is used as N in plots.

Source code in src/survey_framework/data_analysis/count_responses.py
def prepare_df_single(
    data: pd.DataFrame, q: str, ordering: dict[str, list[str]]
) -> tuple[pd.DataFrame, int]:
    """Count participants in the data. This function is for single-choice questions.

    The output dataframe contains the following columns:
        - q: The answer options
        - count: number of participants (in this group) that gave this answer
        - proportion: share of participants (relative to "total") that gave this answ.

    Args:
        data: The main DataFrame of answers
        q: name of the output column for answer options
        ordering: Answer re-ordering dict, e.g. ORDER from `order/order2024.py`

    Returns:
        Tuple of [DataFrame, participant number]. The latter is used as N in plots.
    """
    assert "id" not in data.columns
    N_question = data.count().iloc[0]

    # need to reset the index, otherwise count returns an empty DF.
    data_q_counts = (
        data.reset_index()
        .groupby(q, observed=False)
        .count()
        .rename(columns={"id": "count"})
    )

    # sort the dataframe
    data_q_counts_sorted = data_q_counts.reset_index()
    orderlist = ordering.get(q)
    if orderlist:
        # sort with given order
        data_q_counts_sorted[q] = pd.Categorical(
            data_q_counts_sorted[q], categories=orderlist, ordered=True
        )
        data_q_counts_sorted = data_q_counts_sorted.sort_values(by=q)

    # add percentages column
    data_q_counts_sorted_percentages = data_q_counts_sorted
    data_q_counts_sorted_percentages["proportion"] = (
        data_q_counts_sorted_percentages["count"] / N_question
    )

    return data_q_counts_sorted_percentages, N_question

Basic filtering and aggreation of survey data.

filter_by_center(survey, responses, center_code)

Filter responses by center.

Parameters:

  • survey (LimeSurveyData) –

    The survey object

  • responses (DataFrame) –

    DataFrame with responses

  • center_code (str) –

    ID of the center to filter by (like 'A01')

Returns:

  • tuple[DataFrame, DataFrame]

    Tuple of filtered DataFrame and remainder DataFrame

Source code in src/survey_framework/data_analysis/analysis.py
def filter_by_center(
    survey: LimeSurveyData, responses: pd.DataFrame, center_code: str
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Filter responses by center.

    Args:
        survey: The survey object
        responses: DataFrame with responses
        center_code: ID of the center to filter by (like 'A01')

    Returns:
        Tuple of filtered DataFrame and remainder DataFrame
    """
    # get IDs for the given center, then filter by the IDs
    centers = survey.get_responses(CENTER)
    center_students = centers[centers[CENTER] == center_code]
    if "id" in responses.columns:
        filtered = responses.loc[
            responses["id"].astype(int).isin(center_students.index)
        ]
        remainder = responses.loc[
            ~responses["id"].astype(int).isin(center_students.index)
        ]
    else:
        filtered = responses[responses.index.isin(center_students.index)]
        remainder = responses[~responses.index.isin(center_students.index)]

    assert len(filtered) == len(center_students)
    assert len(responses) == len(filtered) + len(remainder)
    assert type(filtered) is pd.DataFrame and type(remainder) is pd.DataFrame
    return filtered, remainder

get_as_numeric(survey, q_code, blocklist)

Get numeric answers for the requested question code.

Raises:

  • ValueError

    if non-numeric answer codes are not in the given blocklist.

Parameters:

  • survey (LimeSurveyData) –

    The survey object

  • q_code (str) –

    The question ID to be queried

  • blocklist (list[str]) –

    Answer codes to be excluded from the result

Returns:

  • Series[float]

    Numeric Series

Source code in src/survey_framework/data_analysis/analysis.py
def get_as_numeric(
    survey: LimeSurveyData, q_code: str, blocklist: list[str]
) -> "pd.Series[float]":
    """Get numeric answers for the requested question code.

    Raises:
        ValueError: if non-numeric answer codes are not in the given blocklist.

    Args:
        survey: The survey object
        q_code: The question ID to be queried
        blocklist: Answer codes to be excluded from the result

    Returns:
        Numeric Series
    """
    answers = survey.get_responses(q_code)
    filtered = answers.loc[~answers[q_code].isin(blocklist)]

    choices = survey.get_choices(q_code)
    mapped = filtered.map(lambda a_code: choices[a_code])
    numeric = cast("pd.Series[float]", mapped.apply(pd.to_numeric).squeeze())
    return numeric

get_center_series(survey, center_code)

Get a series that contains the center name for every participant.

All centers other than center_code are replaced by "Other Centers". The output of this function can be nicely used with the histogram plot.

Parameters:

  • survey (LimeSurveyData) –

    The survey object

  • center_code (str) –

    ID of the center to filter by (like 'A01')

Returns:

  • tuple[Series[str], Sequence[str]]

    Tuple of the Series and a 2-element list for center ordering.

Source code in src/survey_framework/data_analysis/analysis.py
def get_center_series(
    survey: LimeSurveyData, center_code: str
) -> tuple["pd.Series[str]", Sequence[str]]:
    """Get a series that contains the center name for every participant.

    All centers *other than* `center_code` are replaced by "Other Centers".
    The output of this function can be nicely used with the histogram plot.

    Args:
        survey: The survey object
        center_code: ID of the center to filter by (like 'A01')

    Returns:
        Tuple of the Series and a 2-element list for center ordering.
    """
    center_name = shorten_center_name(survey.get_choices(CENTER)[center_code])
    assert center_name is not None

    # get the center question data, replace all "other" centers
    centers = survey.get_responses(CENTER)[CENTER].astype(str).rename("Center")
    centers.loc[~centers.isin([center_code])] = "Other Centers"
    centers.loc[centers.isin([center_code])] = center_name

    return centers, [center_name, "Other Centers"]

get_data_for_q(survey, question_number)

Returns a DataFrame with the responses for the given question ID.

Contingent questions (free text fields, shown as "other") are removed.

Deprecated

Use LimeSurveyData.get_responses(drop_other=True) instead.

Parameters:

  • survey (LimeSurveyData) –

    The LimeSurvey object

  • question_number (str) –

    The question ID (like 'A1')

Returns:

  • DataFrame

    A DataFrame with all answers to the specified question

Source code in src/survey_framework/data_analysis/analysis.py
def get_data_for_q(survey: LimeSurveyData, question_number: str) -> pd.DataFrame:
    """Returns a DataFrame with the responses for the given question ID.

    Contingent questions (free text fields, shown as "other") are removed.

    Deprecated:
        Use `LimeSurveyData.get_responses(drop_other=True)` instead.

    Args:
        survey: The LimeSurvey object
        question_number: The question ID (like 'A1')

    Returns:
        A DataFrame with all answers to the specified question
    """
    warn(
        "get_data_for_q is deprecated, use survey.get_responses instead",
        DeprecationWarning,
        stacklevel=2,
    )

    responses = survey.get_responses(question_number, drop_other=True)
    # change types of all columns to object
    responses = responses.astype("object")
    # make id column by resetting the index
    responses = responses.reset_index()
    # change id column to type string
    responses = responses.astype({"id": "string"})
    return responses

get_phd_duration(survey)

Calculate relevant durations from questions A8 and A9.

We calculate: * How long has the participant been a doctoral researcher [years]? * How long do they estimate their project to last in total [months]?

Parameters:

Returns:

  • tuple[Series[int], Series[int]]

    Tuple of current phd year and total duration estimation.

Source code in src/survey_framework/data_analysis/analysis.py
def get_phd_duration(
    survey: LimeSurveyData,
) -> tuple["pd.Series[int]", "pd.Series[int]"]:
    """Calculate relevant durations from questions A8 and A9.

    We calculate:
    * How long has the participant been a doctoral researcher [years]?
    * How long do they estimate their project to last *in total* [months]?

    Args:
        survey: The survey object

    Returns:
        Tuple of current phd year and total duration estimation.
    """
    Q_START = {"year": "A8", "month": "A8a"}
    Q_END = {"year": "A9", "month": "A9a"}

    # we collected the data in April / May 2024
    SURVEY_YEAR = 2024
    SURVEY_MONTH = 5

    # filter out "before 2015" / I don't know / IDWA
    startyear = get_as_numeric(survey, Q_START["year"], ["A8", "A9", "A13"])
    startmonth = get_as_numeric(survey, Q_START["month"], ["A13", "A14"])
    endyear = get_as_numeric(survey, Q_END["year"], ["A11", "A12", "A13", "A14", "A15"])
    endmonth = get_as_numeric(survey, Q_END["month"], ["A13", "A14", "A15"])

    # calculate phd year relative to survey time (0 means not yet started)
    phd_current_month = (
        startyear.rsub(SURVEY_YEAR)
        .mul(12)
        .add(
            startmonth.rsub(SURVEY_MONTH),
        )
        .dropna()
        .astype(int)
    )
    phd_current_year = phd_current_month.floordiv(12).add(1)
    phd_current_year.clip(0, 6, inplace=True)  # clamp after 6 years
    phd_current_year = phd_current_year[phd_current_year != 0]  # remove year 0

    # calculate estimated total phd duration in months
    # if end month is missing, assume December
    phd_estimation_months = (
        endyear.sub(startyear)
        .mul(12)
        .add(
            endmonth.sub(startmonth, fill_value=12),
        )
        .dropna()
        .astype(int)
    )

    return (phd_current_year, phd_estimation_months)

Scoring

Logic for converting answers on standardized scales into scores.

Condition

Bases: StrEnum

Enumeration of mental health conditions, to be used with rate_mental_health.

Profile

Bases: StrEnum

The five burnout profiles defined in the MBI manual.

Scale

Bases: StrEnum

The three burnout scales defined by the MBI.

rate_burnout(responses)

Calculate burnout scores from participants' answers.

This uses the MBI-GS scale according to the Maslach Burnout Inventory (MBI) Manual, Fourth Edition.

Parameters:

  • responses (DataFrame) –

    responses to question D3d (burnout)

Returns:

  • DataFrame

    SUM scores for each Scale (3 ints) and a burnout Profile (1 string)

Source code in src/survey_framework/data_analysis/scoring.py
def rate_burnout(responses: pd.DataFrame) -> pd.DataFrame:
    """Calculate burnout scores from participants' answers.

    This uses the MBI-GS scale according to the Maslach Burnout Inventory (MBI)
    Manual, Fourth Edition.

    Args:
        responses: responses to question D3d (burnout)

    Returns:
        SUM scores for each `Scale` (3 ints) and a burnout `Profile` (1 string)
    """
    SCORE_MAP = {
        "A2": 0,  # "Never"
        "A3": 1,  # "A few times a year or less"
        "A4": 2,  # "Once a month or less"
        "A5": 3,  # "A few times a month"
        "A6": 4,  # "Once a week"
        "A7": 5,  # "A few times a week"
        "A8": 6,  # "Every day"
    }

    scales = [
        Scale.EX,  # I feel emotionally drained from my work.
        Scale.EX,  # I feel used up at the end of the workday.
        Scale.EX,  # I feel tired when I get up in the morning and have to ...
        Scale.EX,  # Working all day is really a strain for me.
        Scale.PE,  # I can effectively solve the problems that arise in my work.
        Scale.EX,  # I feel burned out from my work.
        Scale.PE,  # I feel I am making an effective contribution to what ...
        Scale.CY,  # I have become less interested in my work since I ...
        Scale.CY,  # I have become less enthusiastic about my work.
        Scale.PE,  # In my opinion, I am good at my job.
        Scale.PE,  # I feel exhilarated when I accomplish something at work.
        Scale.PE,  # I have accomplished many worthwhile things in this job.
        Scale.CY,  # I just want to do my job and not be bothered.
        Scale.CY,  # I have become more cynical about whether my work ...
        Scale.CY,  # I doubt the significance of my work.
        Scale.PE,  # At my work, I feel confident that I am effective at ...
    ]

    # make empty df with three columns
    df = pd.DataFrame(responses.index)
    df[Scale.EX], df[Scale.CY], df[Scale.PE] = 0, 0, 0

    for col, scale in zip(responses.columns, scales, strict=True):
        # sum up the score in the relevant category
        col_scores = responses[col].map(SCORE_MAP, na_action="ignore")
        df[scale] += col_scores

    # boolean classification according to Table 3 in the manual
    # for PE, critical == good, hence the ">" instead of ">="
    df["EX_critical"] = df[Scale.EX].div(5).map(lambda x: x >= 2.90, na_action="ignore")
    df["CY_critical"] = df[Scale.CY].div(5).map(lambda x: x >= 2.86, na_action="ignore")
    df["PE_critical"] = df[Scale.PE].div(6).map(lambda x: x > 4.30, na_action="ignore")

    def classify(row: "pd.Series[Any]") -> Profile:
        """Assign burnout profiles according to Table 1 in the manual.

        Args:
            row: a single participant

        Returns:
            burnout `Profile` of the participant
        """
        exhausted = row["EX_critical"]
        cynical = row["CY_critical"]
        effective = row["PE_critical"]
        assert isinstance(exhausted, bool)
        assert isinstance(cynical, bool)
        assert isinstance(effective, bool)

        if not exhausted and not cynical and effective:
            return Profile.ENGAGED
        elif not exhausted and not cynical and not effective:
            return Profile.INEFFECTIVE
        elif exhausted and not cynical:
            return Profile.OVEREXTENDED
        elif not exhausted and cynical:
            return Profile.DISENGAGED
        elif exhausted and cynical:
            return Profile.BURNOUT
        else:
            raise AssertionError("unreachable")

    df["Profile"] = df.dropna().apply(classify, axis=1, result_type="reduce")
    return df

rate_mental_health(responses, condition, keep_subscores=False)

Calculate State/Trait Anxiety or Depression score based on responses to question.

Scoring is based on
  • K. Kroenke, R. L. Spitzer, J. B. W. William, and B. Löwe., The Patient Health Questionnaire somatic, anxiety,and depressive symptom scales: a systematic review. General Hospital Psychiatry, 32(4):345-359, 2010.
  • T. M. Marteau and H. Bekker., The development of a six-item short-form of the state scale of the spielberger state-trait anxiety inventory (STAI). British Journal of Clinical Psychology, 31(3):301-306, 1992.

Parameters:

  • responses (DataFrame) –

    DataFrame containing responses data

  • condition (Condition) –

    Which kind of mental health condition to rate

  • keep_subscores (bool, default: False ) –

    Whether to include scores from subquestions in the output DataFrame, or only total score and classification. Default False.

Returns:

  • DataFrame

    Mental health condition ratings ("score") and classifications ("class").

Source code in src/survey_framework/data_analysis/scoring.py
def rate_mental_health(
    responses: pd.DataFrame,
    condition: Condition,
    keep_subscores: bool = False,
) -> pd.DataFrame:
    """Calculate State/Trait Anxiety or Depression score based on responses to question.

    Scoring is based on:
      * K. Kroenke, R. L. Spitzer, J. B. W. William, and B. Löwe., The Patient
        Health Questionnaire somatic, anxiety,and depressive symptom scales:
        a systematic review. General Hospital Psychiatry, 32(4):345-359, 2010.
      * T. M. Marteau and H. Bekker., The development of a six-item short-form
        of the state scale of the spielberger state-trait anxiety inventory
        (STAI). British Journal of Clinical Psychology, 31(3):301-306, 1992.

    Args:
        responses: DataFrame containing responses data
        condition: Which kind of mental health condition to rate
        keep_subscores: Whether to include scores from subquestions
            in the output DataFrame, or only total score and classification.
            Default False.

    Returns:
        Mental health condition ratings ("score") and classifications ("class").
    """
    # Set up condition-specific parameters
    match condition:
        case Condition.STATE_ANXIETY:
            num_subquestions = 6
            base_score = 10 / 3
            conversion = ["pos", "neg", "neg", "pos", "pos", "neg"]
            label = "state_anxiety"
            classification_boundaries = [20, 40, 60, 80]
            classes = ["no or low anxiety", "moderate anxiety", "high anxiety"]

        case Condition.TRAIT_ANXIETY:
            num_subquestions = 8
            base_score = 5 / 2
            conversion = [
                "pos",
                "neg",
                "neg",
                "pos",
                "neg",
                "neg",
                "pos",
                "neg",
            ]
            label = "trait_anxiety"
            classification_boundaries = [20, 40, 60, 80]
            classes = ["no or low anxiety", "moderate anxiety", "high anxiety"]

        case Condition.DEPRESSION:
            num_subquestions = 8
            base_score = 1
            conversion = ["freq" for _ in range(8)]
            label = "depression"
            classification_boundaries = [0, 4, 9, 14, 19, 24]
            classes = [
                "no to minimal depression",
                "mild depression",
                "moderate depression",
                "moderately severe depression",
                "severe depression",
            ]

        case _:
            raise AssertionError("unreachable")

    # sanity check
    q_code = responses.columns[0].split("_")[0]
    if q_code != condition:
        raise ValueError(f"expected question {condition}, got {q_code}")

    # Set up score conversion dicts
    pos_direction_scores = {
        "A1": 4 * base_score,  # "Not at all
        "A2": 3 * base_score,  # "Somewhat"
        "A3": 2 * base_score,  # "Moderately
        "A4": 1 * base_score,  # "Very much"
    }
    neg_direction_scores = {
        "A1": 1 * base_score,  # "Not at all"
        "A2": 2 * base_score,  # "Somewhat"
        "A3": 3 * base_score,  # "Moderately"
        "A4": 4 * base_score,  # "Very much"
    }
    frequency_scores = {
        "A1": 0 * base_score,  # "Not at all"
        "A2": 1 * base_score,  # "Several days"
        "A3": 2 * base_score,  # "More than half the days"
        "A4": 3 * base_score,  # "Nearly every day"
    }
    conversion_dicts = {
        "pos": pos_direction_scores,
        "neg": neg_direction_scores,
        "freq": frequency_scores,
    }

    # Map responses from code to score
    df = pd.DataFrame()
    for column, conv in zip(responses.columns, conversion, strict=True):
        df[f"{column}_score"] = responses[column].map(
            conversion_dicts[conv], na_action="ignore"
        )

    # Calculate total anxiety or depression scores
    # scaled by number of non-NaN responses
    # e.g. scale by 8/5 if 5/8 subquestions answered
    responses_counts = df.notna().sum(axis=1)
    df[f"{label}_score"] = (
        df.sum(axis=1, skipna=True).div(responses_counts).mul(num_subquestions)
    )

    # Suppress entries with less than half of all subquestions answered
    # TODO: we might want to be more strict here
    df.loc[responses_counts < num_subquestions / 2, f"{label}_score"] = None

    # Classify into categories
    df[f"{label}_class"] = pd.cut(
        df[f"{label}_score"],
        bins=classification_boundaries,
        include_lowest=True,
        labels=classes,
    )

    if not keep_subscores:
        df = df.drop(df.columns[:-2], axis=1)

    return df

rate_satisfaction(responses, calc_average=True)

Calculate satisfaction rating for each subquestion and calculate the average.

Uses a numeric scale from 1 = very dissatisfied to 5 = very satisfied.

Parameters:

  • responses (DataFrame) –

    DataFrame containing responses data

  • calc_average (bool, default: True ) –

    Whether to calculate average satisfaction. Default True.

Returns:

  • DataFrame

    Satisfaction ratings for each component (and overall average)

Source code in src/survey_framework/data_analysis/scoring.py
def rate_satisfaction(
    responses: pd.DataFrame,
    calc_average: bool = True,
) -> pd.DataFrame:
    """Calculate satisfaction rating for each subquestion and calculate the average.

    Uses a numeric scale from 1 = very dissatisfied to 5 = very satisfied.

    Args:
        responses: DataFrame containing responses data
        calc_average: Whether to calculate average satisfaction. Default True.

    Returns:
        Satisfaction ratings for each component (and overall average)
    """
    # safety check: only questions where we verified that they use a scale from
    # A1 (very satisfied) to A5 (very dissatisfied) should be added to this list
    SATISFACTION_QUESTIONS = ["C1"]

    q_code = responses.columns[0].split("_")[0]
    if q_code not in SATISFACTION_QUESTIONS:
        raise ValueError(f"{q_code} is not a satisfaction-scale question")

    # Set up score conversion dicts for individual questions
    satisfaction_question_scores = {
        "A1": 5.0,  # "Very satisfied"
        "A2": 4.0,  # "Satisfied"
        "A3": 3.0,  # "Neither/nor"
        "A4": 2.0,  # "Dissatisfied"
        "A5": 1.0,  # "Very dissatisfied"
    }

    # Map responses to score
    df = pd.DataFrame()
    for column in responses.columns:
        df[f"{column}_score"] = responses[column].map(
            satisfaction_question_scores, na_action="ignore"
        )

    # Calculate mean rating and round (ignoring NaN)
    if calc_average:
        df[f"{q_code}_score"] = df.mean(axis=1, skipna=True).round()
        df[f"{q_code}_class"] = df[f"{q_code}_score"].map(
            {v: k for k, v in satisfaction_question_scores.items()}
        )

    return df

rate_somatic(responses, keep_subscores=False)

Calculate Patient Health Questionaire (PHQ15) from participant responses.

Scoring is based on
  • K. Kroenke, R. L. Spitzer, J. B. W. William, and B. Löwe., The Patient Health Questionnaire somatic, anxiety, and depressive symptom scales: a systematic review. General Hospital Psychiatry, 32(4):345-359, 2010.

Parameters:

  • responses (DataFrame) –

    DataFrame containing responses data

  • keep_subscores (bool, default: False ) –

    Whether to include scores from subquestions in the output DataFrame, or only total score and classification.

Returns:

  • DataFrame

    PHQ15 classifications in two columns ("D4_class" and "D4_score").

Source code in src/survey_framework/data_analysis/scoring.py
def rate_somatic(
    responses: pd.DataFrame,
    keep_subscores: bool = False,
) -> pd.DataFrame:
    """Calculate Patient Health Questionaire (PHQ15) from participant responses.

    Scoring is based on:
      * K. Kroenke, R. L. Spitzer, J. B. W. William, and B. Löwe., The Patient
        Health Questionnaire somatic, anxiety, and depressive symptom scales:
        a systematic review. General Hospital Psychiatry, 32(4):345-359, 2010.

    Args:
        responses: DataFrame containing responses data
        keep_subscores: Whether to include scores from subquestions
            in the output DataFrame, or only total score and classification.

    Returns:
        PHQ15 classifications in two columns ("D4_class" and "D4_score").
    """
    PHQ15 = "D4"
    label = "somatic"

    # sanity check
    q_code = responses.columns[0].split("_")[0]
    if q_code != PHQ15:
        raise ValueError(f"expected question {PHQ15}, got {q_code}")

    num_subquestions = 14
    base_score = 15 / 14
    classification_boundaries = [0, 4, 9, 14, 30]
    classes = [
        "No somatic symptoms",
        "Mild somatic symptoms",
        "Moderate somatic symptoms",
        "Severe somatic symptoms",
    ]

    # Set up score conversion dicts
    scores = {
        "A2": 0 * base_score,  # "Not bothered
        "A3": 1 * base_score,  # "Bothered a little"
        "A4": 2 * base_score,  # "Bothered a lot
    }

    # Map responses from code to score
    df = pd.DataFrame()
    for column in responses.columns:
        df[f"{column}_score"] = responses[column].map(scores, na_action="ignore")

    # Calculate total scores scaled by number of non-NaN responses
    # e.g. scale by 8/5 if 5/8 subquestions answered
    responses_counts = df.notna().sum(axis=1)
    df[f"{label}_score"] = (
        df.sum(axis=1, skipna=True).div(responses_counts).mul(num_subquestions)
    )

    # Suppress entries with less than half of all subquestions answered
    # TODO: we might want to be more strict here
    df.loc[responses_counts < num_subquestions / 2, f"{label}_score"] = None

    # Classify into categories
    df[f"{label}_class"] = pd.cut(
        df[f"{label}_score"],
        bins=classification_boundaries,
        include_lowest=True,
        labels=classes,
    )

    if not keep_subscores:
        df = df.drop(df.columns[:-2], axis=1)

    return df

Bar Plots

Functions for basic bar plots.

plot_bar(survey, data_df, question, n_question, label_q_data='', orientation=Orientation.HORIZONTAL, stat=PlotStat.COUNT, width=6, height=4, bar_labels=BarLabels.NONE, bar_label_size=None, tick_label_size=None, tick_label_wrap=25)

Plot bar plots (single and multiple).

Parameters:

  • survey (LimeSurveyData) –

    The LimeSurvey object

  • data_df (DataFrame) –

    DataFrame with responses to be plotted

  • question (str) –

    The question code

  • n_question (int) –

    Number of participants

  • label_q_data (str, default: '' ) –

    Label for axis with data from question.

  • orientation (Orientation, default: HORIZONTAL ) –

    Plot orientation.

  • stat (PlotStat, default: COUNT ) –

    Plot absolute values or percentages?

  • width (float, default: 6 ) –

    Width of figure.

  • height (float, default: 4 ) –

    Height of figure.

  • bar_labels (BarLabels, default: NONE ) –

    How to format bar labels.

  • bar_label_size (int | None, default: None ) –

    Font size for bar labels, if enabled.

  • tick_label_size (int | None, default: None ) –

    Font size for tick labels.

  • tick_label_wrap (int, default: 25 ) –

    How many characters are allowed per line in tick labels.

Returns:

  • tuple[Figure, Axes]

    New matplotlib Figure and Axes for the bar plot.

Source code in src/survey_framework/plotting/barplots.py
def plot_bar(
    survey: LimeSurveyData,
    data_df: pd.DataFrame,
    question: str,
    n_question: int,
    label_q_data: str = "",
    orientation: Orientation = Orientation.HORIZONTAL,
    stat: PlotStat = PlotStat.COUNT,
    width: float = 6,
    height: float = 4,
    bar_labels: BarLabels = BarLabels.NONE,
    bar_label_size: int | None = None,
    tick_label_size: int | None = None,
    tick_label_wrap: int = 25,
) -> tuple[Figure, Axes]:
    """Plot bar plots (single and multiple).

    Args:
        survey: The LimeSurvey object
        data_df: DataFrame with responses to be plotted
        question: The question code
        n_question: Number of participants
        label_q_data: Label for axis with data from question.
        orientation: Plot orientation.
        stat: Plot absolute values or percentages?
        width: Width of figure.
        height: Height of figure.
        bar_labels: How to format bar labels.
        bar_label_size: Font size for bar labels, if enabled.
        tick_label_size: Font size for tick labels.
        tick_label_wrap: How many characters are allowed per line in tick labels.

    Returns:
        New matplotlib Figure and Axes for the bar plot.
    """
    # plot barplot
    fig, ax = barplot_internal(
        data_df=data_df,
        question=question,
        orient=orientation,
        stat=stat,
        width=width,
        height=height,
    )

    # add number of participants to top right corner
    plt.text(
        0.99,
        0.99,
        f"N = {n_question}",
        ha="right",
        va="top",
        transform=ax.transAxes,
        # fontsize=fontsize,
    )

    # add bar labels (the ones on top or next to bars within the plot)
    ax = add_bar_labels(
        ax=ax,
        show_axes_labels=bar_labels,
        percentcount=stat,
        n_question=n_question,
        fontsize=bar_label_size,
    )

    # add tick labels (the ones below or next to the bars outside of the plot)
    ax = add_tick_labels(
        survey=survey,
        ax=ax,
        question=question,
        orientation=orientation,
        fontsize=tick_label_size,
        text_wrap=tick_label_wrap,
    )

    # add general labels to axes
    label_axes(ax=ax, orientation=orientation, label_q_data=label_q_data, stat=stat)

    return fig, ax

plot_bar_comparison(survey, data_df, question, hue, hue_order=None, n_participants=None, label_q_data='', orient=Orientation.HORIZONTAL, stat=PlotStat.COUNT, width=6, height=4, bar_labels=BarLabels.NONE, bar_label_size=None, tick_label_size=None, tick_label_wrap=25)

Plot comparison bar plots (single and multiple).

Parameters:

  • survey (LimeSurveyData) –

    The LimeSurvey object

  • data_df (DataFrame) –

    DataFrame with responses to be plotted

  • question (str) –

    Question code for the first question

  • hue (str) –

    Question code for the second question

  • hue_order (Sequence[str] | None, default: None ) –

    Order of answer options for the second question.

  • n_participants (dict[Hashable, int] | None, default: None ) –

    Number of participants per hue group (usually centers), or None to suppress printing N.

  • label_q_data (str, default: '' ) –

    Label for axis with data from question.

  • orient (Orientation, default: HORIZONTAL ) –

    Plot orientation.

  • stat (PlotStat, default: COUNT ) –

    Plot absolute values or percentages?

  • width (float, default: 6 ) –

    Width of figure.

  • height (float, default: 4 ) –

    Height of figure.

  • bar_labels (BarLabels, default: NONE ) –

    How to format bar labels.

  • bar_label_size (int | None, default: None ) –

    Font size for bar labels.

  • tick_label_size (int | None, default: None ) –

    Font size for tick labels.

  • tick_label_wrap (int, default: 25 ) –

    Number of letters after which tick labels wrap.

Returns:

  • tuple[Figure, Axes]

    New matplotlib Figure and Axes for the bar plot.

Source code in src/survey_framework/plotting/barplots.py
def plot_bar_comparison(
    survey: LimeSurveyData,
    data_df: pd.DataFrame,
    question: str,
    hue: str,
    hue_order: Sequence[str] | None = None,
    n_participants: dict[Hashable, int] | None = None,
    label_q_data: str = "",
    orient: Orientation = Orientation.HORIZONTAL,
    stat: PlotStat = PlotStat.COUNT,
    width: float = 6,
    height: float = 4,
    bar_labels: BarLabels = BarLabels.NONE,
    bar_label_size: int | None = None,
    tick_label_size: int | None = None,
    tick_label_wrap: int = 25,
) -> tuple[Figure, Axes]:
    """Plot comparison bar plots (single and multiple).

    Args:
        survey: The LimeSurvey object
        data_df: DataFrame with responses to be plotted
        question: Question code for the first question
        hue: Question code for the second question
        hue_order: Order of answer options for the second question.
        n_participants: Number of participants per hue group (usually centers),
            or None to suppress printing N.
        label_q_data: Label for axis with data from question.
        orient: Plot orientation.
        stat: Plot absolute values or percentages?
        width: Width of figure.
        height: Height of figure.
        bar_labels: How to format bar labels.
        bar_label_size: Font size for bar labels.
        tick_label_size: Font size for tick labels.
        tick_label_wrap: Number of letters after which tick labels wrap.

    Returns:
        New matplotlib Figure and Axes for the bar plot.
    """
    # plot barplot
    fig, ax = barplot_internal(
        data_df=data_df,
        question=question,
        orient=orient,
        stat=stat,
        width=width,
        height=height,
        comparison=PlotType.SINGLE_Q_COMPARISON
        if n_participants is not None
        else PlotType.MULTI_Q,
        hue=hue,
        hue_order=hue_order,
    )

    # adapt legend
    ax = adapt_legend(
        survey=survey, ax=ax, question=hue, text_wrap=40, group_n=n_participants
    )

    # add bar labels (the ones on top or next to bars within the plot)
    ax = add_bar_labels(
        ax=ax,
        show_axes_labels=bar_labels,
        percentcount=stat,
        n_question=None,
        # rotation=45 if orientation == Orientation.VERTICAL else None,
        fontsize=bar_label_size,
    )

    # add tick labels (the ones below or next to the bars outside of the plot)
    ax = add_tick_labels(
        survey=survey,
        ax=ax,
        question=question,
        orientation=orient,
        fontsize=tick_label_size,
        text_wrap=tick_label_wrap,
    )

    # add general labels to axes
    label_axes(ax=ax, orientation=orient, label_q_data=label_q_data, stat=stat)
    return fig, ax

Functions for side-by-side horizontal bar plots.

plot_bar_side_by_side(survey, data_left, data_right, y_left, y_right, stat=PlotStat.PERCENT, color_left=helmholtzblue, color_right=helmholtzgreen, title_left=None, title_right=None, width=12, height=10)

Plot two horizontal bar plots side-by-side, sharing a common y axis.

Parameters:

  • survey (LimeSurveyData) –

    The survey object.

  • data_left (DataFrame) –

    Data for the left plot.

  • data_right (DataFrame) –

    Data for the right plot.

  • y_left (str) –

    Question ID (left)

  • y_right (str) –

    Question ID (right)

  • stat (PlotStat, default: PERCENT ) –

    Which metric to plot (percent / count)

  • color_left (str, default: helmholtzblue ) –

    Bar color (left).

  • color_right (str, default: helmholtzgreen ) –

    Bar color (right).

  • title_left (str | None, default: None ) –

    Left plot title.

  • title_right (str | None, default: None ) –

    Right plot title.

  • width (float, default: 12 ) –

    Total plot width.

  • height (float, default: 10 ) –

    Total plot height.

Returns:

  • tuple[Figure, tuple[Axes, Axes]]

    New Figure and Axes

Source code in src/survey_framework/plotting/barplots_sidebyside.py
def plot_bar_side_by_side(
    survey: LimeSurveyData,
    data_left: pd.DataFrame,
    data_right: pd.DataFrame,
    y_left: str,
    y_right: str,
    stat: PlotStat = PlotStat.PERCENT,
    color_left: str = helmholtzblue,
    color_right: str = helmholtzgreen,
    title_left: str | None = None,
    title_right: str | None = None,
    width: float = 12,
    height: float = 10,
) -> tuple[Figure, tuple[Axes, Axes]]:
    """Plot two horizontal bar plots side-by-side, sharing a common y axis.

    Args:
        survey: The survey object.
        data_left: Data for the left plot.
        data_right: Data for the right plot.
        y_left: Question ID (left)
        y_right: Question ID (right)
        stat: Which metric to plot (percent / count)
        color_left: Bar color (left).
        color_right: Bar color (right).
        title_left: Left plot title.
        title_right: Right plot title.
        width: Total plot width.
        height: Total plot height.

    Returns:
        New Figure and Axes
    """
    # set seaborn theme
    set_plotstyle()

    # define figure and axis
    # nrows, ncols = number of rows, columns of the subplot grid
    # sharey = share the Y axis
    # https://stackoverflow.com/questions/16150819/common-xlabel-ylabel-for-matplotlib-subplots
    figure, axs = plt.subplots(
        nrows=1, ncols=2, dpi=300, figsize=(width, height), sharey=True
    )
    ax_left, ax_right = cast(tuple[Axes, Axes], axs)

    # determine order according to answers
    order_left = [i for i in survey.questions.choices[y_left]]
    order_right = [i for i in survey.questions.choices[y_right]]

    # .loc[:,var] -> left side is for index, right side for column
    # make countplots for total numbers
    plot_left = sns.countplot(
        ax=ax_left,
        data=data_left,
        y=y_left,
        color=color_left,
        order=order_left,
        stat=stat.value,
    )
    plot_right = sns.countplot(
        ax=ax_right,
        data=data_right,
        y=y_right,
        color=color_right,
        order=order_right,
        stat=stat.value,
    )

    # remove spines from figure
    # ax_left.spines["top"].set_visible(False)
    # ax_left.spines["right"].set_visible(False)
    # ax_left.spines["bottom"].set_visible(False)
    # ax_left.spines["left"].set_visible(False)
    # ax_right.spines["top"].set_visible(False)
    # ax_right.spines["right"].set_visible(False)
    # ax_right.spines["bottom"].set_visible(False)
    # ax_right.spines["left"].set_visible(False)

    # set xlim equal on both sides
    if stat == PlotStat.PERCENT:
        ax_left.set_xlim((0, 100))
        ax_right.set_xlim((0, 100))
    elif stat == PlotStat.PROPORTION:
        ax_left.set_xlim((0, 1))
        ax_right.set_xlim((0, 1))

    # flip left side
    # https://stackoverflow.com/questions/68858330/right-align-horizontal-seaborn-barplot
    ax_left.invert_xaxis()
    ax_left.yaxis.tick_right()
    ax_left.yaxis.set_ticks_position("none")

    # calculate how many people answered this question
    N_left = len(data_left.index)
    N_right = len(data_right.index)

    # show percentages behind bars
    # bar_container = cast(BarContainer, plot_left.containers[0])
    # bar_labels_left = [
    #     f"{i / N_left * 100:.1f}%" for i in list(bar_container.datavalues)
    # ]
    # plot_left.bar_label(bar_container, labels=bar_labels_left)

    # bar_container = cast(BarContainer, plot_right.containers[0])
    # bar_labels_right = [
    #     f"{i / N_right * 100:.1f}%" for i in list(bar_container.datavalues)
    # ]
    # plot_right.bar_label(bar_container, labels=bar_labels_right)

    add_bar_labels(ax_left, BarLabels.PERCENT, stat, n_question=N_left)
    add_bar_labels(ax_right, BarLabels.PERCENT, stat, n_question=N_right)

    # get titles
    if title_left is None:
        title_left = survey.questions.label[y_left]
    if title_right is None:
        title_right = survey.questions.label[y_right]

    # set titles
    plot_left.set_title("\n".join(wrap(title_left, 40)))
    plot_right.set_title("\n".join(wrap(title_right, 40)))

    # capitalize x axis label
    ax_left.set_xlabel(stat.capitalize())
    ax_right.set_xlabel(stat.capitalize())

    # empty y axis label
    ax_left.set_ylabel("")
    ax_right.set_ylabel("")

    # set y axis tick labels; labels on the right side are not shown
    # https://stackoverflow.com/questions/11244514/modify-tick-label-text
    y_ticklabels = [item.get_text() for item in plot_left.get_yticklabels()]
    for i in range(0, len(y_ticklabels)):
        label = survey.questions.choices[y_right][y_ticklabels[i]]
        y_ticklabels[i] = "\n".join(wrap(label, 20))
    plot_right.set_yticks(range(len(y_ticklabels)))
    plot_right.set_yticklabels(y_ticklabels)

    # more space between both subfigures
    # https://www.geeksforgeeks.org/how-to-set-the-spacing-between-subplots-in-matplotlib-in-python/
    figure.tight_layout(pad=0.5)

    plt.text(
        0, 0.01, f"N = {N_left}", ha="left", va="bottom", transform=ax_left.transAxes
    )
    plt.text(
        0.99,
        0.01,
        f"N = {N_right}",
        ha="right",
        va="bottom",
        transform=ax_right.transAxes,
    )

    return figure, (ax_left, ax_right)

plot_sidebyside_comparison_singleQ(survey, data_left, data_right, base_q_left, base_q_right, comp_q, N_left, N_right, title_left='', title_right='', width=12, height=10, fontsize=10, plot_stat=PlotStat.COUNT, bar_labels=BarLabels.NONE, fontsize_bar_labels=10, text_wrap=25)

Plot two barplots side by side, with an additional grouping given by comp_q.

Parameters:

  • survey (LimeSurveyData) –

    The main survey object

  • data_left (DataFrame) –

    Data for the left question

  • data_right (DataFrame) –

    Data for the right question

  • base_q_left (str) –

    Question code for the left question

  • base_q_right (str) –

    Question code for the right question

  • comp_q (str) –

    Question code that is used as a grouper (hue) for both questions

  • N_left (int) –

    Size of the left population

  • N_right (int) –

    Size of the right population

  • title_left (str, default: '' ) –

    Title for the left plot

  • title_right (str, default: '' ) –

    Title for the right plot

  • width (float, default: 12 ) –

    Total width of the plot

  • height (float, default: 10 ) –

    Total height of the plot

  • fontsize (int, default: 10 ) –

    Font size for titles and axis (tick) labels

  • plot_stat (PlotStat, default: COUNT ) –

    Whether to plot absolute (PERCENT) or relative (COUNT)

  • bar_labels (BarLabels, default: NONE ) –

    Whether to label each bar

  • fontsize_bar_labels (int, default: 10 ) –

    Font size for the bar labels, if enabled above

  • text_wrap (int, default: 25 ) –

    After how many characters axis and legend labels should wrap

Returns:

  • tuple[Figure, tuple[Axes, Axes]]

    New matplotlib Figure and Axes for this bar plot.

Source code in src/survey_framework/plotting/barplots_sidebyside.py
def plot_sidebyside_comparison_singleQ(
    survey: LimeSurveyData,
    data_left: pd.DataFrame,
    data_right: pd.DataFrame,
    base_q_left: str,
    base_q_right: str,
    comp_q: str,
    N_left: int,
    N_right: int,
    title_left: str = "",
    title_right: str = "",
    width: float = 12,
    height: float = 10,
    fontsize: int = 10,
    plot_stat: PlotStat = PlotStat.COUNT,
    bar_labels: BarLabels = BarLabels.NONE,
    fontsize_bar_labels: int = 10,
    text_wrap: int = 25,
) -> tuple[Figure, tuple[Axes, Axes]]:
    """Plot two barplots side by side, with an additional grouping given by `comp_q`.

    Args:
        survey: The main survey object
        data_left: Data for the left question
        data_right: Data for the right question
        base_q_left: Question code for the left question
        base_q_right: Question code for the right question
        comp_q: Question code that is used as a grouper (hue) for both questions
        N_left: Size of the left population
        N_right: Size of the right population
        title_left: Title for the left plot
        title_right: Title for the right plot
        width: Total width of the plot
        height: Total height of the plot
        fontsize: Font size for titles and axis (tick) labels
        plot_stat: Whether to plot absolute (PERCENT) or relative (COUNT)
        bar_labels: Whether to label each bar
        fontsize_bar_labels: Font size for the bar labels, if enabled above
        text_wrap: After how many characters axis and legend labels should wrap

    Returns:
        New matplotlib Figure and Axes for this bar plot.
    """
    # set seaborn theme
    set_plotstyle()

    figure, axs = plt.subplots(
        nrows=1, ncols=2, dpi=300, figsize=(width, height), sharey=True, layout="tight"
    )
    ax_left, ax_right = cast(tuple[Axes, Axes], axs)

    hue_input_left, colors_left = get_hue_left(data_left, comp_q)
    hue_input_right, colors_right = get_hue_right(data_right, comp_q)

    # left
    plot_left = sns.barplot(
        ax=ax_left,
        x=data_left[plot_stat.value],
        y=list(data_left[base_q_left]),
        hue=hue_input_left,
        palette=colors_left,
        orient="h",
    )
    # right
    plot_right = sns.barplot(
        ax=ax_right,
        x=data_right[plot_stat.value],
        y=list(data_right[base_q_right]),
        hue=hue_input_right,
        palette=colors_right,
        orient="h",
    )

    # remove spines from figure
    # ax_left.spines["top"].set_visible(False)
    # ax_left.spines["right"].set_visible(False)
    # ax_left.spines["bottom"].set_visible(False)
    # ax_left.spines["left"].set_visible(False)
    # ax_right.spines["top"].set_visible(False)
    # ax_right.spines["right"].set_visible(False)
    # ax_right.spines["bottom"].set_visible(False)
    # ax_right.spines["left"].set_visible(False)

    # set xlim equal on both sides
    if plot_stat == PlotStat.PERCENT:
        ax_left.set_xlim((0, 100))
        ax_right.set_xlim((0, 100))
    elif plot_stat == PlotStat.PROPORTION:
        ax_left.set_xlim((0, 1))
        ax_right.set_xlim((0, 1))

    # flip left side
    # https://stackoverflow.com/questions/68858330/right-align-horizontal-seaborn-barplot
    ax_left.invert_xaxis()
    ax_left.yaxis.tick_right()

    # add answer options as y tick labels between the two bar plots
    ax_left = add_tick_labels(
        survey=survey,
        ax=ax_left,
        # data_df=pd.DataFrame(data_left[base_q_left].value_counts()),
        question=base_q_left,
        orientation=Orientation.HORIZONTAL,
        fontsize=fontsize,
        text_wrap=text_wrap,
    )

    if title_left == "":
        title_left = (
            base_q_left
            + ": "
            + survey.questions.loc[
                survey.questions["question_group"] == base_q_left
            ].question_label.iloc[0]
        )
    if title_right == "":
        title_right = (
            base_q_right
            + ": "
            + survey.questions.loc[
                survey.questions["question_group"] == base_q_right
            ].question_label.iloc[0]
        )

    # set title
    plot_left.set_title(
        "\n".join(wrap(title_left, 60)),
        fontsize=fontsize,
    )
    plot_right.set_title(
        "\n".join(wrap(title_right, 60)),
        fontsize=fontsize,
    )

    # return figure, (ax_left, ax_right)

    # add bar labels (the ones on top or next to bars within the plot)
    plot_left = add_bar_labels(
        ax=ax_left,
        show_axes_labels=bar_labels,
        percentcount=plot_stat,
        n_question=N_left,
        fontsize=fontsize_bar_labels,
    )

    # adapt legend
    ax_left = adapt_legend(
        survey=survey,
        ax=ax_left,
        question=comp_q,
        text_wrap=text_wrap,
        anchor=(0.4, 0.18),
    )

    ax_right = adapt_legend(
        survey=survey,
        ax=ax_right,
        question=comp_q,
        text_wrap=text_wrap,
        anchor=(1, 0.18),
    )

    # set y axis big label to ""
    ax_left.set_ylabel("")
    ax_right.set_ylabel("")

    # add number of participants to top right corner
    ax_left.text(
        0, 0.01, f"N = {N_left}", ha="left", va="bottom", transform=ax_left.transAxes
    )
    ax_right.text(
        0.99,
        0.01,
        f"N = {N_right}",
        ha="right",
        va="bottom",
        transform=ax_right.transAxes,
    )

    return figure, (ax_left, ax_right)

Stacked Bar Plots

Stacked bar plots. This is very quick-and-dirty and needs a proper cleanup.

plot_stacked_bar_categorical(df, classes_column, category_column, na_values=False, label_q_data='', width=6, height=4, fontsize=None, fontsize_axes_labels=None, legend_title='', category_order=None)

Plot a stacked barplot with an arbitrary number of bars.

This might be what we want instead of the above -- Code needs a cleanup and proper documentation though.

Parameters:

  • df (DataFrame) –

    description

  • classes_column (str) –

    description

  • category_column (str) –

    description

  • na_values (bool, default: False ) –

    description

  • label_q_data (str, default: '' ) –

    description

  • width (int, default: 6 ) –

    description

  • height (int, default: 4 ) –

    description

  • fontsize (int | None, default: None ) –

    description

  • fontsize_axes_labels (int | None, default: None ) –

    description

  • legend_title (str, default: '' ) –

    description

  • category_order (list[str] | None, default: None ) –

    description

Returns:

  • tuple[Figure, Axes]

    New matplotlib figure and axes.

Source code in src/survey_framework/plotting/stacked.py
def plot_stacked_bar_categorical(
    df: pd.DataFrame,
    classes_column: str,
    category_column: str,
    na_values: bool = False,
    label_q_data: str = "",
    width: int = 6,
    height: int = 4,
    fontsize: int | None = None,
    fontsize_axes_labels: int | None = None,
    legend_title: str = "",
    category_order: list[str] | None = None,
) -> tuple[Figure, Axes]:
    """Plot a stacked barplot with an arbitrary number of bars.

    This might be what we want instead of the above --
    Code needs a cleanup and proper documentation though.

    Args:
        df: _description_
        classes_column: _description_
        category_column: _description_
        na_values: _description_
        label_q_data: _description_
        width: _description_
        height: _description_
        fontsize: _description_
        fontsize_axes_labels: _description_
        legend_title: _description_
        category_order: _description_

    Returns:
        New matplotlib figure and axes.
    """
    hc.set_plotstyle()

    year_categories = (
        list(df[category_column].cat.categories)
        if isinstance(df[category_column].dtype, pd.CategoricalDtype)
        else sorted(df[category_column].unique())
    )
    n_years = len(year_categories)
    bar_width, bar_gap = 0.8, 0.1
    x_positions = np.arange(n_years) * (bar_width + bar_gap)
    fig, ax = plt.subplots(figsize=(width, height), layout="constrained")

    if category_order:
        category_order_mod = category_order.copy()
        if "NA" not in category_order_mod:
            category_order_mod.append("NA")
    else:
        all_classes = df[classes_column]
        if isinstance(all_classes.dtype, pd.CategoricalDtype):
            category_order_mod = list(all_classes.cat.categories)
            if "NA" not in category_order_mod:
                category_order_mod.append("NA")
        else:
            category_order_mod = ["NA"] + sorted(all_classes.dropna().unique())

    n_resp = len(category_order_mod)
    colors_list = (
        [to_rgb(hc.grey), *hc.get_blues(n_resp - 1)[::-1]]
        if na_values
        else list(hc.get_blues(n_resp)[::-1])
    )
    color_mapping = {
        cat: color for cat, color in zip(category_order_mod, colors_list, strict=True)
    }

    for i, yr in enumerate(year_categories):
        subset = df[df[category_column] == yr]
        n_question = len(subset)
        classes = subset[classes_column].fillna("NA")
        class_counts = classes.value_counts().reindex(category_order_mod, fill_value=0)
        class_percentages = class_counts / class_counts.sum() * 100

        bottom = 0
        for cat in category_order_mod:
            perc = class_percentages[cat]
            if perc == 0:
                continue
            ax.bar(
                x_positions[i],
                perc,
                bottom=bottom,
                width=bar_width,
                color=color_mapping[cat],
                label=cat if i == 0 else None,
            )

            # always percentage label, centered in segment
            if perc > 5:
                ax.text(
                    x_positions[i],
                    bottom + perc / 2,
                    f"{perc:.1f}%",
                    ha="center",
                    va="center",
                    fontsize=fontsize_axes_labels,
                    color="black" if _rgb2gray(color_mapping[cat]) > 0.5 else "white",
                )
            bottom += perc

        ax.text(
            x_positions[i],
            bottom + 1,
            f"N = {n_question}",
            ha="center",
            va="bottom",
            fontsize=fontsize,
        )

    ax.set_xticks(x_positions)
    ax.set_xticklabels(year_categories, fontsize=fontsize)
    ax.set_ylabel("Percentage", fontsize=fontsize)
    ax.set_xlabel(label_q_data, fontsize=fontsize)

    handles, labels_leg = ax.get_legend_handles_labels()
    fig.legend(
        handles,
        labels_leg,
        title=legend_title,
        loc="outside lower center",
        fontsize=fontsize,
        title_fontsize=fontsize_axes_labels,
    )

    return fig, ax

plot_stacked_bar_comparison(df1, df2, col_of_interest, plot_title, order, legend_loc, group_labels=('Group 1', 'Group 2'), colors=None, width=3, height=4, ax=None, n_y_pos=0.95)

Create a vertical stacked bar plot of mental health classes for two groups.

Parameters:

  • df1 (DataFrame) –

    DataFrame for group 1 (left bar).

  • df2 (DataFrame) –

    DataFrame for group 2 (right bar).

  • col_of_interest (str) –

    Data column to plot (stacked)

  • plot_title (str) –

    Heading for this plot

  • order (list[str]) –

    Ordered list of categories for stacked bars and legend.

  • legend_loc (Literal['top', 'right', 'bottom'] | None) –

    Position of the legend, or None for no legend.

  • group_labels (tuple[str, str], default: ('Group 1', 'Group 2') ) –

    Tuple of group labels.

  • colors (list[list[tuple[float, float, float]]] | None, default: None ) –

    Optional list of colors (auto-generated from helmholtzcolors if None).

  • width (float, default: 3 ) –

    Width of the plot.

  • height (float, default: 4 ) –

    Height of the plot.

  • ax (Axes | None, default: None ) –

    Axes to draw the plot on. Generate a new Axes if None (default).

  • n_y_pos (float, default: 0.95 ) –

    vertical position of the N labels (default: 95% plot height)

Returns:

  • tuple[Figure, Axes]

    Matplotlib figure and axis.

Source code in src/survey_framework/plotting/stacked.py
def plot_stacked_bar_comparison(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    col_of_interest: str,
    plot_title: str,
    order: list[str],
    legend_loc: Literal["top", "right", "bottom"] | None,
    group_labels: tuple[str, str] = ("Group 1", "Group 2"),
    colors: list[list[tuple[float, float, float]]] | None = None,
    width: float = 3,
    height: float = 4,
    ax: Axes | None = None,
    n_y_pos: float = 0.95,
) -> tuple[Figure, Axes]:
    """Create a vertical stacked bar plot of mental health classes for two groups.

    Args:
        df1: DataFrame for group 1 (left bar).
        df2: DataFrame for group 2 (right bar).
        col_of_interest: Data column to plot (stacked)
        plot_title: Heading for this plot
        order: Ordered list of categories for stacked bars and legend.
        legend_loc: Position of the legend, or None for no legend.
        group_labels: Tuple of group labels.
        colors: Optional list of colors (auto-generated from helmholtzcolors if None).
        width: Width of the plot.
        height: Height of the plot.
        ax: Axes to draw the plot on. Generate a new Axes if None (default).
        n_y_pos: vertical position of the N labels (default: 95% plot height)

    Returns:
        Matplotlib figure and axis.
    """
    hc.set_plotstyle()
    # Use Helmholtz color palette if not provided
    if colors is None:
        # Create group-specific color palettes
        colors = [
            list(reversed(hc.get_blues(len(order)))),  # Group 1 (df1): blues
            list(reversed(hc.get_greens(len(order)))),  # Group 2 (df2): greens
        ]

    # Normalize counts to proportions
    def get_distribution(df: pd.DataFrame) -> "pd.Series[float]":
        return (
            df[f"{col_of_interest}"]
            .value_counts(normalize=True)
            .reindex(order, fill_value=0)
        )

    dist1 = get_distribution(df1)
    dist2 = get_distribution(df2)
    distributions = [dist1, dist2]

    if ax is None:
        fig, ax = plt.subplots(figsize=(width, height), dpi=300, layout="constrained")
    else:
        fig = cast(Figure, ax.figure)
    x_positions = [0, 1]

    for i, dist in enumerate(distributions):
        # Count valid responses (non-NaN) for each group
        n_group1 = df1[col_of_interest].notna().sum()
        n_group2 = df2[col_of_interest].notna().sum()

        # Add N labels inside each bar
        ax.text(
            x_positions[0],
            n_y_pos,
            f"N = {n_group1}",
            ha="center",
            va="center",
            # color="black",
            fontsize=9,
        )

        ax.text(
            x_positions[1],
            n_y_pos,
            f"N = {n_group2}",
            ha="center",
            va="center",
            # color="black",
            fontsize=9,
        )
        bottom = 0.0
        for j, category in enumerate(order):
            height_pct = dist[category]
            # plot a stacked bar
            ax.bar(
                x=x_positions[i],
                height=height_pct,
                bottom=bottom,
                color=colors[i][j],
                edgecolor="white",
                width=0.6,
                label=category if i == 0 else None,
            )
            if height_pct > 0.05:
                # put a percent label into the bar
                ax.text(
                    x_positions[i],
                    bottom + height_pct / 2,
                    f"{height_pct * 100:.1f}%",
                    ha="center",
                    va="center",
                    fontsize=9,
                    color="black" if _rgb2gray(colors[i][j]) > 0.5 else "white",
                )
            bottom += height_pct

    ax.set_xticks(x_positions)
    ax.set_xticklabels(group_labels)
    ax.set_ylim(0, 1)
    ax.set_ylabel("Percent")
    ax.yaxis.set_major_formatter(PercentFormatter(xmax=1.0))
    ax.set_title(f"{plot_title}")

    if legend_loc is not None:
        legend_handles = []
        for i, label in enumerate(order):
            patch1 = Patch(color=colors[0][i])
            patch2 = Patch(color=colors[1][i])
            legend_handles.append(((patch1, patch2), label))

        # Now create the legend with a custom handler
        match legend_loc:
            case "right":
                ax.legend(
                    handles=[h[0] for h in legend_handles][::-1],
                    labels=[h[1] for h in legend_handles][::-1],
                    handler_map={tuple: HandlerTuple(ndivide=None)},
                    loc="upper left",  # Align legend's top-left corner
                    bbox_to_anchor=(1.02, 1),  # Place it just outside top-right of plot
                    ncol=1,  # Vertical legend
                )
            case "bottom":
                ax.legend(
                    handles=[h[0] for h in legend_handles][::-1],
                    labels=[h[1] for h in legend_handles][::-1],
                    handler_map={tuple: HandlerTuple(ndivide=None)},
                    loc="upper center",  # yes, "upper" is intentional
                    bbox_to_anchor=(0.5, -0.1),  # legend below the plot
                    ncol=1,  # Spread legend entries in one row
                )
            case "top":
                ax.legend(
                    handles=[h[0] for h in legend_handles][::-1],
                    labels=[h[1] for h in legend_handles][::-1],
                    handler_map={tuple: HandlerTuple(ndivide=None)},
                    loc="lower center",  # yes, "lower" is intentional
                    bbox_to_anchor=(0.5, 1.03),  # legend above the plot
                    ncol=1,  # Spread legend entries in one row
                )

    return fig, ax

plot_stacked_bar_single(df1, col_of_interest, plot_title, order, legend_loc, colors=None, width=3, height=4, ax=None, n_y_pos=0.95)

Create a vertical stacked bar plot of mental health classes for one group.

NOTE: This is a quick-fix. It might make sense to merge it with histplot.py/simple_histplot() where stacking should be implemented anyway

Parameters:

  • df1 (DataFrame) –

    DataFrame for group 1 (left bar).

  • col_of_interest (str) –

    Data column to plot (stacked)

  • plot_title (str) –

    Heading for this plot

  • order (list[str]) –

    Ordered list of categories for stacked bars and legend.

  • legend_loc (Literal['top', 'right', 'bottom'] | None) –

    Position of the legend, or None for no legend.

  • colors (list[list[tuple[float, float, float]]] | None, default: None ) –

    Optional list of colors (auto-generated from helmholtzcolors if None).

  • width (float, default: 3 ) –

    Width of the plot.

  • height (float, default: 4 ) –

    Height of the plot.

  • ax (Axes | None, default: None ) –

    Axes to draw the plot on. Generate a new Axes if None (default).

  • n_y_pos (float, default: 0.95 ) –

    vertical position of the N labels (default: 95% plot height)

Returns:

  • tuple[Figure, Axes]

    Matplotlib figure and axis.

Source code in src/survey_framework/plotting/stacked.py
def plot_stacked_bar_single(
    df1: pd.DataFrame,
    col_of_interest: str,
    plot_title: str,
    order: list[str],
    legend_loc: Literal["top", "right", "bottom"] | None,
    colors: list[list[tuple[float, float, float]]] | None = None,
    width: float = 3,
    height: float = 4,
    ax: Axes | None = None,
    n_y_pos: float = 0.95,
) -> tuple[Figure, Axes]:
    """Create a vertical stacked bar plot of mental health classes for one group.

    NOTE: This is a quick-fix. It might make sense to merge it with
    histplot.py/simple_histplot() where stacking should be implemented anyway

    Args:
        df1: DataFrame for group 1 (left bar).
        col_of_interest: Data column to plot (stacked)
        plot_title: Heading for this plot
        order: Ordered list of categories for stacked bars and legend.
        legend_loc: Position of the legend, or None for no legend.
        colors: Optional list of colors (auto-generated from helmholtzcolors if None).
        width: Width of the plot.
        height: Height of the plot.
        ax: Axes to draw the plot on. Generate a new Axes if None (default).
        n_y_pos: vertical position of the N labels (default: 95% plot height)

    Returns:
        Matplotlib figure and axis.
    """
    hc.set_plotstyle()
    # Use Helmholtz color palette if not provided
    if colors is None:
        # Create group-specific color palettes
        colors = [
            list(reversed(hc.get_blues(len(order)))),  # Group 1 (df1): blues
        ]

    # Normalize counts to proportions
    def get_distribution(df: pd.DataFrame) -> "pd.Series[float]":
        return (
            df[f"{col_of_interest}"]
            .value_counts(normalize=True)
            .reindex(order, fill_value=0)
        )

    dist1 = get_distribution(df1)
    distributions = [dist1]

    if ax is None:
        fig, ax = plt.subplots(figsize=(width, height), dpi=300, layout="constrained")
    else:
        fig = cast(Figure, ax.figure)
    x_positions = [0]

    for i, dist in enumerate(distributions):
        # Count valid responses (non-NaN) for each group
        n_group1 = df1[col_of_interest].notna().sum()

        # Add N labels inside each bar
        ax.text(
            x_positions[0],
            n_y_pos,
            f"N = {n_group1}",
            ha="center",
            va="center",
            # color="black",
            fontsize=9,
        )

        bottom = 0.0
        for j, category in enumerate(order):
            height_pct = dist[category]
            # plot a stacked bar
            ax.bar(
                x=x_positions[i],
                height=height_pct,
                bottom=bottom,
                color=colors[i][j],
                edgecolor="white",
                width=0.6,
                label=category if i == 0 else None,
            )
            if height_pct > 0.05:
                # put a percent label into the bar
                ax.text(
                    x_positions[i],
                    bottom + height_pct / 2,
                    f"{height_pct * 100:.1f}%",
                    ha="center",
                    va="center",
                    fontsize=9,
                    color="black" if _rgb2gray(colors[i][j]) > 0.5 else "white",
                )
            bottom += height_pct

    ax.set_xticks(x_positions)
    ax.set_xticklabels([""])
    ax.set_ylim(0, 1)
    ax.set_ylabel("Percent")
    ax.yaxis.set_major_formatter(PercentFormatter(xmax=1.0))
    ax.set_title(f"{plot_title}")

    if legend_loc is not None:
        legend_handles = []
        for i, label in enumerate(order):
            patch1 = Patch(color=colors[0][i])
            # patch2 = Patch(color=colors[1][i])
            # legend_handles.append(((patch1, patch2), label))
            legend_handles.append(((patch1), label))

        # Now create the legend with a custom handler
        match legend_loc:
            case "right":
                ax.legend(
                    handles=[h[0] for h in legend_handles][::-1],
                    labels=[h[1] for h in legend_handles][::-1],
                    handler_map={tuple: HandlerTuple(ndivide=None)},
                    loc="upper left",  # Align legend's top-left corner
                    bbox_to_anchor=(1.02, 1),  # Place it just outside top-right of plot
                    ncol=1,  # Vertical legend
                )
            case "bottom":
                ax.legend(
                    handles=[h[0] for h in legend_handles][::-1],
                    labels=[h[1] for h in legend_handles][::-1],
                    handler_map={tuple: HandlerTuple(ndivide=None)},
                    loc="upper center",  # yes, "upper" is intentional
                    bbox_to_anchor=(0.5, -0.1),  # legend below the plot
                    ncol=1,  # Spread legend entries in one row
                )
            case "top":
                ax.legend(
                    handles=[h[0] for h in legend_handles][::-1],
                    labels=[h[1] for h in legend_handles][::-1],
                    handler_map={tuple: HandlerTuple(ndivide=None)},
                    loc="lower center",  # yes, "lower" is intentional
                    bbox_to_anchor=(0.5, 1.03),  # legend above the plot
                    ncol=1,  # Spread legend entries in one row
                )

    return fig, ax

Special Plots

Survival Plots. Typically used to visualize change over time.

plot_survival_plot(df, category=None, ticks=None, tick_map=str, legend_replace=None, legend_title=None, colors=None, width=6, height=4)

Plots the given DataFrame as a survival plot, approaching zero.

Parameters:

  • df (DataFrame) –

    DataFrame with a column of numerical data called "data".

  • category (str | None, default: None ) –

    Column in df to categorize the data.

  • ticks (Iterable[int] | None, default: None ) –

    Iterable of x axis ticks.

  • tick_map (Callable[[int], str], default: str ) –

    Function to generate strings from ticks.

  • legend_replace (dict[str, str] | None, default: None ) –

    Replacements for legend entries.

  • legend_title (str | None, default: None ) –

    Heading for the legend.

  • colors (list[tuple[float, float, float]] | None, default: None ) –

    Line colors, instead of shades of blue.

  • width (int, default: 6 ) –

    Horizontal figure size.

  • height (int, default: 4 ) –

    Vertical figure size.

Returns:

  • tuple[Figure, Axes]

    The matplotlib figure and axes.

Source code in src/survey_framework/plotting/survivalplot.py
def plot_survival_plot(
    df: pd.DataFrame,
    category: str | None = None,
    ticks: Iterable[int] | None = None,
    tick_map: Callable[[int], str] = str,
    legend_replace: dict[str, str] | None = None,
    legend_title: str | None = None,
    colors: list[tuple[float, float, float]] | None = None,
    width: int = 6,
    height: int = 4,
) -> tuple[Figure, Axes]:
    """Plots the given DataFrame as a survival plot, approaching zero.

    Args:
        df: DataFrame with a column of numerical data called "data".
        category: Column in `df` to categorize the data.
        ticks: Iterable of x axis ticks.
        tick_map: Function to generate strings from ticks.
        legend_replace: Replacements for legend entries.
        legend_title: Heading for the legend.
        colors: Line colors, instead of shades of blue.
        width: Horizontal figure size.
        height: Vertical figure size.

    Returns:
        The matplotlib figure and axes.
    """
    if legend_replace is None:
        legend_replace = dict()
    set_plotstyle()

    figure, ax = plt.subplots(dpi=300, figsize=(width, height), layout="constrained")

    if not colors and category:
        n_colors = len(pd.unique(df[category]))
        colors = get_blues(n_colors)

    sns.ecdfplot(
        data=df,
        ax=ax,
        x="data",  # hard-coded for simplicity
        hue=category,
        stat="percent",
        complementary=True,
        palette=colors,
    )

    if ticks:
        labels = map(tick_map, ticks)
        ax.set_xticks(list(ticks), labels)

    legend = ax.get_legend()
    if legend:
        # place the number of participants behind each category
        counts = df.groupby(category, observed=False).count()

        for tt in legend.get_texts():
            label = tt.get_text()
            # get number of participants...
            try:
                # ...for string keys
                group_n = counts.loc[label]["data"]
            except KeyError:
                # ...for integer keys
                group_n = counts.loc[int(label)]["data"]
            # also replace the label if it's in the replacement dictionary
            replacement = legend_replace.get(label, label)

            new_label = "\n".join(wrap(f"{replacement} ({group_n})", 25))
            tt.set_text(new_label)
        if legend_title is not None:
            legend.set_title(legend_title)
    else:
        # place the number of participants in the top right corner
        ax.text(
            0.99,
            0.99,
            f"N = {len(df)}",
            ha="right",
            va="top",
            transform=ax.transAxes,
        )

    return figure, ax

Likert Plots (typically used for data on a 5-point scale).

plot_likertplot(survey, data_df, question, order, bar_labels=BarLabels.PERCENT, width=6, height=4, percent_cutoff=8, text_wrap=30, relabel_subquestions=True)

Plot the given data as a Likert plot.

Parameters:

  • survey (LimeSurveyData) –

    the LimeSurvey object

  • data_df (DataFrame) –

    dataframe containing answers to be plotted

  • question (str) –

    question code (e.g. 'D4')

  • order (list[str]) –

    ordered list of answer options. The rest will be dropped!

  • bar_labels (BarLabels, default: PERCENT ) –

    which kind of bar labels to use.

  • width (float, default: 6 ) –

    width of the figure.

  • height (float, default: 4 ) –

    height of the figure.

  • percent_cutoff (int, default: 8 ) –

    If groups are smaller than x percent, they don't get a label.

  • text_wrap (int, default: 30 ) –

    wrap question labels after x characters.

  • relabel_subquestions (bool, default: True ) –

    Whether to rewrite y axis labels using the question data.

Returns:

  • tuple[Figure, Axes]

    The matplotlib figure and axis

Source code in src/survey_framework/plotting/likertplot.py
def plot_likertplot(
    survey: LimeSurveyData,
    data_df: pd.DataFrame,
    question: str,
    order: list[str],
    bar_labels: BarLabels = BarLabels.PERCENT,
    width: float = 6,
    height: float = 4,
    percent_cutoff: int = 8,
    text_wrap: int = 30,
    relabel_subquestions: bool = True,
) -> tuple[Figure, Axes]:
    """Plot the given data as a Likert plot.

    Args:
        survey: the LimeSurvey object
        data_df: dataframe containing answers to be plotted
        question: question code (e.g. 'D4')
        order: ordered list of answer options. **The rest will be dropped!**
        bar_labels: which kind of bar labels to use.
        width: width of the figure.
        height: height of the figure.
        percent_cutoff: If groups are smaller than x percent, they don't get a label.
        text_wrap: wrap question labels after x characters.
        relabel_subquestions: Whether to rewrite y axis labels using the question data.

    Returns:
        The matplotlib figure and axis
    """
    assert "id" not in data_df.columns
    set_plotstyle()
    colors = palette[len(order)]

    fig, ax = plt.subplots(dpi=300, figsize=(width, height), layout="constrained")

    # remove values not present in order
    drop = set(survey.get_choices(question).keys()).difference(order)
    dropped_df = data_df.map(lambda d: d if d not in drop else None)

    # use external library to actually draw the plot
    # silence FutureWarnings (already fixed upstream, not yet in PyPI)
    with warnings.catch_warnings():
        warnings.simplefilter(action="ignore", category=FutureWarning)
        ax = _likert(dropped_df, order, colors=colors, ax=ax)

    # set the title (overarching question)
    title = survey.questions.loc[survey.questions["question_group"] == question][
        "question_label"
    ].unique()
    assert len(title) == 1, "Multiple question_labels found, check data correctness."
    # ax.set_title(title[0])

    # set subquestion labels (y ticks)
    if relabel_subquestions:
        new_labels = []
        for old_label in ax.get_yticklabels():
            label = cast(str, survey.questions.loc[old_label.get_text()]["label"])
            clean_str = label.replace("/", " / ")
            new_labels.append("\n".join(wrap(clean_str, text_wrap, max_lines=3)))
        ax.set_yticklabels(new_labels, linespacing=0.9)

    # reposition legend (draw new legend and remove the old one)
    (handles, labels) = ax.get_legend_handles_labels()
    legend = fig.legend(
        handles, labels, loc="outside upper center", ncol=2 if len(handles) == 4 else 3
    )
    old_legend = ax.get_legend()
    if old_legend is not None:
        old_legend.remove()

    # set the legend labels
    choices = survey.questions.loc[survey.questions["question_group"] == question][
        "choices"
    ].iloc[0]  # underlying assumption: all subquestions use the same scale
    for text in legend.get_texts():
        text.set_text(choices[text.get_text()])

    # add number of participants
    n_question = data_df.count().iloc[0]  # don't count NaNs, but count dropped answers
    plt.text(
        0.99, 0.99, f"N = {n_question}", ha="right", va="top", transform=ax.transAxes
    )

    def cutoff_fmt(x: float) -> str:
        """Formatter for bar labels, with a 5% cutoff (no label for small bars)."""
        percentage = x * 100 / n_question
        if percentage < percent_cutoff:
            return ""
        match bar_labels:
            case BarLabels.PERCENT:
                return f"{percentage:.1f}%"
            case BarLabels.COUNT:
                return f"{x:g}"
            case BarLabels.NONE:
                return ""
            case _:
                raise AssertionError("unreachable")

    # set bar labels
    for bc in ax.containers[1:]:
        ax.bar_label(
            container=cast(BarContainer, bc),
            label_type="center",
            weight="bold",
            fontsize="7",
            color="white",
            # path_effects=[PathEffects.withStroke(linewidth=1, foreground='black')],
            fmt=cutoff_fmt,
        )

    return fig, ax

A heatmap, used to visualize correlation strength.

CorrMethod

Bases: StrEnum

Correlation Method used by Pandas.

plot_heatmap(df, survey, width=6.5, height=6, method=CorrMethod.SPEARMAN)

Correlation heatmap of the input dataframe vs. all (mental) health scores.

Parameters:

  • df (DataFrame) –

    Dataframe with numeric columns that should be correlated against health

  • survey (LimeSurveyData) –

    main survey object

  • width (float, default: 6.5 ) –

    Horizontal figure size.

  • height (float, default: 6 ) –

    Vertical figure size.

  • method (CorrMethod, default: SPEARMAN ) –

    Statistical correlation method.

Returns:

  • tuple[Figure, Axes]

    tuple of matplotlib figure and axes for the heatmap

Source code in src/survey_framework/plotting/heatmap.py
def plot_heatmap(
    df: pd.DataFrame,
    survey: LimeSurveyData,
    width: float = 6.5,
    height: float = 6,
    method: CorrMethod = CorrMethod.SPEARMAN,
) -> tuple[Figure, Axes]:
    """Correlation heatmap of the input dataframe vs. all (mental) health scores.

    Args:
        df: Dataframe with numeric columns that should be correlated against health
        survey: main survey object
        width: Horizontal figure size.
        height: Vertical figure size.
        method: Statistical correlation method.

    Returns:
        tuple of matplotlib figure and axes for the heatmap
    """
    """"""
    SOMATIC = "D4"
    BURNOUT = "D3d"

    # health scores
    sta = rate_mental_health(
        survey.get_responses(Condition.STATE_ANXIETY), Condition.STATE_ANXIETY
    )
    tra = rate_mental_health(
        survey.get_responses(Condition.TRAIT_ANXIETY), Condition.TRAIT_ANXIETY
    )
    depr = rate_mental_health(
        survey.get_responses(Condition.DEPRESSION), Condition.DEPRESSION
    )
    somatic = rate_somatic(survey.get_responses(SOMATIC))
    _bout = rate_burnout(survey.get_responses(BURNOUT)).set_index("id")

    correlations = pd.DataFrame(
        {
            "State Anxiety": df.corrwith(
                sta["state_anxiety_score"], method=method.value
            ),
            "Trait Anxiety": df.corrwith(
                tra["trait_anxiety_score"], method=method.value
            ),
            "Depression": df.corrwith(depr["depression_score"], method=method.value),
            "Somatic Symptoms": df.corrwith(
                somatic["somatic_score"], method=method.value
            ),
            # "Exhaustion": df.corrwith(bout["Exhaustion"], method=method.value),
            # "Cynicism": df.corrwith(bout["Cynicism"], method=method.value),
            # "Professional Efficacy": df.corrwith(
            #     bout["Professional Efficacy"], method=method.value
            # ),
        }
    )

    correlations.sort_values(by="State Anxiety", inplace=True)
    # print(correlations)

    hc.set_plotstyle()
    figure, ax = plt.subplots(dpi=300, figsize=(width, height), layout="constrained")

    ax = sns.heatmap(correlations, annot=True, ax=ax)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

    return figure, ax

A very simple, but versatile histogram plot.

This basically produces bar plots, but can add extras like a density curve.

simple_histplot(data_df, question_code, order_dict, hue_series=None, hue_order=None, kde=False, log_scale=False, binwidth=None, width=10, height=6, bar_labels=BarLabels.NONE)

Plot a histogram of values in data_df[question_code].

Currently in experimental state, to be expanded for stacked barplots?

Parameters:

  • data_df (DataFrame) –

    DataFrame to be plotted

  • question_code (str) –

    Column in data_df to be plotted

  • order_dict (dict[str, list[str]]) –

    answer ordering, can be empty (ORDER from order/order2024.py)

  • hue_series (Series[str] | None, default: None ) –

    Separator for data_df (needs same index).

  • hue_order (Sequence[str] | None, default: None ) –

    How to sort hues in the legend and plot

  • kde (bool, default: False ) –

    Whether to plot a density curve.

  • log_scale (bool, default: False ) –

    whether the y axis should be log-scaled.

  • binwidth (int | None, default: None ) –

    Width of bins; automatically inferred if not given.

  • width (float, default: 10 ) –

    Plot width.

  • height (float, default: 6 ) –

    Plot height.

  • bar_labels (BarLabels, default: NONE ) –

    How to label each bar (NONE by default, or PERCENT)

Returns:

  • tuple[Figure, Axes]

    New figure and axes of the histogram

Source code in src/survey_framework/plotting/histplot.py
def simple_histplot(
    data_df: DataFrame,
    question_code: str,
    order_dict: dict[str, list[str]],
    hue_series: "Series[str] | None" = None,
    hue_order: Sequence[str] | None = None,
    kde: bool = False,
    log_scale: bool = False,
    binwidth: int | None = None,
    width: float = 10,
    height: float = 6,
    bar_labels: BarLabels = BarLabels.NONE,
) -> tuple[Figure, Axes]:
    """Plot a histogram of values in `data_df[question_code]`.

    Currently in experimental state, to be expanded for stacked barplots?

    Args:
        data_df: DataFrame to be plotted
        question_code: Column in `data_df` to be plotted
        order_dict: answer ordering, can be empty (ORDER from order/order2024.py)
        hue_series: Separator for data_df (needs same index).
        hue_order: How to sort hues in the legend and plot
        kde: Whether to plot a density curve.
        log_scale: whether the y axis should be log-scaled.
        binwidth: Width of bins; automatically inferred if not given.
        width: Plot width.
        height: Plot height.
        bar_labels: How to label each bar (NONE by default, or PERCENT)

    Returns:
        New figure and axes of the histogram
    """
    orderlist = order_dict.get(question_code)
    if orderlist:
        data_df[question_code] = pd.Categorical(
            data_df[question_code], categories=orderlist, ordered=True
        )

    hc.set_plotstyle()
    figure, ax = plt.subplots(dpi=300, figsize=(width, height), layout="constrained")

    if hue_series is not None:
        df = data_df.join(hue_series, on="id")
        hue = hue_series.name
        assert isinstance(hue, str)

        ax = sns.histplot(
            data=df,
            ax=ax,
            x=question_code,
            hue=hue,
            hue_order=hue_order,
            palette=[hc.helmholtzblue, hc.helmholtzgreen],
            stat="percent",
            common_norm=False,
            multiple="dodge",
            shrink=0.8,
            binwidth=binwidth,
            kde=kde,
            log_scale=log_scale,
        )

        # remove legend title, add group sizes behind labels
        legend = ax.get_legend()
        assert legend is not None, "Legend is guaranteed to exist"
        legend.set_title("")
        counts = df.groupby(hue).count()[question_code].to_dict()
        for text in legend.texts:
            name = text.get_text()
            text.set_text(f"{name} ({counts[name]})")
    else:
        ax = sns.histplot(
            data=data_df,
            ax=ax,
            x=question_code,
            stat="percent",
            color=hc.helmholtzblue,
            shrink=0.8,
            binwidth=binwidth,
            kde=kde,
            log_scale=log_scale,
        )

    match bar_labels:
        case BarLabels.NONE:
            pass
        case BarLabels.COUNT:
            raise ValueError("count bar labels not supported on histogram plots")
        case BarLabels.PERCENT:
            ax.bar_label(cast(BarContainer, ax.containers[0]), fmt="{:.1f}%")

    return figure, ax

Sankey Plot -- visualizes the "flow" of participants between questions.

plot_sankey(data_df, titles=None, title='', width=6, height=8, fontsize=None, plot_fractions=True)

Plots a two staged sankey diagram.

Parameters:

  • data_df (DataFrame) –

    Data containing rows like (label_left, count, label_right, same_count)

  • titles (list[str] | None, default: None ) –

    Titles of both stages.

  • title (str, default: '' ) –

    deprecated, unused

  • width (float, default: 6 ) –

    Total plot width.

  • height (float, default: 8 ) –

    Total plot height.

  • fontsize (int | None, default: None ) –

    Font size for the plot

  • plot_fractions (bool, default: True ) –

    Whether group sizes should be displayed

Returns:

  • Figure

    New Figure and Axes

Source code in src/survey_framework/plotting/sankeyplots.py
def plot_sankey(
    data_df: pd.DataFrame,
    titles: list[str] | None = None,
    title: str = "",
    width: float = 6,
    height: float = 8,
    fontsize: int | None = None,
    plot_fractions: bool = True,
) -> Figure:
    """Plots a two staged sankey diagram.

    Args:
        data_df: Data containing rows like
            (label_left, count, label_right, same_count)
        titles: Titles of both stages.
        title: _deprecated, unused_
        width: Total plot width.
        height: Total plot height.
        fontsize: Font size for the plot
        plot_fractions: Whether group sizes should be displayed

    Returns:
        New Figure and Axes
    """
    set_plotstyle()

    # Colors
    color_dict = {}
    colors = sns.color_palette("Paired").as_hex()

    # Take one color for each label on the left side
    for i, row in enumerate(data_df[0].unique()):
        color_dict[row] = colors[i]

    # Plot
    fig, ax = plt.subplots(dpi=300, figsize=(width, height), layout="constrained")

    sky.sankey(
        data_df,
        ax=ax,
        sort="none",  # keep dataframe ordering
        titles=titles,
        valign="center",
        color_dict=color_dict,
        node_gap=0.02,
        frame_gap=0,  # remove whitespace above/below
        value_loc="both" if plot_fractions else "none",
        value_thresh_ofmax=0.01,  # prevent label clashes
        fontsize=fontsize,
    )

    return fig