博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用 DL4J 训练中文词向量
阅读量:7005 次
发布时间:2019-06-27

本文共 8300 字,大约阅读时间需要 27 分钟。

hot3.png

目录

使用  训练中文词向量

1 预处理

对中文语料的预处理,主要包括:分词、去停用词以及一些根据实际场景制定的规则。

package ai.mole.test;import org.ansj.domain.Term;import org.ansj.splitWord.analysis.ToAnalysis;import org.nlpcn.commons.lang.tire.domain.Forest;import org.nlpcn.commons.lang.tire.library.Library;import java.io.*;import java.util.LinkedList;import java.util.List;import java.util.regex.Pattern;public class Preprocess {    private static final Pattern NUMERIC_PATTERN = Pattern.compile("^[.\\d]+$");    private static final Pattern ENGLISH_WORD_PATTERN = Pattern.compile("^[a-z]+$");    public static void main(String[] args) {        String inPath1 = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\test1.txt";        String inPath2 = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\stop_words.txt";        String outPath = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\result1.txt";        String encoding = "utf-8";        PrintWriter writer = null;        Forest forest = null;        try {            writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outPath), encoding));            forest = Library.makeForest(Test.class.getResourceAsStream("/library/userLibrary.dic"));            List
lineList = IOUtil.readLines(new FileInputStream(inPath1), encoding); List
stopWordList = IOUtil.readLines(new FileInputStream(inPath2), encoding); for (String line : lineList) { String[] cols = line.split("\\t", -1); if (cols.length < 2) { continue; } String text = cols[0].trim().toLowerCase() + " " + cols[1].trim().toLowerCase(); // 分词 List
termList = ToAnalysis.parse(text, forest).getTerms(); List
wordList = new LinkedList<>(); for (Term term : termList) { String word = term.getName(); if (word.length() < 2) { continue; } if (stopWordList.contains(word)) { continue; } if (isNumeric(word)) { continue; } if (isEnglishWord(word)) { continue; } wordList.add(word); } if (wordList.size() > 5) { String outStr = listToLine(wordList); writer.println(outStr); } } } catch (FileNotFoundException e) { System.out.println("The file does not exist or the path is not correct!!!"); System.exit(-1); } catch (UnsupportedEncodingException e) { System.out.println("Does not support the current character set!!!"); } catch (IOException e) { e.printStackTrace(); } catch (Exception e) { e.printStackTrace(); } finally { if (writer != null) { writer.close(); } } } private static boolean isNumeric(String text) { return NUMERIC_PATTERN.matcher(text).matches(); } private static boolean isEnglishWord(String text) { return ENGLISH_WORD_PATTERN.matcher(text).matches(); } private static String listToLine(List
list) { StringBuilder sb = new StringBuilder(); for (int i=0; i

2 训练

训练的代码非常简单,可以直接看,至于  的原理可以看皮提果的博文。

package ai.mole.test;import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;import org.deeplearning4j.models.word2vec.Word2Vec;import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;import org.deeplearning4j.text.sentenceiterator.SentenceIterator;import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import java.io.File;import java.io.IOException;import java.util.Collection;public class TrainWord2VecModel {    private static Logger log = LoggerFactory.getLogger(TrainWord2VecModel.class);    public static void main(String[] args) throws IOException {        String corpusPath = "/data/analyze/xgp/words.txt";        String vectorsPath = "/data/analyze/xgp/word_vectors.txt";        log.info("Start Training...");        long st = System.currentTimeMillis();        log.info("Load & vectorize sentences...");        SentenceIterator iter = new BasicLineIterator(new File(corpusPath));        TokenizerFactory t = new DefaultTokenizerFactory();//        t.setTokenPreProcessor(new CommonPreprocessor());        log.info("Building model...");        Word2Vec vec = new Word2Vec.Builder()                .minWordFrequency(50)                .iterations(1)                .epochs(100)                .layerSize(500)                .seed(42)                .windowSize(5)                .iterate(iter)                .tokenizerFactory(t)                .build();        log.info("Fitting word2vec model...");        vec.fit();        log.info("Writing word vectors to text file...");//        WordVectorSerializer.writeWord2VecModel(vec, vectorsPath);        WordVectorSerializer.writeWordVectors(vec, vectorsPath);        log.info("Closest words:");        Collection
bydWordList = vec.wordsNearest("比亚迪", 10); Collection
changanWordList = vec.wordsNearest("长安", 10); System.out.print(bydWordList); System.out.println(changanWordList); log.info("10 words closest to '比亚迪': {}", bydWordList); log.info("10 words closest to '长安': {}", changanWordList); long et = System.currentTimeMillis(); log.info("Training is completed, and the time taken is " + (et-st) + " ms."); System.out.println("Training is completed, and the time taken is " + (et-st) + " ms."); }}

3 调用

调用训练好的词向量也非常简单,只需要调用 WordVectorSerializer 类的静态方法 readWord2VecModel就可以了,提供的输入参数就是训练好的词向量路径。

Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("D:\\MyData\\XUGP3\\Desktop\\测试分词\\vectors.txt");Collection
bydWordList = word2Vec.wordsNearest("比亚迪", 10);Collection
changanWordList = word2Vec.wordsNearest("长安", 10);System.out.println(bydWordList);System.out.println(changanWordList);

附录 - maven 依赖

org.apdplat
word
1.3
org.nd4j
${nd4j.backend}
${nd4j.version}
org.deeplearning4j
deeplearning4j-core
${dl4j.version}
org.deeplearning4j
deeplearning4j-nlp
${dl4j.version}
org.deeplearning4j
deeplearning4j-zoo
${dl4j.version}
org.deeplearning4j
deeplearning4j-ui_${scala.binary.version}
${dl4j.version}
org.deeplearning4j
deeplearning4j-parallel-wrapper_${scala.binary.version}
${dl4j.version}
org.datavec
datavec-hadoop
${datavec.version}
org.apache.hadoop
hadoop-common
${hadoop.version}
org.deeplearning4j
arbiter-deeplearning4j
${arbiter.version}
org.deeplearning4j
arbiter-ui_2.11
${arbiter.version}
datavec-data-codec
org.datavec
${datavec.version}

分类: 

标签: 

转载于:https://my.oschina.net/airship/blog/2994425

你可能感兴趣的文章
Python2+Selenium入门01-环境准备
查看>>
golang协程池设计
查看>>
微服务之数据同步Porter
查看>>
phpStudy 升级 mysql5.7 出现的问题
查看>>
mp4文件如何转换为webm格式
查看>>
(一)如何实现一个单进程阻塞的网络服务器
查看>>
微信小程序设置上一页数据
查看>>
两种让用户自定义项目主题色的方案
查看>>
android 中文字体向上偏移解决方案
查看>>
Project-Euler第69题
查看>>
Spring Cloud OAuth2 资源服务器CheckToken 源码解析
查看>>
jQuery DOM操作
查看>>
高频写入redis场景优化
查看>>
一直在做业务的程序员技术会进步吗?我们该如何跳出舒适圈
查看>>
Promise 源码分析
查看>>
mobx
查看>>
C++ Primer 第三章 学习笔记及习题答案
查看>>
Lodash学习小记
查看>>
webpack4 系列教程(十六):开发模式和生产模式·实战
查看>>
Elasticsearch 参考指南(查询和过滤器上下文)
查看>>