<kbd id="5sdj3"></kbd>
<th id="5sdj3"></th>

  • <dd id="5sdj3"><form id="5sdj3"></form></dd>
    <td id="5sdj3"><form id="5sdj3"><big id="5sdj3"></big></form></td><del id="5sdj3"></del>

  • <dd id="5sdj3"></dd>
    <dfn id="5sdj3"></dfn>
  • <th id="5sdj3"></th>
    <tfoot id="5sdj3"><menuitem id="5sdj3"></menuitem></tfoot>

  • <td id="5sdj3"><form id="5sdj3"><menu id="5sdj3"></menu></form></td>
  • <kbd id="5sdj3"><form id="5sdj3"></form></kbd>

    Apache SparkMLlib構(gòu)建機(jī)器學(xué)習(xí)分類模型

    共 27447字,需瀏覽 55分鐘

     ·

    2024-04-11 14:41

    一、引言

    1.1 Spark MLlib簡介

    Apache Spark MLlib(Machine Learning library)是一個開源機(jī)器學(xué)習(xí)框架,建立在Apache Spark之上,支持分布式計(jì)算和大規(guī)模數(shù)據(jù)處理。它提供了許多經(jīng)典機(jī)器學(xué)習(xí)算法和工具,如分類、回歸、聚類、協(xié)同過濾、特征提取和數(shù)據(jù)預(yù)處理等。

    Spark MLlib使用基于DataFrame的API,提供了一個易于使用的高級API,使得用戶能夠快速構(gòu)建、訓(xùn)練和調(diào)整機(jī)器學(xué)習(xí)模型,而無需擔(dān)心底層分布式計(jì)算的復(fù)雜性。它還支持分布式模型選擇和調(diào)整,以及與其他Apache Spark組件的集成,如Spark SQL、Spark Streaming和GraphX。

    Spark MLlib還提供了Python、Java和Scala等多種編程語言的API,使得不同的開發(fā)人員可以使用他們最喜歡的編程語言來開發(fā)機(jī)器學(xué)習(xí)應(yīng)用程序。

    總之,Spark MLlib是一個非常強(qiáng)大和靈活的機(jī)器學(xué)習(xí)框架,適用于處理大規(guī)模數(shù)據(jù)和需要分布式計(jì)算的場景。

    1.2 為什么選擇使用Spark MLlib

    1. 處理大規(guī)模數(shù)據(jù):Spark MLlib支持分布式計(jì)算和大規(guī)模數(shù)據(jù)處理,使得處理大規(guī)模數(shù)據(jù)集變得容易。1. 豐富的算法庫:Spark MLlib包含了許多經(jīng)典的機(jī)器學(xué)習(xí)算法和工具,如分類、回歸、聚類、協(xié)同過濾、特征提取和數(shù)據(jù)預(yù)處理等,覆蓋了大部分機(jī)器學(xué)習(xí)應(yīng)用場景。1. 高性能:Spark MLlib基于Apache Spark,使用內(nèi)存計(jì)算和RDD(彈性分布式數(shù)據(jù)集)等優(yōu)化技術(shù),可以在處理大規(guī)模數(shù)據(jù)時提供高性能和可擴(kuò)展性。1. 易于使用:Spark MLlib提供了一個易于使用的高級API,使得用戶可以快速構(gòu)建、訓(xùn)練和調(diào)整機(jī)器學(xué)習(xí)模型,而無需擔(dān)心底層分布式計(jì)算的復(fù)雜性。1. 多語言支持:Spark MLlib支持多種編程語言的API,包括Python、Java和Scala等,使得不同的開發(fā)人員可以使用他們最喜歡的編程語言來開發(fā)機(jī)器學(xué)習(xí)應(yīng)用程序。

    二、Spark MLlib基礎(chǔ)

    2.1 RDD和DataFrame的比較

    1. 數(shù)據(jù)類型:基礎(chǔ)RDD可以包含任意類型的數(shù)據(jù),包括對象、原始類型、數(shù)組和集合等;DataFrame則是一種表格化的數(shù)據(jù)結(jié)構(gòu),其數(shù)據(jù)類型必須是統(tǒng)一的,且可以使用SQL-like的語法進(jìn)行查詢。1. 內(nèi)存計(jì)算:DataFrame利用內(nèi)存計(jì)算技術(shù),相比基礎(chǔ)RDD更加高效。1. 可讀性:DataFrame比基礎(chǔ)RDD更加易于閱讀和理解,可以使用SQL-like的語法進(jìn)行查詢,更加直觀。1. 類型安全:DataFrame是類型安全的,可以在編譯期間捕獲類型錯誤,避免運(yùn)行時錯誤;而基礎(chǔ)RDD則是類型不安全的,需要在運(yùn)行時進(jìn)行類型檢查。1. 執(zhí)行計(jì)劃:基礎(chǔ)RDD提供了更加靈活的執(zhí)行計(jì)劃,用戶可以控制計(jì)算的方式和順序,但這也增加了開發(fā)復(fù)雜度;而DataFrame則有一個自動優(yōu)化的執(zhí)行計(jì)劃,可以自動優(yōu)化查詢性能。 總之,基礎(chǔ)RDD更加靈活和可控,但需要開發(fā)人員自己掌握計(jì)算的方式和順序;而DataFrame則更加易于使用和高效,適合快速開發(fā)和迭代。選擇使用哪種數(shù)據(jù)結(jié)構(gòu),取決于具體的場景和需求。

    2.2 數(shù)據(jù)準(zhǔn)備和預(yù)處理

    在使用Spark MLlib進(jìn)行機(jī)器學(xué)習(xí)之前,需要對原始數(shù)據(jù)進(jìn)行預(yù)處理和準(zhǔn)備。以下是一些常見的數(shù)據(jù)準(zhǔn)備和預(yù)處理步驟:

    1. 數(shù)據(jù)清洗:刪除缺失值、處理異常值和重復(fù)值等。1. 特征選擇:選擇對模型有用的特征,去除冗余和無關(guān)的特征。1. 特征縮放:對特征進(jìn)行縮放,以便它們具有相似的范圍和重要性。1. 特征變換:將原始特征轉(zhuǎn)換為更有意義的特征,如使用對數(shù)、指數(shù)、平方根等函數(shù)進(jìn)行變換。1. 特征歸一化:將特征值歸一化為標(biāo)準(zhǔn)正態(tài)分布,使得模型更容易學(xué)習(xí)。1. 數(shù)據(jù)轉(zhuǎn)換:將數(shù)據(jù)轉(zhuǎn)換為適合模型訓(xùn)練的格式,如將分類變量轉(zhuǎn)換為二進(jìn)制變量、將文本轉(zhuǎn)換為向量等。 在Spark MLlib中,可以使用各種預(yù)處理和數(shù)據(jù)準(zhǔn)備工具,如:
    2. Imputer:用于填充缺失值。1. StandardScaler:用于特征縮放和歸一化。1. VectorAssembler:用于將多個特征列組合成一個向量列。1. OneHotEncoder:用于將分類變量轉(zhuǎn)換為二進(jìn)制變量。1. StringIndexer和IndexToString:用于將字符串類型的變量轉(zhuǎn)換為數(shù)字類型的變量。1. Tokenizer和StopWordsRemover:用于將文本轉(zhuǎn)換為向量。 總之,在使用Spark MLlib進(jìn)行機(jī)器學(xué)習(xí)之前,需要對原始數(shù)據(jù)進(jìn)行預(yù)處理和準(zhǔn)備。Spark MLlib提供了許多工具和功能,可以幫助我們輕松地完成這些任務(wù)。

    2.3 特征提取和轉(zhuǎn)換

    在Spark MLlib中,有許多常用的特征提取和轉(zhuǎn)換工具,包括:

    1. Tokenizer:用于將文本轉(zhuǎn)換為單詞或詞條。1. StopWordsRemover:用于去除文本中的停用詞,如“the”、“and”等。1. CountVectorizer:用于將文本轉(zhuǎn)換為詞頻向量。1. HashingTF:用于將文本轉(zhuǎn)換為哈希向量,可以減少維度并提高計(jì)算效率。1. IDF:用于計(jì)算逆文檔頻率,可以減少常見詞語的權(quán)重,提高稀有詞語的權(quán)重。1. Word2Vec:用于將文本轉(zhuǎn)換為向量,可以捕捉詞語之間的語義關(guān)系。1. PCA:用于將高維特征空間降維,可以提高計(jì)算效率并避免過擬合。1. StringIndexer:用于將分類變量轉(zhuǎn)換為數(shù)字類型的變量。1. OneHotEncoder:用于將數(shù)字類型的變量轉(zhuǎn)換為二進(jìn)制變量。 以上這些工具都可以用于特征提取和轉(zhuǎn)換,幫助我們將原始數(shù)據(jù)轉(zhuǎn)換為模型可以處理的格式。我們可以根據(jù)具體的任務(wù)和數(shù)據(jù)類型選擇適當(dāng)?shù)墓ぞ撸垣@得更好的結(jié)果。值得注意的是,這些工具的使用通常需要進(jìn)行適當(dāng)?shù)膮?shù)設(shè)置和調(diào)整,以達(dá)到最佳的效果。

    三、監(jiān)督學(xué)習(xí)

    3.1 分類問題

    3.1.1 邏輯回歸

    邏輯回歸是一種二元分類模型,它的目標(biāo)是根據(jù)已知數(shù)據(jù)對一個事物進(jìn)行分類。邏輯回歸的輸出是一個概率值,代表該事物屬于某個類別的概率。如果概率值大于閾值,則將其分類為正類,否則分類為負(fù)類。

    在 Spark MLlib 中,可以使用 LogisticRegression 類來實(shí)現(xiàn)邏輯回歸。下面是一個 Java 版本的示例代碼:

    pom引用:

          
          <dependencies>
        <!-- Spark core dependencies -->
        <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-core_2.12</artifactId>
          <version>3.2.0</version>
        </dependency>
        <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-sql_2.12</artifactId>
          <version>3.2.0</version>
        </dependency>
        <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-mllib_2.12</artifactId>
          <version>3.2.0</version>
        </dependency>

        <!-- Spark testing dependencies (optional) -->
        <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-streaming_2.12</artifactId>
          <version>3.2.0</version>
          <scope>test</scope>
        </dependency>
        <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-streaming-kafka-0-10_2.12</artifactId>
          <version>3.2.0</version>
          <scope>test</scope>
        </dependency>
        <dependency>
          <groupId>org.apache.spark</groupId>
          <artifactId>spark-sql-kafka-0-10_2.12</artifactId>
          <version>3.2.0</version>
          <scope>test</scope>
        </dependency>
      </dependencies>
          
          import org.apache.spark.ml.classification.LogisticRegression;
    import org.apache.spark.ml.classification.LogisticRegressionModel;
    import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
    import org.apache.spark.ml.feature.VectorAssembler;
    import org.apache.spark.ml.linalg.Vector;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;

    public class LogisticRegressionDemo {

        public static void main(String[] args) {
            SparkSession spark = SparkSession
                    .builder()
                    .appName("LogisticRegressionDemo")
                    .master("local[*]")
                    .getOrCreate();

            // 加載數(shù)據(jù)
            Dataset<Row> data = spark.read().format("libsvm").load("data/sample_libsvm_data.txt");

            // 將特征向量轉(zhuǎn)換成一列
            VectorAssembler assembler = new VectorAssembler()
                    .setInputCols(new String[]{"features"})
                    .setOutputCol("feature");

            Dataset<Row> newData = assembler.transform(data).select("label""feature");

            // 將數(shù)據(jù)集分為訓(xùn)練集和測試集
            Dataset<Row>[] splits = newData.randomSplit(new double[]{0.70.3});
            Dataset<Row> trainData = splits[0];
            Dataset<Row> testData = splits[1];

            // 創(chuàng)建邏輯回歸模型
            LogisticRegression lr = new LogisticRegression();

            // 訓(xùn)練模型
            LogisticRegressionModel lrModel = lr.fit(trainData);

            // 在測試集上進(jìn)行預(yù)測
            Dataset<Row> predictions = lrModel.transform(testData);

            // 計(jì)算模型評估指標(biāo)
            BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
            double auc = evaluator.evaluate(predictions);

            System.out.println("Area under ROC curve = " + auc);

            spark.stop();
        }
    }

    這個示例代碼首先加載了一個 libsvm 格式的數(shù)據(jù)集,然后將特征向量轉(zhuǎn)換成一列,將數(shù)據(jù)集分為訓(xùn)練集和測試集,創(chuàng)建邏輯回歸模型并訓(xùn)練模型,最后在測試集上進(jìn)行預(yù)測并計(jì)算模型評估指標(biāo)。在這個例子中,我們使用了 BinaryClassificationEvaluator 來計(jì)算模型的 AUC 指標(biāo),它是評估二元分類器性能的一種常用指標(biāo)。

    需要注意的是,以上代碼僅供參考,實(shí)際情況可能需要根據(jù)數(shù)據(jù)集的特點(diǎn)和任務(wù)的要求進(jìn)行相應(yīng)的修改。

    3.1.2 決策樹

    Spark MLlib 分類決策樹是一種基于樹結(jié)構(gòu)的分類算法,通過一系列特征對數(shù)據(jù)進(jìn)行劃分和分類。該算法在 Spark MLlib 中的實(shí)現(xiàn)采用 CART(Classification And Regression Tree)算法,使用信息熵或 Gini 系數(shù)等指標(biāo)進(jìn)行特征選擇和劃分。Spark MLlib 分類決策樹可用于二分類、多分類和概率預(yù)測問題。

          
          import org.apache.spark.ml.Pipeline;
    import org.apache.spark.ml.PipelineModel;
    import org.apache.spark.ml.PipelineStage;
    import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
    import org.apache.spark.ml.classification.DecisionTreeClassifier;
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
    import org.apache.spark.ml.feature.IndexToString;
    import org.apache.spark.ml.feature.StringIndexer;
    import org.apache.spark.ml.feature.StringIndexerModel;
    import org.apache.spark.ml.feature.VectorAssembler;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;

    public class DecisionTreeClassificationExample {
      public static void main(String[] args) {
        SparkSession spark = SparkSession.builder()
          .appName("DecisionTreeClassificationExample")
          .master("local[*]")
          .getOrCreate();

        // 讀取數(shù)據(jù)集
        Dataset<Row> data = spark.read().format("csv")
          .option("header""true")
          .option("inferSchema""true")
          .load("path/to/data.csv");

        // 將標(biāo)簽列轉(zhuǎn)換為數(shù)值類型
        StringIndexerModel labelIndexer = new StringIndexer()
          .setInputCol("label")
          .setOutputCol("indexedLabel")
          .fit(data);
        data = labelIndexer.transform(data);

        // 將特征列轉(zhuǎn)換為特征向量
        VectorAssembler featureAssembler = new VectorAssembler()
          .setInputCols(new String[]{"feature1""feature2""feature3"})
          .setOutputCol("features");
        data = featureAssembler.transform(data);

        // 將數(shù)據(jù)集分為訓(xùn)練集和測試集
        Dataset<Row>[] splits = data.randomSplit(new double[]{0.70.3}, 12345);
        Dataset<Row> trainData = splits[0];
        Dataset<Row> testData = splits[1];

        // 創(chuàng)建決策樹分類器
        DecisionTreeClassifier dt = new DecisionTreeClassifier()
          .setLabelCol("indexedLabel")
          .setFeaturesCol("features");

        // 將標(biāo)簽數(shù)值轉(zhuǎn)換回原始標(biāo)簽
        IndexToString labelConverter = new IndexToString()
          .setInputCol("prediction")
          .setOutputCol("predictedLabel")
          .setLabels(labelIndexer.labels());

        // 創(chuàng)建管道并擬合模型
        Pipeline pipeline = new Pipeline()
          .setStages(new PipelineStage[]{labelIndexer, featureAssembler, dt, labelConverter});
        PipelineModel model = pipeline.fit(trainData);

        // 在測試集上進(jìn)行預(yù)測和評估
        Dataset<Row> predictions = model.transform(testData);
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
          .setLabelCol("indexedLabel")
          .setPredictionCol("prediction")
          .setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Error = " + (1.0 - accuracy));
        // 輸出決策樹結(jié)構(gòu)
        DecisionTreeClassificationModel treeModel =
        (DecisionTreeClassificationModel) (model.stages()[2]);
        System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());

        spark.stop();
        }
    }

    以上示例中,我們首先使用 SparkSession 讀取 CSV 格式的數(shù)據(jù)集。然后,使用 StringIndexer 將標(biāo)簽列轉(zhuǎn)換為數(shù)值類型,并使用 VectorAssembler 將特征列轉(zhuǎn)換為特征向量。接著,將數(shù)據(jù)集分為訓(xùn)練集和測試集,并創(chuàng)建 DecisionTreeClassifier 決策樹分類器。最后,將管道中的各個階段組合在一起,擬合模型并在測試集上進(jìn)行預(yù)測和評估。

    3.1.3 隨機(jī)森林

    隨機(jī)森林是一種集成學(xué)習(xí)算法,它將多棵決策樹組合起來,通過投票或平均來決定分類結(jié)果。該算法在 Spark MLlib 中的實(shí)現(xiàn)使用基于 CART(Classification And Regression Tree)算法的決策樹作為基分類器,可以用于二分類、多分類和概率預(yù)測問題。

    以下是一個基于 Java 的 Spark MLlib 分類隨機(jī)森林示例:

          
          import org.apache.spark.ml.Pipeline;
    import org.apache.spark.ml.PipelineModel;
    import org.apache.spark.ml.PipelineStage;
    import org.apache.spark.ml.classification.RandomForestClassificationModel;
    import org.apache.spark.ml.classification.RandomForestClassifier;
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
    import org.apache.spark.ml.feature.IndexToString;
    import org.apache.spark.ml.feature.StringIndexer;
    import org.apache.spark.ml.feature.StringIndexerModel;
    import org.apache.spark.ml.feature.VectorAssembler;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;

    public class RandomForestClassificationExample {
      public static void main(String[] args) {
        SparkSession spark = SparkSession.builder()
          .appName("RandomForestClassificationExample")
          .master("local[*]")
          .getOrCreate();

        // 讀取數(shù)據(jù)集
        Dataset<Row> data = spark.read().format("csv")
          .option("header""true")
          .option("inferSchema""true")
          .load("path/to/data.csv");

        // 將標(biāo)簽列轉(zhuǎn)換為數(shù)值類型
        StringIndexerModel labelIndexer = new StringIndexer()
          .setInputCol("label")
          .setOutputCol("indexedLabel")
          .fit(data);
        data = labelIndexer.transform(data);

        // 將特征列轉(zhuǎn)換為特征向量
        VectorAssembler featureAssembler = new VectorAssembler()
          .setInputCols(new String[]{"feature1""feature2""feature3"})
          .setOutputCol("features");
        data = featureAssembler.transform(data);

        // 將數(shù)據(jù)集分為訓(xùn)練集和測試集
        Dataset<Row>[] splits = data.randomSplit(new double[]{0.70.3}, 12345);
        Dataset<Row> trainData = splits[0];
        Dataset<Row> testData = splits[1];

        // 創(chuàng)建隨機(jī)森林分類器
        RandomForestClassifier rf = new RandomForestClassifier()
          .setLabelCol("indexedLabel")
          .setFeaturesCol("features")
          .setNumTrees(10);

        // 將標(biāo)簽數(shù)值轉(zhuǎn)換回原始標(biāo)簽
        IndexToString labelConverter = new IndexToString()
          .setInputCol("prediction")
          .setOutputCol("predictedLabel")
          .setLabels(labelIndexer.labels());

        // 創(chuàng)建管道并擬合模型
        Pipeline pipeline = new Pipeline()
          .setStages(new PipelineStage[]{labelIndexer, featureAssembler, rf, labelConverter});
        PipelineModel model = pipeline.fit(trainData);

        // 在測試集上進(jìn)行預(yù)測和評估
        Dataset<Row> predictions = model.transform(testData);
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
          .setLabelCol("indexedLabel")
          .setPredictionCol("prediction")
          .setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Error = " + (1.0 - accuracy));
        // 獲取訓(xùn)練好的隨機(jī)森林模型并打印樹的重要性
        RandomForestClassificationModel rfModel = (RandomForestClassificationModel) model.stages()[2];
        System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());

        spark.stop();
      }
    }

    該示例代碼首先使用 SparkSession 讀取 CSV 格式的數(shù)據(jù)集。接下來,使用 StringIndexer 將標(biāo)簽列轉(zhuǎn)換為數(shù)值類型,并使用 VectorAssembler 將特征列轉(zhuǎn)換為特征向量。然后,將數(shù)據(jù)集分為訓(xùn)練集和測試集。創(chuàng)建 RandomForestClassifier,并將其作為管道的一部分進(jìn)行擬合。擬合后,使用 MulticlassClassificationEvaluator 對測試集進(jìn)行預(yù)測和評估。最后,獲取訓(xùn)練好的隨機(jī)森林模型并打印樹的重要性。

    請注意,上面的示例中,數(shù)據(jù)集的路徑應(yīng)該被替換為實(shí)際數(shù)據(jù)集的路徑,特征列的名稱也應(yīng)該被替換為實(shí)際特征列的名稱。

    3.1.4 梯度提升樹

    Spark MLlib 提供了一個強(qiáng)大的算法——分類梯度提升樹(Gradient-Boosted Trees, GBT),它可以用于二元分類和多類分類。GBT 是一種集成學(xué)習(xí)算法,它通過在先前樹的殘差上逐步擬合一系列決策樹來提高模型的準(zhǔn)確性。

    在 Spark MLlib 中,可以使用 GBTClassifier 類來構(gòu)建分類 GBT 模型。GBT 分類器使用一系列決策樹來逐步提高模型的準(zhǔn)確性,每個決策樹都是在之前決策樹的殘差上訓(xùn)練得到的。通過這種方式,GBT 可以在更少的迭代次數(shù)下得到比隨機(jī)森林更準(zhǔn)確的模型。

    與其他 Spark MLlib 分類器類似,GBT 分類器也使用管道(Pipeline)來處理數(shù)據(jù)。管道通常包括以下幾個步驟:

    1. 數(shù)據(jù)預(yù)處理:包括數(shù)據(jù)清洗、特征提取、特征轉(zhuǎn)換等操作。1. 特征工程:根據(jù)特定的特征工程需求,對特征進(jìn)行過濾、選擇、轉(zhuǎn)換等操作。1. 模型訓(xùn)練:使用訓(xùn)練集對模型進(jìn)行擬合。1. 模型評估:使用測試集對模型進(jìn)行評估。1. 模型應(yīng)用:將模型應(yīng)用到新的數(shù)據(jù)集上進(jìn)行預(yù)測。 在使用 GBT 分類器時,你需要指定以下參數(shù):
    • featuresCol:特征列的名稱。- labelCol:標(biāo)簽列的名稱。- maxIter:訓(xùn)練迭代次數(shù)。- maxDepth:決策樹的最大深度。- minInstancesPerNode:每個節(jié)點(diǎn)上的最小實(shí)例數(shù)。- stepSize:每個迭代步驟的步長。- subsamplingRate:用于訓(xùn)練每棵樹的數(shù)據(jù)子樣本的比例。
          
          import org.apache.spark.ml.Pipeline;
    import org.apache.spark.ml.PipelineModel;
    import org.apache.spark.ml.PipelineStage;
    import org.apache.spark.ml.classification.GBTClassificationModel;
    import org.apache.spark.ml.classification.GBTClassifier;
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
    import org.apache.spark.ml.feature.IndexToString;
    import org.apache.spark.ml.feature.StringIndexer;
    import org.apache.spark.ml.feature.StringIndexerModel;
    import org.apache.spark.ml.feature.VectorIndexer;
    import org.apache.spark.ml.feature.VectorIndexerModel;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;


    public class GBTExample {
        public static void main(String[] args) {
            // 創(chuàng)建一個 SparkSession
            SparkSession spark = SparkSession
                    .builder()
                    .appName("GBTExample")
                    .getOrCreate();

            // 讀取數(shù)據(jù)集
            Dataset<Row> data = spark.read()
                    .format("libsvm")
                    .load("data/mllib/sample_libsvm_data.txt");

            // 對標(biāo)簽列進(jìn)行索引
            StringIndexerModel labelIndexer = new StringIndexer()
                    .setInputCol("label")
                    .setOutputCol("indexedLabel")
                    .fit(data);

            // 對特征列進(jìn)行索引
            VectorIndexerModel featureIndexer = new VectorIndexer()
                    .setInputCol("features")
                    .setOutputCol("indexedFeatures")
                    .setMaxCategories(4// 特征具有少于 4 個不同的值
                    .fit(data);

            // 將數(shù)據(jù)集拆分為訓(xùn)練集和測試集
            Dataset<Row>[] splits = data.randomSplit(new double[]{0.70.3});
            Dataset<Row> trainingData = splits[0];
            Dataset<Row> testData = splits[1];

            // 定義 GBT 分類器
            GBTClassifier gbt = new GBTClassifier()
                    .setLabelCol("indexedLabel")
                    .setFeaturesCol("indexedFeatures")
                    .setMaxIter(10)
                    .setFeatureSubsetStrategy("auto");

            // 將索引的標(biāo)簽轉(zhuǎn)換回原始標(biāo)簽
            IndexToString labelConverter = new IndexToString()
                    .setInputCol("prediction")
                    .setOutputCol("predictedLabel")
                    .setLabels(labelIndexer.labels());

            // 創(chuàng)建管道
            Pipeline pipeline = new Pipeline()
                    .setStages(new PipelineStage[]{
                            labelIndexer,
                            featureIndexer,
                            gbt,
                            labelConverter
                    });

            // 訓(xùn)練模型
            PipelineModel model = pipeline.fit(trainingData);

            // 進(jìn)行預(yù)測
            Dataset<Row> predictions = model.transform(testData);

            // 選擇樣例行顯示
            predictions.select("predictedLabel""label""features").show(5);

            // 評估模型
            MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                    .setLabelCol("indexedLabel")
                    .setPredictionCol("prediction")
                    .setMetricName("accuracy");
            double accuracy = evaluator.evaluate(predictions);
            System.out.println("Test Error = " + (1.0 - accuracy));

            // 獲取訓(xùn)練得到的 GBT 模型
            GBTClassificationModel gbtModel = (GBTClassificationModel) (model.stages()[2]);
            System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());

            spark.stop();
        }
    }

    該示例使用了 Spark MLlib 內(nèi)置的 sample_libsvm_data.txt 數(shù)據(jù)集。首先,將數(shù)據(jù)集加載到 DataFrame 中。接下來,對標(biāo)簽列和特征列進(jìn)行索引。然后,將數(shù)據(jù)集拆分為訓(xùn)練集和測試集。接下來,創(chuàng)建 GBT 分類器,并使用管道將標(biāo)簽轉(zhuǎn)換回原始標(biāo)簽。最后,使用訓(xùn)練數(shù)據(jù)擬合管道并進(jìn)行預(yù)測。最終評估模型并輸出模型學(xué)習(xí)到的 GBT 分類模型的調(diào)試字符串。該字符串顯示了樹的結(jié)構(gòu)和分裂標(biāo)準(zhǔn),以及在每個節(jié)點(diǎn)處對特征的使用情況和分裂點(diǎn)。


    瀏覽 43
    點(diǎn)贊
    評論
    收藏
    分享

    手機(jī)掃一掃分享

    分享
    舉報
    評論
    圖片
    表情
    推薦
    點(diǎn)贊
    評論
    收藏
    分享

    手機(jī)掃一掃分享

    分享
    舉報

    <kbd id="5sdj3"></kbd>
    <th id="5sdj3"></th>

  • <dd id="5sdj3"><form id="5sdj3"></form></dd>
    <td id="5sdj3"><form id="5sdj3"><big id="5sdj3"></big></form></td><del id="5sdj3"></del>

  • <dd id="5sdj3"></dd>
    <dfn id="5sdj3"></dfn>
  • <th id="5sdj3"></th>
    <tfoot id="5sdj3"><menuitem id="5sdj3"></menuitem></tfoot>

  • <td id="5sdj3"><form id="5sdj3"><menu id="5sdj3"></menu></form></td>
  • <kbd id="5sdj3"><form id="5sdj3"></form></kbd>
    亚洲日韩国产成人精品 | 苍井空在线视频一区二区三区 | 国产性爱免费视频 | 色男人男人天堂 | 欧美操逼视频免费观看 |