博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
JavaScript机器学习之KNN算法
阅读量:7211 次
发布时间:2019-06-29

本文共 7172 字,大约阅读时间需要 23 分钟。

hot3.png

译者按: 机器学习原来很简单啊,不妨动手试试!

原文:

译者:

为了保证可读性,本文采用意译而非直译。另外,本文版权归原作者所有,翻译仅用于学习。另外,我们修正了原文代码中的错误

上图使用所画。

上次我们用JavaScript实现了,这次我们来聊聊KNN算法。

KNN是k-Nearest-Neighbours的缩写,它是一种监督学习算法。KNN算法可以用来做分类,也可以用来解决回归问题。

GitHub仓库:

KNN算法简介

简单地说,KNN算法由那离自己最近的K个点来投票决定待分类数据归为哪一类

如果待分类的数据有这些邻近数据,NY: 7, NJ: 0, IN: 4,即它有7个NY邻居,0个NJ邻居,4个IN邻居,则这个数据应该归类为NY

假设你在邮局工作,你的任务是为邮递员分配信件,目标是最小化到各个社区的投递旅程。不妨假设一共有7个街区。这就是一个实际的分类问题。你需要将这些信件分类,决定它属于哪个社区,比如上东城曼哈顿下城等。

最坏的方案是随意分配信件分配给邮递员,这样每个邮递员会拿到各个社区的信件。

最佳的方案是根据信件地址进行分类,这样每个邮递员只需要负责邻近社区的信件。

也许你是这样想的:"将邻近3个街区的信件分配给同一个邮递员"。这时,邻近街区的个数就是k。你可以不断增加k,直到获得最佳的分配方案。这个k就是分类问题的最佳值。

KNN代码实现

像一样,我们将使用的KNN模块来实现。

每一个机器学习算法都需要数据,这次我将使用IRIS数据集。其数据集包含了150个样本,都属于下的三个亚属,分别是、和。四个特征被用作样本的定量分析,它们分别是和的长度和宽度。

1. 安装模块

$ npm install ml-knn@2.0.0 csvtojson prompt

: k-Nearest-Neighbours模块,不同版本的接口可能不同,这篇博客使用了2.0.0

: 用于将CSV数据转换为JSON

: 在控制台输入输出数据

2. 初始化并导入数据

****由加州大学欧文分校提供。

curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv

假设你已经初始化了一个NPM项目,请在index.js中输入以下内容:

const KNN = require('ml-knn');const csv = require('csvtojson');const prompt = require('prompt');var knn;const csvFilePath = 'iris.csv'; // 数据集const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];let seperationSize; // 分割训练和测试数据let data = [],    X = [],    y = [];let trainingSetX = [],    trainingSetY = [],    testSetX = [],    testSetY = [];
  • seperationSize用于分割数据和测试数据

使用csvtojson模块的fromFile方法加载数据:

csv(    {        noheader: true,        headers: names    })    .fromFile(csvFilePath)    .on('json', (jsonObj) =>    {        data.push(jsonObj); // 将数据集转换为JS对象数组    })    .on('done', (error) =>    {        seperationSize = 0.7 * data.length;        data = shuffleArray(data);        dressData();    });

我们将seperationSize设为样本数目的0.7倍。注意,如果训练数据集太小的话,分类效果将变差。

由于数据集是根据种类排序的,所以需要使用shuffleArray函数对数据进行混淆,这样才能方便分割出训练数据。这个函数的定义请参考StackOverflow的提问:

function shuffleArray(array){    for (var i = array.length - 1; i > 0; i--)    {        var j = Math.floor(Math.random() * (i + 1));        var temp = array[i];        array[i] = array[j];        array[j] = temp;    }    return array;}

3. 转换数据

数据集中每一条数据可以转换为一个JS对象:

{ sepalLength: ‘5.1’, sepalWidth: ‘3.5’, petalLength: ‘1.4’, petalWidth: ‘0.2’, type: ‘Iris-setosa’ }

在使用KNN算法训练数据之前,需要对数据进行这些处理:

  1. 将属性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串转换为浮点数. (parseFloat)
  2. 将分类 (type)用数字表示
function dressData(){    let types = new Set();     data.forEach((row) =>    {        types.add(row.type);    });    let typesArray = [...types];     data.forEach((row) =>    {        let rowArray, typeNumber;        rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);        typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)        X.push(rowArray);        y.push(typeNumber);    });    trainingSetX = X.slice(0, seperationSize);    trainingSetY = y.slice(0, seperationSize);    testSetX = X.slice(seperationSize);    testSetY = y.slice(seperationSize);    train();}

4. 训练数据并测试

function train(){    knn = new KNN(trainingSetX, trainingSetY,    {        k: 7    });    test();}

train方法需要2个必须的参数: 输入数据,即和的长度和宽度;实际分类,即、和。另外,第三个参数是可选的,用于提供调整KNN算法的内部参数。我将k参数设为7,其默认值为5。

训练好模型之后,就可以使用测试数据来检查准确性了。我们主要对预测出错的个数比较感兴趣。

function test(){    const result = knn.predict(testSetX);    const testSetLength = testSetX.length;    const predictionError = error(result, testSetY);    console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);    predict();}

比较预测值与真实值,就可以得到出错个数:

function error(predicted, expected){    let misclassifications = 0;    for (var index = 0; index < predicted.length; index++)    {        if (predicted[index] !== expected[index])        {            misclassifications++;        }    }    return misclassifications;}

5. 进行预测(可选)

任意输入属性值,就可以得到预测值

function predict(){    let temp = [];    prompt.start();    prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)    {        if (!err)        {            for (var key in result)            {                temp.push(parseFloat(result[key]));            }            console.log(`With ${temp} -- type =  ${knn.predict(temp)}`);        }    });}

6. 完整程序

完整的程序index.js是这样的:

const KNN = require('ml-knn');const csv = require('csvtojson');const prompt = require('prompt');var knn;const csvFilePath = 'iris.csv'; // 数据集const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];let seperationSize; // 分割训练和测试数据let data = [],    X = [],    y = [];let trainingSetX = [],    trainingSetY = [],    testSetX = [],    testSetY = [];csv(    {        noheader: true,        headers: names    })    .fromFile(csvFilePath)    .on('json', (jsonObj) =>    {        data.push(jsonObj); // 将数据集转换为JS对象数组    })    .on('done', (error) =>    {        seperationSize = 0.7 * data.length;        data = shuffleArray(data);        dressData();    });function dressData(){    let types = new Set();     data.forEach((row) =>    {        types.add(row.type);    });    let typesArray = [...types];     data.forEach((row) =>    {        let rowArray, typeNumber;        rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);        typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)        X.push(rowArray);        y.push(typeNumber);    });    trainingSetX = X.slice(0, seperationSize);    trainingSetY = y.slice(0, seperationSize);    testSetX = X.slice(seperationSize);    testSetY = y.slice(seperationSize);    train();}// 使用KNN算法训练数据function train(){    knn = new KNN(trainingSetX, trainingSetY,    {        k: 7    });    test();}// 测试训练的模型function test(){    const result = knn.predict(testSetX);    const testSetLength = testSetX.length;    const predictionError = error(result, testSetY);    console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);    predict();}// 计算出错个数function error(predicted, expected){    let misclassifications = 0;    for (var index = 0; index < predicted.length; index++)    {        if (predicted[index] !== expected[index])        {            misclassifications++;        }    }    return misclassifications;}// 根据输入预测结果function predict(){    let temp = [];    prompt.start();    prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)    {        if (!err)        {            for (var key in result)            {                temp.push(parseFloat(result[key]));            }            console.log(`With ${temp} -- type =  ${knn.predict(temp)}`);        }    });}// 混淆数据集的顺序function shuffleArray(array){    for (var i = array.length - 1; i > 0; i--)    {        var j = Math.floor(Math.random() * (i + 1));        var temp = array[i];        array[i] = array[j];        array[j] = temp;    }    return array;}

在控制台执行node index.js

$ node index.js

输出如下:

Test Set Size = 45 and number of Misclassifications = 2prompt: Sepal Length:  1.7prompt: Sepal Width:  2.5prompt: Petal Length:  0.5prompt: Petal Width:  3.4With 1.7,2.5,0.5,3.4 -- type =  2

参考链接

欢迎加入的全栈BUG监控交流群: 622902485

版权声明:

转载时请注明作者Fundebug以及本文地址:

转载于:https://my.oschina.net/u/3375885/blog/1353429

你可能感兴趣的文章
php请求页面将返回的页面发送email
查看>>
#土豆记事#教你开发Android App之 —— Hello Android
查看>>
安全机构 abuse.ch 公布近10万个恶意网站
查看>>
若依后台管理系统 3.3 发布,新增多项功能
查看>>
三步教你学会git
查看>>
高防服务器,高防免费服务器,阿里云服务器高防价格
查看>>
JavaScript中8个常见的陷阱
查看>>
使用脚本在Linux服务器上自动安装Kubernetes的包管理器Helm
查看>>
15个顶级Java多线程面试题及回答
查看>>
STL中deque,queue,stack,list的学习
查看>>
带protobuf的通用型makefile
查看>>
Linux环境变量总结
查看>>
jquery里prop和attr的区别
查看>>
scala maven idea 第一个scala 程序
查看>>
Spring Web MVC框架(十一) Spring Web MVC测试框架
查看>>
Linux设备模型 (2)
查看>>
Spring源码剖析2:Spring IOC容器的加载过程
查看>>
Throttle, more throttle
查看>>
部分手机Toast不显示的解决办法
查看>>
解决win10系统下,git Bash闪退的问题
查看>>