0
点赞
收藏
分享

微信扫一扫

基于wenet的自定义数据集大小的AISHELL-Sample教程

乐百川 2022-04-24 阅读 237

数据集下载

下载地址
并解压
命令:unzip AISHELL-1_sample.zip

1.准备 wav.scp text

#!/bin/bash
. ./path.sh

# 数据集存放的位置
sample_data=/home/asr/data/wenet/examples/aishell/s0/datasets/AISHELL-1_sample
# 数据生成的地方
data=/home/asr/data/wenet/examples/aishell/s0/data_
if [ ! -d $data ];then
    mkdir -p $data
fi
# 初始化
rm -rf $data/wav.scp
rm -rf $data/text
# 1.准备 wav.scp text
for sub_dir in `ls ${sample_data}`;do
    wav_txt_dir=${sample_data}/${sub_dir}/${sub_dir}_mic
    echo $wav_txt_dir
    for file in `ls $wav_txt_dir`;do
        if [ ${file#*.} != "txt" ];then
            # 准备wav.scp
            echo "${file%.*} $wav_txt_dir/${file%.*}.wav" >> $data/wav.scp
            # echo `wc -l $data/wav.scp`
            # 准备text
            txt=`cat $wav_txt_dir/${file%.*}.txt`
            echo "${file%.*} $txt" >> $data/text
        fi
    done
done
echo "wav.scp and text done!"

2.准备data.list

使用同时读取两个文件,生成data.list

# 2.准备data.list
exec 3<$data/wav.scp 
exec 4<$data/text
exec 5<$data/text
rm -rf $data/data.list
while read wav <&3 && read txt <&4 && read txt1 <&5
do  
    key=`echo $wav | awk -F ' ' '{ printf $1}'`
    wav=`echo $wav | awk -F ' ' '{ printf $2}'`
    txt=`echo $txt | awk -F ' ' '{ printf $2}'`
    echo "{\"key\":\"${key}\",\"wav\":\"${wav}\",\"txt\":\"${txt}\" }" >> $data/data.list
done
echo "data.list done!"

3.准备dict

使用python脚本生成dict

python get_lang_char.py >  $data/lang_char.txt

get_lang_char.py

import os
text_dir = "./data_/text"
lang_char = set()
with open(text_dir,'r',encoding='utf-8') as rfile:
    lines = rfile.readlines()
    for line in lines:
        text = line.split(" ")[1].strip("\n")
        for char in text:
            lang_char.add(char)

print("<blank> 0")
print("<unk> 1")
id=0
for id,char in enumerate(lang_char):
    print(char,id+2)
print("<sos/eos>",id+3)

4.计算CMVN

onfig=conf/train_u2++_conformer.yaml

4.cmvn
tools/compute_cmvn_stats.py \
    --num_workers 8 \
    --train_config $config \
    --in_scp data_/wav.scp \
    --out_cmvn data_/global_cmvn

5.训练

python3 wenet/bin/train.py \
    --config $config \
    --data_type raw \
    --symbol_table data_/lang_char.txt \
    --train_data data_/data.list \
    --model_dir data_/model \
    --cv_data data_/data.list \
    --num_workers 8 \
    --cmvn data_/global_cmvn \
    --pin_memory

6.合并模型

# 6.合并模型
python wenet/bin/average_model.py \
      --dst_model data_/average.pt \
      --src_path data_/model \
      --num 30 \
      --val_best

7.测试模型

python3 wenet/bin/recognize.py \
    --mode "attention_rescoring" \
    --config data_/train.yaml \
    --data_type raw \
    --test_data data_/data.list \
    --chechpoint data_/model/final.pt \
    --beam_size 10 \
    --batch_size 1 \
    --penalty 0.0 \
    --dict data_/lang_char.txt \
    --ctc_weight 1.0 \
    --reverse_weight 0 \
    --result_file data_/result.txt

python tools/compute-wer.py --char=1 --v=1 \
    data_/text data_/result.txt > data_/wer.txt

采用更小的数据集

如果自己的知识想简单的跑通流程,可以缩小数据集

head -n 100 data.list > data.list.100
head -n 100 text > text.100
head -n 100 wav.scp >  wav.scp.100
举报

相关推荐

0 条评论