博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch解决鸢尾花分类
阅读量:6543 次
发布时间:2019-06-24

本文共 1958 字,大约阅读时间需要 6 分钟。

半年前用numpy写了个鸢尾花分类200行。。每一步计算都是手写的 

现在用pytorch简单写一遍,pytorch语法解释请看上一篇

1 import pandas as pd 2 import torch.nn as nn 3 import torch 4  5  6 class MyNet(nn.Module): 7     def __init__(self): 8         super(MyNet, self).__init__() 9         self.fc = nn.Sequential(10             nn.Linear(4, 3),11             nn.Sigmoid(),12             nn.Linear(3, 3),13             nn.Sigmoid(),14             nn.Linear(3, 1),15         )16         self.mls = nn.MSELoss()17         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)18 19     def get_data(self):20         inputs = []21         labels = []22         with open('flower.csv') as file:23             df = pd.read_csv(file, header=None)24             x = df.iloc[:, 0:4].values25             y = df.iloc[:, 4].values26             for i in range(len(x)):27                 inputs.append(x[i])28             for j in range(len(y)):29                 a = []30                 a.append(y[j])31                 labels.append(a)32 33         return inputs, labels34 35     def forward(self, inputs):36         out = self.fc(inputs)37         return out38 39     def train(self, x, label):40         out = self.forward(x)41         loss = self.mls(out, label)42         self.opt.zero_grad()43         loss.backward()44         self.opt.step()45 46     def test(self, x):47         return self.fc(x)48 49 50 if __name__ == '__main__':51     net = MyNet()52     inputs, labels = net.get_data()53     for i in range(1000):54         for index, input in enumerate(inputs):55             # 这里不加.float()会报错,可能是数据格式的问题吧56             input = torch.from_numpy(input).float()57             label = torch.Tensor(labels[index])58             net.train(input, label)59     # 简单测试一下60     c = torch.Tensor([[5.6, 2.7, 4.2, 1.3]])61     print(net.test(c))

运行结果趋近于0.5  正确,单纯练一下pytorch,就没有分训练集,测试集

1 tensor([[0.5392]], grad_fn=
)

不用手写反向传播和梯度下降 是多么幸福一件事~

转载于:https://www.cnblogs.com/MC-Curry/p/10109138.html

你可能感兴趣的文章
Android存储方式之SQLite的使用
查看>>
洛谷P1287 盒子与球 数学
查看>>
Bootstrap vs Foundation如何选择靠谱前端框架
查看>>
与、或、异或、取反、左移和右移
查看>>
vue常用的指令
查看>>
matlab练习程序(随机游走图像)
查看>>
Linux命令行下运行java.class文件
查看>>
input文本框实现宽度自适应代码实例
查看>>
protocol buffers的编码原理
查看>>
行为型设计模式之命令模式(Command)
查看>>
减少死锁的几个常用方法
查看>>
HDFS 核心原理
查看>>
正确配置jstl的maven依赖,jar包冲突的问题终于解决啦
查看>>
利用KMP算法解决串的模式匹配问题(c++) -- 数据结构
查看>>
登录内网账号后,连接不上内网网址
查看>>
安装 MariaDB
查看>>
【deep learning学习笔记】注释yusugomori的DA代码 --- dA.h
查看>>
纯手工打造漂亮的垂直时间轴,使用最简单的HTML+CSS+JQUERY完成100个版本更新记录的华丽转身!...
查看>>
java 为啥变量名前要加个m?
查看>>
探索Android中的Parcel机制(上)
查看>>