给定一个 Polars DataFrame

data = pl.DataFrame({"user_id": [1, 1, 1, 1, 1, 2, 2, 2, 2], "event": [False, True, True, False, True, True, True, False, False]

我希望计算一列event_chain,用于统计用户发生事件的次数,其中前 4 行中的任何一行也发生过事件。每次发生新事件时,如果用户已经有一个连续事件,则连续计数器会递增,如果接下来的 4 行没有发生其他事件,则应将其设置为零

用户身份 事件 事件链 价值理由
1 错误的 0 还没有活动
1 真的 0 最后 4 行(不包括当前行)没有事件
1 真的 1 此行有 1 个事件,最后 4 行有 1 个事件
1 错误的 1 由于接下来的 4 行内有事件,因此不会重置为 0
1 真的 2 如果发生此行或发生最后 4 行,则增加连胜次数
2 真的 0 之前没有活动
2 真的 1 事件此行和最后 4 行的用户
2 错误的 0 此行无事件,且用户接下来 4 行无事件,重置为 0
2 错误的 0

我有如下工作代码来执行此操作,但我认为应该有一种更简洁的方法来执行此操作

        data.with_columns(
         rows_since_last_event=pl.int_range(pl.len()).over("user_id")
          - pl.when("event").then(pl.int_range(pl.len())).forward_fill()
          .over("user_id"),
          rows_till_next_event=pl.when("event").then(pl.int_range(pl.len()))
          .backward_fill().over("user_id") - pl.int_range(pl.len()).over("athlete_id")
         )
        .with_columns(
            chain_event=pl.when(
                pl.col("event")
                .fill_null(0)
                .rolling_sum(window_size=4, min_periods=1)
                .over("user_id")
                - pl.col("event").fill_null(0)
                > 0
            )
            .then(1)
            .otherwise(0)
        )
        .with_columns(
            chain_event_change=pl.when(
                pl.col("chain_event").eq(1),
                pl.col("chain_event").shift().eq(0),
                pl.col("rows_since_last_event").fill_null(5) > 3,
            )
            .then(1)
            .when(
                pl.col("congested_event").eq(0),
                pl.col("congested_event").shift().eq(1),
                pl.col("rows_till_next_event").fill_null(5) > 3,
            )
            .then(1)
            .otherwise(0)
        )
        .with_columns(
            chain_event_identifier=pl.col("chain_event_change")
            .cum_sum()
            .over("user_id")
        )
        .with_columns(
            event_chain=pl.col("chain_event")
            .cum_sum()
            .over("user_id", "chain_event_identifier")
        )
    )

4

  • 我不明白为什么最后两行有event_chain = 0。难道它们不应该有吗,event_chain = 1因为前面的行user_id = 2发生了事件?


    – 


  • event_change如果用户在接下来的 4 行中没有事件,并且当前行不是事件,则应设置为 0。我知道围绕这一点设置的规则有点复杂


    – 

  • 好的,是next4 行还是previous原始问题中写的 4 行?也许你可以在每一行添加一条注释,解释为什么分配了特定的数字


    – 

  • 两者都有,我现在要补充的是


    – 


最佳答案
2

您可以使用+来计算事件

df = (
   df.with_columns(
      previous = 
         pl.when(event = True) 
           .then(pl.sum_horizontal(pl.col.event.shift(N) for N in range(1, 5))
           .over("user_id")),
      next = 
         pl.when(event = False)
           .then(pl.sum_horizontal(pl.col.event.shift(-N) for N in range(1, 5))
           .over("user_id"))
   )
)
shape: (9, 4)
┌─────────┬───────┬──────────┬──────┐
│ user_id ┆ event ┆ previous ┆ next
│ ---     ┆ ---   ┆ ---      ┆ ---  │
│ i64     ┆ bool  ┆ u32      ┆ u32  │
╞═════════╪═══════╪══════════╪══════╡1       ┆ false ┆ null     ┆ 31       ┆ true  ┆ 0        ┆ null │1       ┆ true  ┆ 1        ┆ null │1       ┆ false ┆ null     ┆ 11       ┆ true  ┆ 2        ┆ null │2       ┆ true  ┆ 0        ┆ null │2       ┆ true  ┆ 1        ┆ null │2       ┆ false ┆ null     ┆ 02       ┆ false ┆ null     ┆ 0
└─────────┴───────┴──────────┴──────┘

那么看起来你想向前填充,event = False 除非 next = 0

df.with_columns(
   pl.when(event = False, next = 0)
     .then(0)
     .when(event = False)
     .then(pl.col.previous.forward_fill().over("user_id"))
     .otherwise("previous")
     .fill_null(0)
     .alias("event_chain")
)
shape: (9, 5)
┌─────────┬───────┬──────────┬──────┬─────────────┐
│ user_id ┆ event ┆ previous ┆ next ┆ event_chain │
│ ---     ┆ ---   ┆ ---      ┆ ---  ┆ ---         │
│ i64     ┆ bool  ┆ u32      ┆ u32  ┆ u32         │
╞═════════╪═══════╪══════════╪══════╪═════════════╡1       ┆ false ┆ null     ┆ 301       ┆ true  ┆ 0        ┆ null ┆ 01       ┆ true  ┆ 1        ┆ null ┆ 11       ┆ false ┆ null     ┆ 111       ┆ true  ┆ 2        ┆ null ┆ 22       ┆ true  ┆ 0        ┆ null ┆ 02       ┆ true  ┆ 1        ┆ null ┆ 12       ┆ false ┆ null     ┆ 002       ┆ false ┆ null     ┆ 00
└─────────┴───────┴──────────┴──────┴─────────────┘

由于它有点复杂,可能需要使用中间变量/函数来构建表达式。

例如

window_size = 4
window = pl.Series(range(1, window_size + 1))

event_count = lambda window: pl.sum_horizontal(map(pl.col.event.shift, window))
event_count_previous = lambda: event_count(window)
event_count_next = lambda: event_count(-window)

waiting_for_first_event = pl.col.event.cast(int).cum_max() == 0

df.with_columns(event_chain = 
   pl.when(event=True)
     .then(event_count_previous())
     .when(waiting_for_first_event)
     .then(0)
     .otherwise(event_count_next())
     .over("user_id")
)
  • (使用略有不同的方法将其压缩为单个 when/then 链)

3

  • 1
    简洁明了,我喜欢您预先计算计数器的方式,同时检查它是否真的要与移位总和一起使用。您可能也可以使用.rolling_sum()它,但您需要.reverse()计算下一行。


    – 

  • 不应该waiting_for_first_event是 over(“user_id”) 吗?这会引发错误window expression not allowed in aggregation,但如果没有这个,waiting_for_first_event当任何用户遇到第一个事件时,每个用户都会变为 true,对吗?


    – 

  • 1
    @cdkdrf 末尾的 overpl.when().then().otherwise().over()将应用于每个 then/otherwise 表达式。它本质上是.over()每个分支内重复的简写语法。我无法复制任何错误 – 它对我来说运行良好。


    – 


更新版本
我看了@jqurious 的回答,我认为你可以让它更简洁

  • 在检查前 N 行时预先计算计数器。我们只需要sum前几行,对于后几行,我们只需要知道它们是否存在,这样max就足够了。
  • 还请注意,我们使用大小为 5 的窗口(包括当前行),因此我们不需要“开始”事件的特殊情况。
(
    data
    .with_columns(
        chain_event =
           pl.sum_horizontal(pl.col.event.shift(i) for i in range(5))
             .over('user_id'),
        next =
           pl.max_horizontal(pl.col.event.shift(-i) for i in range(1,5))
             .over('user_id').fill_null(False)
    ).with_columns(
        pl
        .when(event = False, next = False).then(0)
        .when(event = False, chain_event = 0).then(0)
        .otherwise(pl.col.chain_event - 1)
        .alias('chain_event')
        # or even shorter but a bit more cryptic
        # pl
        # .when(event = False, next = False).then(0)
        # .otherwise(pl.col.chain_event - pl.col.event)
        # .alias('chain_event')
    )
)

┌─────────┬───────┬──────┬───────┬─────────────┐
│ user_id ┆ event ┆ prev ┆ next  ┆ chain_event │
│ ---     ┆ ---   ┆ ---  ┆ ---   ┆ ---         │
│ i64     ┆ bool  ┆ u32  ┆ bool  ┆ i64         │
╞═════════╪═══════╪══════╪═══════╪═════════════╡1       ┆ false ┆ 0    ┆ true  ┆ 01       ┆ true  ┆ 1    ┆ true  ┆ 01       ┆ true  ┆ 2    ┆ true  ┆ 11       ┆ false ┆ 2    ┆ true  ┆ 11       ┆ true  ┆ 3    ┆ false ┆ 22       ┆ true  ┆ 1    ┆ true  ┆ 02       ┆ true  ┆ 2    ┆ false ┆ 12       ┆ false ┆ 2    ┆ false ┆ 02       ┆ false ┆ 2    ┆ false ┆ 0
└─────────┴───────┴──────┴───────┴─────────────┘

先前版本

  • 获取前 4 行和后 4 行。
  • 所以我们知道这些窗口内是否有事件发生。
  • 创建连续的事件组,以便我们可以重新启动计数器。
  • 增加计数器。
  • 仅考虑有事件的群组。

基本上,这里最重要的部分是,如果满足以下任一条件,我们就将行视为在链内:

  • 当前行有事件。
  • 前 4 行内有事件(否则我们已经重新启动计数器)并且接下来 4 行内有事件(否则我们将重置计数器)。
(
    data
    .with_columns(
        pl.max_horizontal(pl.col("event").shift(i + 1).over('user_id') for i in range(4)).alias("max_lag").fill_null(False),
        pl.max_horizontal(pl.col("event").shift(-i - 1).over('user_id') for i in range(4)).alias("max_lead").fill_null(False)
    ).with_columns(
        event_chain = (pl.col("max_lag") & pl.col("max_lead")) | pl.col('event')
    ).select(
        pl.col('user_id','event'),
        pl.when(pl.col('event_chain'))
        .then(
            pl.col('event').cum_sum().over('user_id', pl.col('event_chain').rle_id().over('user_id')) - 1
        ).otherwise(0)
        .alias('event_chain')
    )
)

┌─────────┬───────┬─────────────┐
│ user_id ┆ event ┆ event_chain │
│ ---     ┆ ---   ┆ ---         │
│ i64     ┆ bool  ┆ i64         │
╞═════════╪═══════╪═════════════╡1       ┆ false ┆ 01       ┆ true  ┆ 01       ┆ true  ┆ 11       ┆ false ┆ 11       ┆ true  ┆ 22       ┆ true  ┆ 02       ┆ true  ┆ 12       ┆ false ┆ 02       ┆ false ┆ 0
└─────────┴───────┴─────────────┘

或者

  • 计算前 4 行中是否有事件
  • 与计算接下来 4 行内是否有事件相同
(
    data
    .with_columns(
        (pl.col('event').cast(pl.Int32).shift(1).rolling_max(4, min_periods=0)).over('user_id').fill_null(0).alias('max_lag'),
        (pl.col('event').reverse().cast(pl.Int32).shift(1).rolling_max(4, min_periods=0).reverse()).over('user_id').fill_null(0).alias('max_lead')
    ).with_columns(
        event_chain = ((pl.col("max_lag") == 1) & (pl.col("max_lead") == 1)) | pl.col('event')
    ).select(
        pl.col('user_id','event'),
        pl.when(pl.col('event_chain'))
        .then(
            pl.col('event').cum_sum().over('user_id', pl.col('event_chain').rle_id().over('user_id')) - 1
        ).otherwise(0)
        .alias('event_chain')
    )
)