本书其它笔记Fast Data Processing with Spark 2(Third Edition)》读书笔记目

数据分析的主力Datasets/DataFrames

DataSets概述

Spark中,Dataset就是一组各式各样的列,类似一张excel表格或关系型数据库中的表。可以用于类型检查和语义化查询。

在R和Python语言中,使用的依然是DataFrame类,但是包含了所有的DataSet APIs。因此可以这样认为,DataSet在Python和R语言中就叫做DataFrame。

在Scala和Java语言中,使用的是DataSet接口,不存在DataFrame。

DataSet API概览

DataSet APIs

下面的例子使用的数据文件下载地址:

https://github.com/xsankar/fdps-v3

常见的Dataset接口和函数

读写操作

这个例子读取的csv文件内容如下:

1
2
3
4
5
6
$ cat data/spark-csv/cars.csv
year,make,model,comment,blank
2012,Tesla,S,No comment,
1997,Ford,E350,Go get one now they are going fast,
2015,Chevy,Volt,,
2016,Volvo,XC90,Good Car !,

读取和保存

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf-8 -*-

from pyspark.sql import SparkSession
import os

spark = SparkSession.builder.appName("Datasets APIs Demo")\
    .master("local")\
    .config("spark.logConf", "true")\
    .config("spark.logLevel", "ERROR")\
    .getOrCreate()
print(f"Running Spark Version {spark.version}")

filePath = "."
cars = spark.read.option("header", "true")\
        .option("inferSchema", "true")\
        .csv(os.path.join(filePath, "data/spark-csv/cars.csv"))

print(f"Cars has {cars.count()} rows")

cars.printSchema()


# 保存为csv格式
cars.write.mode("overwrite").option("header", "true").csv(
    os.path.join(filePath, "data/cars-out-csv.csv")
)

# 按年分区,保存为parquet格式
cars.write.mode("overwrite").partitionBy("year").parquet(
    os.path.join(filePath, "data/cars-out-pqt"))
spark.stop()

mode参数有:

  • overwrite: 如果文件存在,则覆写已经存在的文件
  • append: 追加写入到已经存在的文件
  • ignore: 会忽略Data Exists, 并且不会保存数据,使用该模式需要特别注意
  • error: 如果文件存在,抛出异常

查看保存结果

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
$ tree data/cars-out-csv.csv data/cars-out-pqt
data/cars-out-csv.csv
├── _SUCCESS
└── part-00000-aa27bc5f-aaeb-4f0c-804f-2d5320803d0d-c000.csv
data/cars-out-pqt
├── _SUCCESS
├── year=1997
│   └── part-00000-160cc90f-7c7a-4d88-b1c8-16c4fd8ec831.c000.snappy.parquet
├── year=2012
│   └── part-00000-160cc90f-7c7a-4d88-b1c8-16c4fd8ec831.c000.snappy.parquet
├── year=2015
│   └── part-00000-160cc90f-7c7a-4d88-b1c8-16c4fd8ec831.c000.snappy.parquet
└── year=2016
    └── part-00000-160cc90f-7c7a-4d88-b1c8-16c4fd8ec831.c000.snappy.parquet

4 directories, 7 files

从保存结果可以看出,保存为csv文件时,是一个目录,而不是单个的csv文件。

保存为parquet格式时,由于按照年来分区,因此生成了4个年份的子目录。这样做的好处是可以节约查询时间,比如当查询语句包含year=2015, 只会去查询year=2015这个目录的数据。在数据量大时,非常有助于提高查询效率。

读取保存的数据

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-

from pyspark.sql import SparkSession
import os

spark = SparkSession.builder.appName("Datasets APIs Demo")\
    .master("local")\
    .config("spark.logConf", "true")\
    .config("spark.logLevel", "ERROR")\
    .getOrCreate()
print(f"Running Spark Version {spark.version}")

filePath = "."
cars = spark.read.parquet(os.path.join(filePath, "data/cars-out-pqt"))

print(f"Cars has {cars.count()} rows")
cars.printSchema()

spark.stop()

聚合函数

首页让我们加载测试数据

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
>>> import os

>>> filePath = "."

>>> cars = spark.read.option('header', 'true')\
...         .option('inferSchema', 'true')\
...         .csv(os.path.join(filePath, 'data/car-data/car-mileage.csv'))

>>> print(f"Cars has {cars.count()} rows")
Cars has 32 rows

>>> cars.printSchema()
root
 |-- mpg: double (nullable = true)
 |-- displacement: double (nullable = true)
 |-- hp: integer (nullable = true)
 |-- torque: integer (nullable = true)
 |-- CRatio: double (nullable = true)
 |-- RARatio: double (nullable = true)
 |-- CarbBarrells: integer (nullable = true)
 |-- NoOfSpeed: integer (nullable = true)
 |-- length: double (nullable = true)
 |-- width: double (nullable = true)
 |-- weight: integer (nullable = true)
 |-- automatic: integer (nullable = true)

>>> cars.show(5)
+-----+------------+---+------+------+-------+------------+---------+------+-----+------+---------+
|  mpg|displacement| hp|torque|CRatio|RARatio|CarbBarrells|NoOfSpeed|length|width|weight|automatic|
+-----+------------+---+------+------+-------+------------+---------+------+-----+------+---------+
| 18.9|       350.0|165|   260|   8.0|   2.56|           4|        3| 200.3| 69.9|  3910|        1|
| 17.0|       350.0|170|   275|   8.5|   2.56|           4|        3| 199.6| 72.9|  3860|        1|
| 20.0|       250.0|105|   185|  8.25|   2.73|           1|        3| 196.7| 72.2|  3510|        1|
|18.25|       351.0|143|   255|   8.0|    3.0|           2|        3| 199.9| 74.0|  3890|        1|
|20.07|       225.0| 95|   170|   8.4|   2.76|           1|        3| 194.1| 71.8|  3365|        0|
+-----+------------+---+------+------+-------+------------+---------+------+-----+------+---------+
only showing top 5 rows

descripe统计列

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
>>> cars.describe("mpg","hp","weight","automatic").show()
+-------+-----------------+-----------------+----------------+-------------------+
|summary|              mpg|               hp|          weight|          automatic|
+-------+-----------------+-----------------+----------------+-------------------+
|  count|               32|               32|              32|                 32|
|   mean|        20.223125|          136.875|       3586.6875|            0.71875|
| stddev|6.318289089312789|44.98082028541039|947.943187269323|0.45680340939917435|
|    min|             11.2|               70|            1905|                  0|
|    max|             36.5|              223|            5430|                  1|
+-------+-----------------+-----------------+----------------+-------------------+

分组统计

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
>>> cars.groupBy("automatic").avg("mpg","torque").show()
+---------+------------------+-----------------+
|automatic|          avg(mpg)|      avg(torque)|
+---------+------------------+-----------------+
|        1|17.324782608695646|257.3636363636364|
|        0|27.630000000000006|          109.375|
+---------+------------------+-----------------+

# 不分组的情况
>>> cars.groupBy().avg("mpg","torque").show()
+---------+-----------+
| avg(mpg)|avg(torque)|
+---------+-----------+
|20.223125|      217.9|
+---------+-----------+

其它的聚合函数

1
2
3
4
5
6
7
8
>>> from pyspark.sql.functions import avg, mean

>>> cars.agg(avg(cars["mpg"]), mean(cars["torque"]) ).show()
+---------+-----------+
| avg(mpg)|avg(torque)|
+---------+-----------+
|20.223125|      217.9|
+---------+-----------+

统计函数

从前面的DataSets API概览图里可以看到,统计相关的函数都是在sql.stat.*下面。

下面看一个利用统计函数的例子

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
>>> import os

>>> filePath = "."

>>> cars = spark.read.option('header', 'true')\
...         .option('inferSchema', 'true')\
...         .csv(os.path.join(filePath, 'data/car-data/car-mileage.csv'))

>>> cars.show(5)
+-----+------------+---+------+------+-------+------------+---------+------+-----+------+---------+
|  mpg|displacement| hp|torque|CRatio|RARatio|CarbBarrells|NoOfSpeed|length|width|weight|automatic|
+-----+------------+---+------+------+-------+------------+---------+------+-----+------+---------+
| 18.9|       350.0|165|   260|   8.0|   2.56|           4|        3| 200.3| 69.9|  3910|        1|
| 17.0|       350.0|170|   275|   8.5|   2.56|           4|        3| 199.6| 72.9|  3860|        1|
| 20.0|       250.0|105|   185|  8.25|   2.73|           1|        3| 196.7| 72.2|  3510|        1|
|18.25|       351.0|143|   255|   8.0|    3.0|           2|        3| 199.9| 74.0|  3890|        1|
|20.07|       225.0| 95|   170|   8.4|   2.76|           1|        3| 194.1| 71.8|  3365|        0|
+-----+------------+---+------+------+-------+------------+---------+------+-----+------+---------+
only showing top 5 rows

# 计算相关性
>>> cor = cars.stat.corr("hp", "weight")

>>> print("hp to weight: Correlation = %.4f" % cor)
hp to weight: Correlation = 0.8834

# 计算协方差
>>> cov = cars.stat.cov("hp", "weight")

>>> print("hp to weight: Covariance = %.4f" % cov)
hp to weight: Covariance = 37667.5403

# 交叉表显示
>>> cars.stat.crosstab("automatic", "NoOfSpeed").show()
+-------------------+---+---+---+
|automatic_NoOfSpeed|  3|  4|  5|
+-------------------+---+---+---+
|                  1| 23|  0|  0|
|                  0|  1|  5|  3|
+-------------------+---+---+---+

交叉表的功能是分组统计数据。上面的交叉表类似下面的效果

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
>>> cars.createOrReplaceGlobalTempView("Cars")

>>> spark.sql("Select automatic, NoOfSpeed, count(*) from global_temp.Cars group by automatic, NoOfSpeed").show()
+---------+---------+--------+
|automatic|NoOfSpeed|count(1)|
+---------+---------+--------+
|        0|        5|       3|
|        1|        3|      23|
|        0|        3|       1|
|        0|        4|       5|
+---------+---------+--------+

交叉表在数据统计中非常有用,有利于我们观察各组数据直接的关系。下面是利用交叉表查看泰坦尼克号中幸存者和他的属性关系。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
>>> import os

>>> filePath = "."

>>> passengers = spark.read.option("header", "true")\
...     .option("inferSchema", "true")\
...     .csv(os.path.join(filePath, 'data/titanic3_02.csv'))

>>> result = passengers.select(passengers['Pclass'], passengers['Survived'], passengers['Gender'], passengers['Age'],
...                            passengers['SibSp'], passengers['Parch'],passengers['Fare'])
...

# Pclass: 票的等级
# SibSp: 兄弟姐妹/配偶在船上
# Parch: 父母/子女在船上
# Fare: 票价
>>> result.show(5)
+------+--------+------+------+-----+-----+--------+
|Pclass|Survived|Gender|   Age|SibSp|Parch|    Fare|
+------+--------+------+------+-----+-----+--------+
|     1|       1|female|  29.0|    0|    0|211.3375|
|     1|       1|  male|0.9167|    1|    2|  151.55|
|     1|       0|female|   2.0|    1|    2|  151.55|
|     1|       0|  male|  30.0|    1|    2|  151.55|
|     1|       0|female|  25.0|    1|    2|  151.55|
+------+--------+------+------+-----+-----+--------+
only showing top 5 rows

>>> result.printSchema()
root
 |-- Pclass: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Fare: double (nullable = true)

>>> result.groupBy('Gender').count().show()
+------+-----+
|Gender|count|
+------+-----+
|female|  466|
|  male|  843|
+------+-----+

>>> result.stat.crosstab("Survived", "Gender").show()
+---------------+------+----+
|Survived_Gender|female|male|
+---------------+------+----+
|              1|   339| 161|
|              0|   127| 682|
+---------------+------+----+

>>> result.stat.crosstab("Survived","SibSp").show()
+--------------+---+---+---+---+---+---+---+
|Survived_SibSp|  0|  1|  2|  3|  4|  5|  8|
+--------------+---+---+---+---+---+---+---+
|             1|309|163| 19|  6|  3|  0|  0|
|             0|582|156| 23| 14| 19|  6|  9|
+--------------+---+---+---+---+---+---+---+

# 因为Age是浮点型,直接用Age创建的交叉表的结果非常长,不便于观察。所以将它转换为整型
>>> ageDist = result.select(result["Survived"], (result["age"] - result["age"] % 10).cast("int").name("AgeBracket"))

>>> ageDist.show(5)
+--------+----------+
|Survived|AgeBracket|
+--------+----------+
|       1|        20|
|       1|         0|
|       0|         0|
|       0|        30|
|       0|        20|
+--------+----------+
only showing top 5 rows

# 20s和30s之间的人最多,还有很多不知道年龄的人
>>> ageDist.crosstab("Survived", "AgeBracket").show()
+-------------------+---+---+---+---+---+---+---+---+---+----+
|Survived_AgeBracket|  0| 10| 20| 30| 40| 50| 60| 70| 80|null|
+-------------------+---+---+---+---+---+---+---+---+---+----+
|                  1| 50| 56|127| 98| 52| 32| 10|  1|  1|  73|
|                  0| 32| 87|217|134| 83| 38| 22|  6|  0| 190|
+-------------------+---+---+---+---+---+---+---+---+---+----+

科学计算函数

科学统计函数基本都在sql.functions下面。

如: log(), log10(), sqrt(), cbrt(), exp(), pow(), sin(), cos(), tan(), acos(), asin(), atan(), toDegrees(), toRadians()

使用示例如下:

创建一个DataFrame

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
>>> from pyspark.sql import Row

>>> row = Row("val")

>>> l = [1,2,3]

>>> rdd = sc.parallelize(l)

# 有两种方式将rdd转换为dataframe
# 方法一:
>>> df = spark.createDataFrame(rdd.map(row))

# 方法二
# >>> df = rdd.map(row).toDF()

>>> df.show()
+---+
|val|
+---+
|  1|
|  2|
|  3|
+---+

一些示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
>>> from pyspark.sql.functions import log, log10, sartlog1p

>>> df.select(df['val'], log(df['val']).name('ln')).show()
+---+------------------+
|val|                ln|
+---+------------------+
|  1|               0.0|
|  2|0.6931471805599453|
|  3|1.0986122886681098|
+---+------------------+

>>> df.select(df['val'], log10(df['val']).name('log10')).show()
+---+-------------------+
|val|              log10|
+---+-------------------+
|  1|                0.0|
|  2| 0.3010299956639812|
|  3|0.47712125471966244|
+---+-------------------+

>>> df.select(df['val'], sqrt(df['val']).name('sqrt')).show()
+---+------------------+
|val|              sqrt|
+---+------------------+
|  1|               1.0|
|  2|1.4142135623730951|
|  3|1.7320508075688772|
+---+------------------+

>>> df.select(df['val'], log1p(df['val']).name('ln1p')).show()
+---+------------------+
|val|              ln1p|
+---+------------------+
|  1|0.6931471805599453|
|  2|1.0986122886681096|
|  3|1.3862943611198906|
+---+------------------+

对于给定的直角三角形的两个直角边,求其斜边的长度

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
>>> import os

>>> filePath = '.'

>>> data = spark.read.option("header", "true")\
...                  .option("inferScheam", "true")\
...                  .csv(os.path.join(filePath, "data/hypot.csv"))

>>>

>>> data.show()
+---+---+
|  X|  Y|
+---+---+
|  3|  4|
|  5| 12|
|  7| 24|
|  9| 40|
| 11| 60|
| 13| 84|
+---+---+

>>> from pyspark.sql.functions import hypot

>>> data.select(data["X"], data["Y"],
...         hypot(data["X"], data["Y"]).name("hypot")).show()
+---+---+-----+
|  X|  Y|hypot|
+---+---+-----+
|  3|  4|  5.0|
|  5| 12| 13.0|
|  7| 24| 25.0|
|  9| 40| 41.0|
| 11| 60| 61.0|
| 13| 84| 85.0|
+---+---+-----+

实战

下面让我们通过一个实战,使用前面学到的接口和函数。

我们使用Northwind Sales的数据集来分析下面的问题:

  • How many orders were placed by each customer?
  • How many orders were placed in each country?
  • How many orders were placed for each month/year?
  • What is the total number of sales for each customer, year-wise?
  • What * is the average order by customer, year-wise?

读取数据

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
>>> import os

>>> filePath = '.'

>>> orders = spark.read.option("header", "true")\
...                    .option("inferSchema", "true")\
...                    .csv(os.path.join(filePath, "data/NW/NW-Orders-01.csv")
... )

>>> print(f"Orders has {orders.count()} rows")
Orders has 830 rows

>>> orders.show(5)
+-------+----------+----------+-------------------+-----------+
|OrderID|CustomerID|EmployeeID|          OrderDate|ShipCountry|
+-------+----------+----------+-------------------+-----------+
|  10248|     VINET|         5|1996-07-02 00:00:00|     France|
|  10249|     TOMSP|         6|1996-07-03 00:00:00|    Germany|
|  10250|     HANAR|         4|1996-07-06 00:00:00|     Brazil|
|  10251|     VICTE|         3|1996-07-06 00:00:00|     France|
|  10252|     SUPRD|         4|1996-07-07 00:00:00|    Belgium|
+-------+----------+----------+-------------------+-----------+
only showing top 5 rows

>>> orders.printSchema()
root
 |-- OrderID: integer (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- EmployeeID: integer (nullable = true)
 |-- OrderDate: timestamp (nullable = true)
 |-- ShipCountry: string (nullable = true)


>>> orderDetails = spark.read.option("header", "true")\
...                    .option("inferSchema", "true")\
...                    .csv(os.path.join(filePath, "data/NW/NW-Order-Details.csv"))

>>> print(f"Order Details has {orderDetails.count()} rows")
Order Details has 2155 rows

>>> orderDetails.show(5)
+-------+---------+---------+---+--------+
|OrderID|ProductId|UnitPrice|Qty|Discount|
+-------+---------+---------+---+--------+
|  10248|       11|     14.0| 12|     0.0|
|  10248|       42|      9.8| 10|     0.0|
|  10248|       72|     34.8|  5|     0.0|
|  10249|       14|     18.6|  9|     0.0|
|  10249|       51|     42.4| 40|     0.0|
+-------+---------+---------+---+--------+
only showing top 5 rows


>>> orderDetails.printSchema()
root
 |-- OrderID: integer (nullable = true)
 |-- ProductId: integer (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- Qty: integer (nullable = true)
 |-- Discount: double (nullable = true)

解答问题1: How many orders were placed by each customer?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
>>> orderByCustomer = orders.groupBy("CustomerID").count()

>>> from pyspark.sql.functions import desc

>>> orderByCustomer.sort(desc("count")).show(5)
+----------+-----+
|CustomerID|count|
+----------+-----+
|     SAVEA|   31|
|     ERNSH|   30|
|     QUICK|   28|
|     HUNGO|   19|
|     FOLKO|   19|
+----------+-----+

解答问题2: How many orders were placed in each country?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
>>> orderByCountry = orders.groupBy("ShipCountry").count()

>>> orderByCountry.sort(desc("count")).show(5)
+-----------+-----+
|ShipCountry|count|
+-----------+-----+
|    Germany|  122|
|        USA|  122|
|     Brazil|   82|
|     France|   77|
|         UK|   56|
+-----------+-----+
only showing top 5 rows

后面三个问题可以如下解答

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# -*- coding: utf-8 -*-

from pyspark.sql import SparkSession
from pyspark.sql.functions import to_date, month, year
import os

spark = SparkSession.builder.appName("Datasets APIs Demo") \
    .master("local") \
    .config("spark.logConf", "true") \
    .config("spark.logLevel", "ERROR") \
    .getOrCreate()
print(f"Running Spark Version {spark.version}")

filePath = "."
orders = spark.read.option("header", "true") \
    .option("inferSchema", "true") \
    .csv(os.path.join(filePath, "data/NW/NW-Orders-01.csv"))

orderDetails = spark.read.option("header", "true") \
    .option("inferSchema", "true") \
    .csv(os.path.join(filePath, "data/NW/NW-Order-Details.csv"))

result = orderDetails.select(orderDetails['OrderID'],
                             ((orderDetails['UnitPrice'] * orderDetails['Qty'] - (
                                 orderDetails['UnitPrice'] * orderDetails['Qty'] * orderDetails['Discount'])).
                              name('OrderPrice')))

result.show(5)

# 设置DataFrame orderTot的别名为OrderTotal
orderTot = result.groupBy('OrderID').sum('OrderPrice').alias('OrderTotal')
orderTot.sort("OrderID").show(5)

orders_df = orders.join(orderTot, orders["OrderID"] == orderTot["OrderID"], "inner") \
    .select(orders["OrderID"],
            orders["CustomerID"],
            orders["OrderDate"],
            orders["ShipCountry"].alias("ShipCountry"),
            orderTot["sum(OrderPrice)"].alias("Total"))

orders_df.sort("CustomerID").show()

# 过滤出空行
orders_df.filter(orders_df["Total"].isNull()).show()

# 添加Date列
orders_df2 = orders_df.withColumn('Date', to_date(orders_df['OrderDate']))
orders_df2.show(5)
orders_df2.printSchema()

# 添加Month和Year列
orders_df3 = orders_df2.withColumn("Month", month(orders_df2["OrderDate"])) \
    .withColumn("Year", year(orders_df2["OrderDate"]))
orders_df3.show(5)
orders_df3.printSchema()

# 问题3
ordersByYM = orders_df3.groupBy("Year", "Month").sum("Total").alias("Total")
ordersByYM.sort(ordersByYM["Year"], ordersByYM["Month"]).show()

# 问题4
ordersByCY = orders_df3.groupBy("CustomerID", "Year").sum("Total").alias("Total")
ordersByCY.sort(ordersByCY["CustomerID"], ordersByCY["Year"]).show()

# 问题5
ordersCA = orders_df3.groupBy("CustomerID").avg("Total").alias("Total")
ordersCA.sort("avg(Total)", ascending=False).show()

spark.stop()