CatBoost快速入门

56 篇文章 4 订阅
订阅专栏

本文代码已上传CSDN,点我下载

文章目录

  • 备注
  • 推荐阅读
  • 简介
  • 安装
  • 初试
  • 可视化
  • 决策树
  • 特征重要性
  • 最优模型
  • 调用GPU
  • 参考文献




备注

该库貌似仍不稳定,我在继续训练的时候找到一个BUG Can Not Training continuation(2020.2.27 版本0.21),现在已修复了




推荐阅读

MNIST & CatBoost保存模型并预测
快速掌握CatBoost基本用法




简介

CatBoost是一款高性能机器学习开源库,基于GBDT,由俄罗斯搜索巨头Yandex在2017年开源。

那么CatBoost与其他Boosting算法如LightGBM和XGBoost相比如何呢?

在质量上,无论是fine-tuned后还是默认情况下,CatBoost的loss优于其他三个框架。
在这里插入图片描述
在速度上,CatBoost在Epsilon和Higgs数据集上与对手进行了比较,在GPU训练下完胜对手,在CPU训练下与LightGBM平分秋色。

Epsilon数据集(二分类2001个特征)
在这里插入图片描述
Higgs数据集(二分类29个特征)
在这里插入图片描述
CatBoost特点有:

  1. 免调参高质量
  2. 支持类别特征
  3. 快速和可用GPU
  4. 提高准确性
  5. 快速预测

更多对比参见 Battle of the Boosting Algos: LGB, XGB, Catboost,建议自己运行一遍,本人运行与原文有出入—— XGBoost、LightGBM、Catboost对比




安装

GPU开箱即用,不用额外安装其他

pip install catboost

Jupyter可视化配置

pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension




初试

CatBoost内置数据集 Titanic,该数据集为二分类任务。

在这里插入图片描述
导入必要的包

from catboost.datasets import titanic
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split

读取数据集

# 数据集
titanic_train, titanic_test = titanic()
titanic_train.head(10)

在这里插入图片描述
有数据为空NaN,例如乘客编号为6的年龄。
有数据是离散值,例如姓名和船票编号。
认为对模型训练作用性不大,去掉。

remove = ['PassengerId', 'Name', 'Ticket', 'Cabin'] 
X = titanic_train.drop(remove, axis=1)  # 去掉无关信息
X = X.dropna(how='any', axis='rows')  # 去掉空值
y = X.pop('Survived')  # 标签
X.head()

在这里插入图片描述
结果如上,其中船舱等级、性别和登船码头(下标为0,1,6)显然为类别特征,而恰好CatBoost支持类别特征训练。

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

创建Pool对象,这是CatBoost自带的类,便于CatBoost库进行处理。
当然,CatBoost实现了sklearn的接口,直接使用pd.DataFrame类型的X_train, X_test, y_train, y_test训练也行。

# 定义池(CatBoost最快的处理方式)
cat_features = [0, 1, 6]  # 分类特征
train_pool = Pool(X_train, y_train, cat_features=cat_features)
test_pool = Pool(X_test, y_test, cat_features=cat_features)

定义CatBoost分类模型

# 定义模型
model = CatBoostClassifier()

训练,参数含义分别是:train_pool训练数据,eval_set验证集,plot可视化,silent不输出训练过程,use_best_model使用最优模型

# 训练
model.fit(train_pool, eval_set=test_pool, plot=True, silent=True, use_best_model=True)  #可视化,不输出过程,最优模型

在这里插入图片描述
查看最优结果和准确率

model.get_best_score()  # 最优loss
{'learn': {'Logloss': 0.14129628504561498},
 'validation': {'Logloss': 0.471373085990394}}
model.score(test_pool) #准确率
0.8111888111888111

最后保存模型

model.save_model('titanic.model') # 保存模型

加载模型

del model
model = CatBoostClassifier()
model.load_model('titanic.model')

查看测试集数据

print(X_test[:10])
print(y_test[:10])
     Pclass     Sex   Age  SibSp  Parch      Fare Embarked
641       1  female  24.0      0      0   69.3000        C
496       1  female  54.0      1      0   78.2667        C
262       1    male  52.0      1      1   79.6500        S
311       1  female  18.0      2      2  262.3750        C
551       2    male  27.0      0      0   26.0000        S
550       1    male  17.0      0      2  110.8833        C
279       3  female  35.0      1      1   20.2500        S
268       1  female  58.0      0      1  153.4625        S
110       1    male  47.0      0      0   52.0000        S
554       3  female  22.0      0      0    7.7750        S
641    1
496    1
262    0
311    1
551    0
550    1
279    1
268    1
110    0
554    1
Name: Survived, dtype: int64

使用模型进行预测

model.predict(X_test[:10])  #预测

可以看到前5个都对了,后5个错得有点多

array([1, 1, 0, 1, 0, 0, 0, 1, 0, 0], dtype=int64)

使用模型进行概率预测

model.predict_proba(X_test[:10])  #预测概率
array([[0.02731782, 0.97268218],
       [0.03240048, 0.96759952],
       [0.63710499, 0.36289501],
       [0.03272136, 0.96727864],
       [0.80136214, 0.19863786],
       [0.64224485, 0.35775515],
       [0.64860225, 0.35139775],
       [0.06276485, 0.93723515],
       [0.64481127, 0.35518873],
       [0.58364375, 0.41635625]])

继续训练

new_model = CatBoostClassifier()
new_model.fit(test_pool, plot=True, silent=True, init_model='titanic.model') # 继续训练

在这里插入图片描述




可视化

fit()时加入参数plot=True

model.fit(X_train, y_train, plot=True)

在这里插入图片描述




决策树

调用 plot_tree()tree_idx为树的索引

model.plot_tree(tree_idx=0, pool=test_pool)

在这里插入图片描述




特征重要性

调用模型属性model.feature_importances_

for i,j in zip(X.columns, model.feature_importances_):
    print('{}: {:.2f}%'.format(i,j))
Pclass: 18.62%
Sex: 46.79%
Age: 12.47%
SibSp: 4.68%
Parch: 2.16%
Fare: 10.65%
Embarked: 4.63%
%matplotlib inline
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def feature_importances(df, model):
    max_num_features=10
    feature_importances = pd.DataFrame(columns = ['feature', 'importance'])
    feature_importances['feature'] = df.columns
    feature_importances['importance'] = model.feature_importances_
    feature_importances.sort_values(by='importance', ascending=False, inplace=True)
    feature_importances = feature_importances[:max_num_features]
    plt.figure(figsize=(12, 6));
    sns.barplot(x="importance", y="feature", data=feature_importances);
    plt.title('CatBoost features importance');
feature_importances(X, model)

在这里插入图片描述
看来最决定生死的前三个因素是性别、船舱等级和年龄。




最优模型

fit()时加入参数use_best_model=True

model.fit(X_train, y_train, use_best_model=True)




调用GPU

定义模型时加入参数task_type="GPU"

model = CatBoostClassifier(task_type="GPU")
model.fit(X_train, y_train)

如果需要GPU支持,系统编译器必须与CUDA Toolkit兼容。
若报错请自行编译 CatBoost Build from source on Windows




参考文献

  1. CatBoost - open-source gradient boosting library
  2. Quick start - CatBoost. Documentation
  3. CatBoost tutorials
  4. 机器学习算法之Catboost
  5. MNIST & Catboost保存模型并预测
大数据分析案例-基于Catboost+LGBM算法构建银行客户流失预测模型
m0_64336780的博客
04-26 6187
本项目旨在通过分析某银行客户数据集,通过可视化分析找出影响客户流失的因素,最后实验机器学习中的Catboost、XGBoost、LGBM等集成算法构建银行客户流失预测模型,提高银行客户管理水平。心得与体会:通过这次Python项目实战,我学到了许多新的知识,这是一个让我把书本上的理论知识运用于实践中的好机会。原先,学的时候感叹学的资料太难懂,此刻想来,有些其实并不难,关键在于理解。在这次实战中还锻炼了我其他方面的潜力,提高了我的综合素质。
catboost参数详解及实战(强推)
机器学习、深度学习、文本分类、异常检测、风控等知识的积累和分享
07-04 2万+
catboost参数详解(史上最细),以及实战贝叶斯调参
Catboost
kuxingseng123的博客
01-24 1445
慢慢的将集成学习全部都将其搞定。
最详细的Catboost参数详解与实例应用
代码届的小白的博客
12-16 4万+
集成学习的两大准则:基学习器的准确性和多样性。 算法:串行的Boosting和并行的Bagging,前者通过错判训练样本重新赋权来重复训练,来提高基学习器的准确性,降低偏差!后者通过采样方法,训练出多样性的基学习器,降低方差。 文章目录1.CatBoost简介1.1CatBoost介绍1.2CatBoost优缺点1.3CatBoost安装2.参数详解2.1通用参数:2.2默认参数2.3性能参数2.4参数调优3.CatBoost实战应用3.1回归案例3.2使用Pool加载数据集并进行预测3.3多分类案例..
Catboost原理详解
机器学习、深度学习、文本分类、异常检测、风控等知识的积累和分享
07-25 3294
对于类别型变量而言,xgb需要先自行编码、才能输入模型;lgb极大地简化了一步,只需要将相应的变量列转化为category、或指定类别型变量名即可输入模型;catboost进一步处理,不仅嵌入了对类别型变量的处理,并附带类别型特征交叉功能、还加入了部分文本数据的处理。本文深入浅出地详解catboost,全篇通俗易懂帮助大家掌握原理。...
catboost使用方法
weixin_44414593的博客
08-11 2435
惊呆了,这么多参数,见鬼。 import catboost as cb 官方文档 参考CSDN文章 参数 loss_function损失函数, 可选RMSE,Logloss,MAE,CrossEntropy random_seed随机性种子 one_hot_max_size 是否对某些特征进行one-hot编码 custom_metric 自定义监控指标, 可选RMSE,Logloss,MAE,CrossEntropy,Recall,Precision,F1,Accuracy,AUC,R2(具体怎么用我还
CatBoost快速入门.ipynb
02-27
简介、安装、初试、可视化、绘出决策树、最优模型、调用GPU等 https://blog.csdn.net/lly1122334/article/details/104517076
R语言catboost离线安装源码
09-08
欢迎大家下载catboost源码文件,原始文件在Github上,Github下载时间很长,所以在这里分享出来供大家下载。
CatBoost.pdf
04-29
catboost原论文,方便自己使用,也同时方便大家的使用,其实网络上也很好找的,这个可能不是随时都方便的。如有侵权,联系删除。
Catboost-MNIST.ipynb
02-20
MNIST & Catboost保存模型并预测 https://blog.csdn.net/lly1122334/article/details/104407869
tutorials:CatBoost教程资料库
05-14
CatBoost教程 基本的 最好从此基础教程开始进行CatBoost探索。 Python 本教程介绍了使用CatBoost的一些基本情况,例如模型训练,交叉验证和预测,以及一些有用的功能,如提早停止,快照支持,功能重要性和参数调整...
catboost】官方调参教程
xiangxiang613的专栏
05-20 1万+
CatBoost官方教程:调参 本文翻译至官方原文:https://catboost.ai/docs/concepts/parameter-tuning.html CatBoost为参数调整提供了灵活的界面,可以对其进行配置以适合不同的任务。 本节包含有关可能的参数设置的一些提示。 catBoost提供了为Python、R语言和命令行都提供了可使用的参数,其中Python和R的完全相同,命令行参数格式则有点不同。 如L2正则化参数,python和R中为:l2_leaf_reg ,命令行中为–l2-lea
如何使用 CatBoost 进行快速梯度提升
关注我!带你一路 "狂飙" 到底!
10-21 354
我们将仔细研究一个名为CatBoost的梯度增强库。在梯度提升中,预测是由一群弱学习者做出的。与为每个样本创建决策树的随机森林不同,在梯度增强中,树是一个接一个地创建的。模型中的先前树不会更改。前一棵树的结果用于改进下一棵树。在本文中,我们将仔细研究一个名为CatBoost的梯度增强库。CatBoost 是Yandex开发的深度方向梯度增强库。它使用遗忘的决策树来生成平衡树。相同的功能用于对树的每个级别进行左右拆分。
CatBoost参数解释和实战
热门推荐
林夕
06-18 5万+
据开发者所说超越Lightgbm和XGBoost的又一个神器,不过具体性能,还要看在比赛中的表现了。 整理一下里面简单的教程和参数介绍,很多参数不是那种重要,只解释部分重要的参数,训练时需要重点考虑的。 Quick start CatBoostClassifier import numpy as np import catboost as cb train_data = ...
安装 catboost 的正确方式
pertain99的博客
09-02 1万+
conda install CatBoost 原生支持 GPU。 首先添加频道。 conda config --add channels conda-forge 安装 CatBoost: conda install catboost 安装 visualization 工具: Install the ipywidgets Python package (version 7.x or higher...
CatBoost算法和调参
weixin_33881140的博客
07-01 1万+
 欢迎关注博主主页,学习python视频资源 sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频) https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campaign=commission&utm_source=cp-400000000398149&utm_medium=s...
catboost案例
u013939918的博客
06-25 1070
from catboost import CatBoostClassifier # 数据集 cat_features = [0, 1] # 类别特征下标 train_data = [["a", "b", 1, 4, 5, 6], ["a", "b", 4, 5, 6, 7], ["c", "d", 30, 40, 50, 60]] train_labels = [1, 1, -1] eval_data = [["a", "b", 2, 4, 6,...
决策树模型,XGBoost,LightGBM和CatBoost模型可视化
Sylvester
08-09 3万+
决策树模型,XGBoost和LightGBM模型可视化 安装 graphviz 参考文档 http://graphviz.readthedocs.io/en/stable/manual.html#installation graphviz安装包下载地址 https://www.graphviz.org/download/ 将graphviz的安装位置添加到系统环境变量 使用pip ins...
Catboost-算法原理
八刀一闪的专栏
03-15 2400
总结一下catboost关键的知识点 Target Statistics 常规处理类别特征的方法是one-hot,但是也可以将类别特征转化为和label相关的数值特征,也就是target statistics,最简单的方法就是计算概率值。(A target statistic is a simple statistical model itself, and it can also cause target leakage and a prediction shift.) 几种计算TS的方法: Gr
catboost算法
最新发布
09-12
CatBoost是一种能够很好地处理类别型特征的梯度提升算法库。它基于GPU实现学习算法,而打分算法则基于CPU实现。CatBoost具有以下主要特点: 1. 高效处理类别型特征:CatBoost能够直接处理类别型特征,无需进行独热编码等预处理操作,可以更好地捕捉类别型特征中的信息。 2. 自动处理缺失值:CatBoost能够自动处理缺失值,无需额外的处理步骤。 3. 自动特征转换:CatBoost可以自动将类别型特征转换为数值型特征,并且在模型训练中进行优化。 4. 支持多种评估指标:CatBoost支持多种评估指标,包括分类任务的准确率、AUC和F1-score等,以及回归任务的RMSE和MAE等。 5. 可解释性强:CatBoost可以提供特征重要性排序,帮助用户理解模型对特征的贡献程度。 6. 支持C++ API:CatBoost还提供了C++ API,可以在C++环境中使用CatBoost模型进行预测和推理。 总之,CatBoost是一种强大的梯度提升算法库,特别适用于处理类别型特征的机器学习任务。它具有高效处理类别型特征、自动处理缺失值和特征转换、支持多种评估指标以及强解释性等优点。此外,CatBoost还提供了C++ API,方便在C++环境中使用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
写文章

热门文章

  • MX Player不支持此音频格式(EAC3) 66055
  • Python类和方法注释规范 64791
  • matplotlib.pyplot.colormaps色彩图cmap 61510
  • Python构建快速高效的中文文字识别OCR 45561
  • Selenium实现点击click() 38766

分类专栏

  • Python 524篇
  • 数据库 12篇
  • Linux 7篇
  • Git 5篇
  • 笔记
  • Golang 3篇
  • 前端 14篇
  • Redis 5篇
  • 自动化测试 7篇
  • PyQt 11篇
  • C++ 10篇
  • Java 16篇
  • 机器学习 56篇
  • Tensorflow 37篇
  • Keras 16篇
  • imgaug 13篇
  • OpenCV 21篇
  • 其他 67篇

最新评论

  • imgaug数据增强神器:第四章 增强关键点/界标

    nobut: 哥们你好,请问找到类似代码没

  • Python照片隐写术——照片内嵌信息(含模型、测试图片、测试视频)

    九度微凉: 博主你好,我encode运行成功,但decode始终有问题,问题是Secret bits: [1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0.] Packet binary: 100000100101111001010000110010101101010101010010010110100101100100000111000100010001000110101000 Data: bytearray(b'\x82^P\xca\xd5RZ'), ECC: bytearray(b'Y\x07\x11\x11\xa8') Bitflips: -1,是哪里的问题呢

  • from scipy.interpolate import spline报错ImportError: cannot import name ‘spline‘

    2401_83177883: 靴靴大佬

  • MX Player不支持此音频格式(EAC3)

    2401_84606608: 老板在官网上没看到下载位置,能出个详细教程吗

  • MX Player不支持此音频格式(EAC3)

    2401_84606608: 老板,我的 mx 播放器 版本号,是 1.75.5(armv8 neon)在使用时候,音频不支持 eac3,网上查找需要下载一个 MX_FFmpeg,但是我在官网没找到下载位置,求给个下载位置

您愿意向朋友推荐“博客详情页”吗?

  • 强烈不推荐
  • 不推荐
  • 一般般
  • 推荐
  • 强烈推荐
提交

最新文章

  • Python json.dumps()添加转义符号
  • 羽毛球移动步法训练
  • Python国际化L10N方案
2024年7篇
2023年16篇
2022年63篇
2021年122篇
2020年265篇
2019年182篇
2018年21篇
2017年9篇
2016年1篇

目录

目录

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43元 前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

XerCis

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或 充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值

深圳SEO优化公司爱联关键词按天收费龙岗外贸网站建设南联关键词排名包年推广大浪百度竞价包年推广龙岗网站改版坂田如何制作网站坪地百姓网标王推广罗湖网站推广工具丹竹头网站优化推广西乡网站优化软件宝安网站推广荷坳优化石岩百搜词包盐田百姓网标王福永网络广告推广民治网站优化推广双龙阿里店铺托管布吉网站改版福永网站推广系统南澳网络广告推广坑梓百度爱采购荷坳品牌网站设计木棉湾网站改版民治外贸网站制作观澜网络广告推广布吉设计网站爱联百度网站优化排名福永SEO按天收费福田SEO按天收费南山网站优化按天扣费歼20紧急升空逼退外机英媒称团队夜以继日筹划王妃复出草木蔓发 春山在望成都发生巨响 当地回应60岁老人炒菠菜未焯水致肾病恶化男子涉嫌走私被判11年却一天牢没坐劳斯莱斯右转逼停直行车网传落水者说“没让你救”系谣言广东通报13岁男孩性侵女童不予立案贵州小伙回应在美国卖三蹦子火了淀粉肠小王子日销售额涨超10倍有个姐真把千机伞做出来了近3万元金手镯仅含足金十克呼北高速交通事故已致14人死亡杨洋拄拐现身医院国产伟哥去年销售近13亿男子给前妻转账 现任妻子起诉要回新基金只募集到26元还是员工自购男孩疑遭霸凌 家长讨说法被踢出群充个话费竟沦为间接洗钱工具新的一天从800个哈欠开始单亲妈妈陷入热恋 14岁儿子报警#春分立蛋大挑战#中国投资客涌入日本东京买房两大学生合买彩票中奖一人不认账新加坡主帅:唯一目标击败中国队月嫂回应掌掴婴儿是在赶虫子19岁小伙救下5人后溺亡 多方发声清明节放假3天调休1天张家界的山上“长”满了韩国人?开封王婆为何火了主播靠辱骂母亲走红被批捕封号代拍被何赛飞拿着魔杖追着打阿根廷将发行1万与2万面值的纸币库克现身上海为江西彩礼“减负”的“试婚人”因自嘲式简历走红的教授更新简介殡仪馆花卉高于市场价3倍还重复用网友称在豆瓣酱里吃出老鼠头315晚会后胖东来又人满为患了网友建议重庆地铁不准乘客携带菜筐特朗普谈“凯特王妃P图照”罗斯否认插足凯特王妃婚姻青海通报栏杆断裂小学生跌落住进ICU恒大被罚41.75亿到底怎么缴湖南一县政协主席疑涉刑案被控制茶百道就改标签日期致歉王树国3次鞠躬告别西交大师生张立群任西安交通大学校长杨倩无缘巴黎奥运

深圳SEO优化公司 XML地图 TXT地图 虚拟主机 SEO 网站制作 网站优化