N-gram算法的pytorch代码实现

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

def tri_gramizer(test_sentence):
    # 将单词序列转化为数据元组列表,
    # 其中的每个元组格式为([ word_i-2, word_i-1 ], target word)
    trigrams = [ ([test_sentence[i], test_sentence[i+1]], test_sentence[i+2]) for i in range(len(test_sentence) - 2) ]

    # 给14行诗建立单词表
    # set 即去除重复的词
    vocab = set(test_sentence)
    # 建立词典,它比单词表多了每个词的索引
    word_to_ix = { word: i for i, word in enumerate(vocab) }
    
    print('The vocab length:', len(vocab))
    
    return trigrams, vocab, word_to_ix

class NGramLanguageModeler(nn.Module):
    # 初始化时需要指定:单词表大小、想要嵌入的维度大小、上下文的长度
    def __init__(self, vocab_size, embedding_dim, context_size):
        # 继承自nn.Module,例行执行父类super 初始化方法
        super(NGramLanguageModeler, self).__init__()
        # 建立词嵌入模块
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        # 线性层1
        self.linear1 = nn.Linear(context_size * embedding_dim, 128)
        # 线性层2,隐藏层 hidden_size 为128
        self.linear2 = nn.Linear(128, vocab_size)

    # 重写的网络正向传播方法
    # 只要正确定义了正向传播
    # PyTorch 可以自动进行反向传播
    def forward(self, inputs):
        # 将输入进行“嵌入”,并转化为“行向量”
        embeds = self.embeddings(inputs).view((1, -1))
        # 嵌入后的数据通过线性层1后,进行非线性函数 ReLU 的运算
        out = F.relu(self.linear1(embeds))
        # 通过线性层2后
        out = self.linear2(out)
        # 通过 log_softmax 方法将结果映射为概率的log
        # log 概率是用于下面计算负对数似然损失函数时方便使用的
        return out

def train(trigrams, vocab, word_to_ix):
    print('Training...')
    
    # 上下文大小
    # 即 前两个词
    CONTEXT_SIZE = 2
    # 嵌入维度
    EMBEDDING_DIM = 10

    # 计算损失
    losses = []
    # 损失函数为 交叉熵损失函数(Cross Entropy Loss)
    loss_function = nn.CrossEntropyLoss()  # 将NLLLoss替换为CrossEntropyLoss
    # 实例化我们的模型,传入:
    # 单词表的大小、嵌入维度、上下文长度
    model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
    # 优化函数使用随机梯度下降算法,学习率设置为0.001
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    for epoch in range(1000):
        print(f'epoch: {epoch}')
        total_loss = 0
        # 循环context上下文,比如:['When', 'forty']
        # target,比如:winters
        for context, target in trigrams:

            # 步骤1:准备数据
            # 将context如“['When', 'forty']”
            # 转化为索引,如[68, 15]
            # 不再需要建立为 PyTorch Variable 变量,张量默认支持自动求导
            context_idxs = torch.LongTensor(list(map(lambda w: word_to_ix[w], context)))

            # 步骤2:清空梯度值,防止上次的梯度累计
            model.zero_grad()

            # 步骤3:运行网络的正向传播,获得 log 概率
            out = model(context_idxs)

            # 步骤4:计算损失函数
            # 不再需要传入 autograd.Variable
            loss = loss_function(out, torch.LongTensor([word_to_ix[target]]))

            # 步骤5:进行反向传播并更新梯度
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        losses.append(total_loss)

    print('Finished')    
    # 保存模型的状态字典和相关信息
    torch.save(model.state_dict(), 'model_state_dict.pth')
    return model, losses

def plot_losses(losses):
    plt.figure()
    plt.plot(losses)


def predict(input_data, model):
    first_word, second_word = input_data
    if first_word not in vocab or second_word not in vocab:
        print('Unknown word')
        return '-1'
    input_tensor = torch.LongTensor([word_to_ix[first_word], word_to_ix[second_word]])
    predict_idx = torch.argmax(model(input_tensor)).item()
    predict_word = list(vocab)[predict_idx]
    print('input words:', first_word, second_word)
    print('predicted word:', predict_word)
    return predict_word

if __name__ == '__main__':
    # 数据我们使用的是莎士比亚的14行诗
    test_sentence = """When forty winters shall besiege thy brow,
    And dig deep trenches in thy beauty's field,
    Thy youth's proud livery so gazed on now,
    Will be a totter'd weed of small worth held:
    Then being asked, where all thy beauty lies,
    Where all the treasure of thy lusty days;
    To say, within thine own deep sunken eyes,
    Were an all-eating shame, and thriftless praise.
    How much more praise deserv'd thy beauty's use,
    If thou couldst answer 'This fair child of mine
    Shall sum my count, and make my old excuse,'
    Proving his beauty by succession thine!
    This were to be new made when thou art old,
    And see thy blood warm when thou feel'st it cold.""".split()    # 按空格切分 


    trigrams, vocab, word_to_ix = tri_gramizer(test_sentence)

    # model, losses = train(trigrams, vocab, word_to_ix)
    # plot_losses(losses)
    
    # 上下文大小
    # 即 前两个词
    CONTEXT_SIZE = 2
    # 嵌入维度
    EMBEDDING_DIM = 10    
    model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
    model.load_state_dict(torch.load('model_state_dict.pth'))
    
    input_data = ['When', 'forty']
    word = predict(input_data, model)

    

参考文章:深度学习新手必学:使用 Pytorch 搭建一个 N-Gram 模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/776079.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW环境下OCR文字识别的实现策略与挑战解析

引言 在自动化测试领域,OCR(Optical Character Recognition,光学字符识别)技术扮演着重要角色,它能够将图像中的文字转换成机器可编辑的格式。对于使用LabVIEW约5个月,主要进行仪器控制与数据采集的你而言…

什么是T0策略?有没有可以持仓自动做T的策略软件?

​​行情低迷,持仓被套,不想被动等待?长期持股,想要增厚持仓收益?有没有可以自动做T的工具或者策略?日内T0交易,做到降低持仓成本,优化收益预期。 什么是T0策略? 可以提…

知识图谱和 LLM:多跳问答

检索增强生成(RAG)应用程序通过将外部来源的数据集成到 LLM 中,擅长回答简单的问题。但他们很难回答涉及将相关信息之间的点连接起来的多部分问题。这是因为 RAG 应用程序需要一个数据库,该数据库旨在存储数据,以便轻松…

c++ 里如何检测内存泄露:比如用了 new ,但没有用 delete

(1 方法一) 用 MFC 框架的 F5 不带断点的调试。可以在输出窗口提示是否有内存泄露。 (2 方法二) ,在 main 函数中添加如下代码,用 F5 不带断点的调试: int main() {_CrtSetDbgFlag( _CRTDBG_A…

JAVA 集合+对象复制工具类

JAVA 集合对象复制工具类 import jakarta.annotation.Nullable;import java.util.ArrayList; import java.util.List; import java.util.function.BiFunction; import java.util.function.Consumer;public class BeanUtil extends cn.hutool.core.bean.BeanUtil {/*** 数据拷贝…

Linux高并发服务器开发(十三)Web服务器开发

文章目录 1 使用的知识点2 http请求get 和 post的区别 3 整体功能介绍4 基于epoll的web服务器开发流程5 服务器代码6 libevent版本的本地web服务器 1 使用的知识点 2 http请求 get 和 post的区别 http协议请求报文格式: 1 请求行 GET /test.txt HTTP/1.1 2 请求行 健值对 3 空…

SQL索引事务

SQL索引事务 索引 创建主键约束(primary key),唯一约束(unique),外键约束(foreign key)时,会自动创建对应列的索引 1.1 查看索引 show index from 表名 现在这个表中没有索引,那么我们现在将这几个表删除之后创建新表 我们现在建立一个班级表一个学生表,并且学生表与班级表存…

EVM-MLIR:以MLIR编写的EVM

1. 引言 EVM_MLIR: 以MLIR编写的EVM。 开源代码实现见: https://github.com/lambdaclass/evm_mlir(Rust) 为使用MLIR和LLVM,将EVM-bytecode,转换为,machine-bytecode。LambdaClass团队在2周…

无人机水运应用场景

航行运输 通航管理(海事通航管理处) 配员核查流程 海事员通过VHF(甚高频)系统与船长沟通核查时间。 无人机根据AIS(船舶自动识别系统)报告的船舶位置,利用打点定位 功能飞抵船舶上方。 使用…

大型能源电力集团需要什么样的总部数据下发系统?

能源电力集团的组织结构是一个复杂的系统,包括多个职能部门和子分公司。这些子分公司负责具体的电力生产、销售、运维等业务。这些部门和公司协同工作,确保电力生产的顺利进行,同时关注公司的长期发展、市场拓展、人力资源管理、财务管理和公…

SCI一区级 | Matlab实现BO-Transformer-LSTM多特征分类预测/故障诊断

SCI一区级 | Matlab实现BO-Transformer-LSTM多特征分类预测/故障诊断 目录 SCI一区级 | Matlab实现BO-Transformer-LSTM多特征分类预测/故障诊断效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.【SCI一区级】Matlab实现BO-Transformer-LSTM特征分类预测/故障诊断&…

winform2

12.TabControl 导航控制条 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms; namespace zhiyou_…

发现CPU占用过高,该如何排查解决?

1.使用top命令 查看cpu占用最多的进程 2.使用 top -H -p pid 发现有两个线程占用比较大 3.将线程id转换为16进制 使用命令 printf 0x%x\n pid 4.使用 jstack pid | grep 线程id(16进制) -A 20 (显示20行) 根据代码显示进行错误排查

2024年7月5日 (周五) 叶子游戏新闻

老板键工具来唤去: 它可以为常用程序自定义快捷键,实现一键唤起、一键隐藏的 Windows 工具,并且支持窗口动态绑定快捷键(无需设置自动实现)。 卸载工具 HiBitUninstaller: Windows上的软件卸载工具 《乐高地平线大冒险》为何不登陆…

娱乐圈惊爆已婚男星刘端端深夜幽会

【娱乐圈惊爆!已婚男星刘端端深夜幽会,竟是《庆余年》二皇子“戏外风云”】在这个信息爆炸的时代,娱乐圈的每一次风吹草动都能瞬间点燃公众的热情。今日,知名娱乐博主刘大锤的一则预告如同投入湖中的巨石,激起了层层涟…

关于下载obsidian SimpRead Sync中报错的问题

参考Kenshin的配置方法,我却在输入简悦的配置文件目录时多次报错。 bug如下: 我发现导出来的配置文件格式如下: 然后根据报错的bug对此文件名进行修改,如下: 解决。

Java数据结构-树的面试题

目录 一.谈谈树的种类 二.红黑树如何实现 三.二叉树的题目 1.求一个二叉树的高度,有两种方法。 2.寻找二叉搜索树当中第K大的值 3、查找与根节点距离K的节点 4.二叉树两个结点的公共最近公共祖先 本专栏全是博主自己收集的面试题,仅可参考&#xf…

暑假前端知识速成【CSS】系列一

坚持就是希望! 什么是CSS? CSS 指的是层叠样式表* (Cascading Style Sheets)CSS 描述了如何在屏幕、纸张或其他媒体上显示 HTML 元素CSS 节省了大量工作。它可以同时控制多张网页的布局外部样式表存储在 CSS 文件中 *:也称级联样式表。 CSS语法 在此例…

微信小程序的智慧物流平台-计算机毕业设计源码49796

目 录 摘要 1 绪论 1.1 研究背景 1.2 研究意义 1.3研究方法 1.4开发技术 1.4.1 微信开发者工具 1.4.2 Node.JS框架 1.4.3 MySQL数据库 1.5论文结构与章节安排 2系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 用户登录流程 2.2.2 数据删除流程 2.3 系统功能分…

Windows 上帝模式是什么?开启之后有什么用处?

Windows 上帝模式是什么 什么是上帝模式?Windows 上帝模式(God Mode)是一个隐藏的文件夹,通过启用它,用户可以在一个界面中访问操作系统的所有管理工具和设置选项。这个功能最早出现在 Windows Vista 中,并…