按照行索引合并(join)
df1 = pd.DataFrame(np.zeros((3, 3)), index=list("ABC"), columns=list("abc"))
df2 = pd.DataFrame(np.ones((2, 2)), index=list("AB"), columns=list("mn"))
df1 和 df2 如下:
我们合并一下两个DataFrame
print(df1.join(df2))
print(df2.join(df1))
合并结果:
注:
- 以df1为基准合并时,df2没有行索引C,用NaN补齐
- 以df2为基准合并时,df1比df2多了行索引C,直接舍弃
按照列索引合并(merge)
构造两个DataFrame
四种连接方式结果如下(内外连接,左右连接):
星巴克案例
下面通过一个案例来说明分组聚合的用法
数据主要是全球星巴克门店的信息,包括所在地址,门店编号,联系方式等等
数据来源:星巴克数据
1、将csv文件下载后放到工作路径,接下来进行读取,查看数据
= "./directory.csv"
df = pd.read_csv(file_path)
print(df.head())
print(df.info())
前5行数据
数据的基本信息如下:
2、按照国家进行分组
先利用groupby方法分组,然后打印相关信息
= df.groupby(by="Country") # 按照国家分组后,多个DataFrame组成的集合
for i,j in grouped:
print("*" * 50)
print(i)
print("*"*50)
print(j)
print("-"*50)
打印结果:
于是,我们懂了,两个星花之间的表示分组依据,星花和横线之间的是每个分组的相关数据,也就是一个DataFrame。那么代码中groupby方法返回的就是多个DataFrame组成的一个DataFrameGroupBy对象
3、使用聚合函数
print(grouped.count())
聚合方法count返回的是各分类的某属性有多少条数据,例如US的Brand一栏有13608条数据
当然聚合方法还有
- sum:计算非NaN值的和
- mean:计算非NaN值的平均值
- median:计算非NaN值的算数中位数
- std、var:计算无偏(分母为n-1)的标准差和方差
- min、max:计算非NaN值的最小、最大值
4、统计中国每个省份门店的数量
= df[df["Country"]=="CN"]
grouped_province = cn_data.groupby(by="State/Province").count()["Brand"]
print(grouped_province)
按多个条件分组:统计每个省份各个城市门店的数量
cn_data = df[df["Country"]=="CN"]
# 按照多个条件分组
grouped = cn_data.groupby(by=["State/Province", "City"]).count()["Brand"]
print(grouped)
分组结果:
我们得到了有两个索引的Series
同样的,我们也可以返回与上述Series相同数据的DataFrame
= df[df["Country"]=="CN"]
grouped = cn_data.groupby(by=["State/Province", "City"]).count()[["Brand"]]
print(grouped)
print(type(grouped))
5、绘制中国星巴克数量排名前十的城市
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import font_manager
def main():
font_20 = font_manager.FontProperties(fname=r"C:\Windows\Fonts\simkai.ttf", size=20)
font_30 = font_manager.FontProperties(fname=r"C:\Windows\Fonts\timesi.ttf", size=30)
file_path = "./directory.csv"
df = pd.read_csv(file_path)
df = df[df["Country"] == "CN"]
# 获取中国星巴克数量排名前十的城市以及响应的门店数量
data1 = df.groupby(by="City").count()["Brand"].sort_values(ascending=False)[:10]
x = data1.index[::-1]
y = data1.values[::-1]
plt.figure(figsize=(20, 12), dpi=80)
plt.title("StarBucks in China's cities", fontproperties=font_30)
plt.barh(range(len(x)), y, height=0.4, color="orange")
plt.xticks(range(0,600,100), fontproperties=font_20)
plt.yticks(range(len(x)), x, fontproperties=font_20)
plt.savefig(fname="./starBucks_cn", dpi=500)
# plt.show()
if __name__ == '__main__':
main()