0
点赞
收藏
分享

微信扫一扫

sklearn如何保存模型

星巢文化 2022-05-24 阅读 41

问题

用sklearn训练的模型,如何将其参数保存,方便下次调用

模型

gbr = GBR(random_state=1412) # 实例化
gbr.fit(X, y.ravel()) # 训练模型

方法

常用方法 joblib 和 pickle 库

保存模型

  • joblib
# from sklearn.externals import joblib # 低版本Scikit-learn 0.21版本以下
import joblib # 新版本 Scikit-learn
joblib.dump(gbr, "train_model.m")
  • pickle
import pickle
with open('train_model.pkl', 'wb') as f:
    pickle.dump(gbr, f)

读取模型

  • joblib
import joblib
gbr = joblib.load("train_model.m")
  • pickle
import pickle
with open('train_model.pkl', 'rb') as f:
    gbr = pickle.load(f)

不同架构(Java、C++等)

  • 官方建议使用Open Neural Network Exchange 格式或Predictive Model Markup Language (PMML) 格式导出
  • 参考文章3给出了存储成json格式的方式
import json
import numpy as np
class MyLogReg(LogisticRegression):
    # Override the class constructor
    def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
        LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
        self.X_train = X_train
        self.Y_train = Y_train
    # A method for saving object data to JSON file
    def save_json(self, filepath):
        dict_ = {}
        dict_['C'] = self.C
        dict_['max_iter'] = self.max_iter
        dict_['solver'] = self.solver
        dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
        dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
        # Creat json and save to file
        json_txt = json.dumps(dict_, indent=4)
        with open(filepath, 'w') as file:
            file.write(json_txt)
    # A method for loading data from JSON file
    def load_json(self, filepath):
        with open(filepath, 'r') as file:
            dict_ = json.load(file)
        self.C = dict_['C']
        self.max_iter = dict_['max_iter']
        self.solver = dict_['solver']
        self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
        self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None

存储和查看方法

filepath = "mylogreg.json"
# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)
mylogreg.save_json(filepath)
# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()
json_mylogreg.load_json(filepath)
json_mylogreg

限制

Pickle 和 Joblib

  • 兼容性问题
    Pickle 和 Joblib 的最大缺点就是其兼容性问题,可能与不同模型不同版本的 scikit-learn 或 Python 版本有关。
  • 安全问题
    Pickle(以及扩展的 Joblib)在可维护性和安全性方面存在一些问题。

JSON

  • 安全性较低
  • 适用于实例变量较少的对象

相关文章:

  1. Model persistence
  2. sklearn2pmml
  3. sklearn 模型的保存与加载
举报

相关推荐

0 条评论