NumPy数组的变形(改变数组形状)
在机器学习以及深度学习的任务中,通常需要将处理好的数据以模型能接收的格式输入给模型,然后由模型通过一系列的运算,最终返回一个处理结果。然而,由于不同模型所接收的输入格式不一样,往往需要先对其进行一系列的变形和运算,从而将数据处理成符合模型要求的格式。
在矩阵或者数组的运算中,经常会遇到需要把多个向量或矩阵按某轴方向合并,或展平(如在卷积或循环神经网络中,在全连接层之前,需要把矩阵展平)的情况。下面介绍几种常用的数组变形方法。
修改指定数组的形状是 NumPy 中最常见的操作之一,常见的方法有很多,下表列出了一些常用函数和属性。
下面来看一些示例。
请看下面的代码:
请看下面的代码:
在矩阵或者数组的运算中,经常会遇到需要把多个向量或矩阵按某轴方向合并,或展平(如在卷积或循环神经网络中,在全连接层之前,需要把矩阵展平)的情况。下面介绍几种常用的数组变形方法。
修改指定数组的形状是 NumPy 中最常见的操作之一,常见的方法有很多,下表列出了一些常用函数和属性。
函数/属性 | 描述 |
---|---|
arr.reshape() | 重新将向量 arr 维度进行改变,不修改向量本身 |
arr.resize() | 重新将向量 arr 维度进行改变,修改向量本身 |
arr.T | 对向量 arr 进行转置 |
arr.ravel() | 对向量 arr 进行展平,即将多维数组变成1维数组,不会产生原数组的副本 |
arr.flatten() | 对向量 arr 进行展平,即将多维数组变成1维数组,返回原数组的副本 |
arr.squeeze() | 只能对维数为1的维度降维。对多维数组使用时不会报错,但是不会产生任何影响 |
arr.transpose() | 对高维矩阵进行轴对换 |
下面来看一些示例。
reshape() 函数
reshape() 函数用来改变向量的维度(不修改向量本身),请看下面的代码:import numpy as np arr =np.arange(10) print(arr) # 将向量 arr 维度变换为2行5列 print(arr.reshape(2, 5)) # 指定维度时可以只指定行数或列数, 其他用 -1 代替 print(arr.reshape(5, -1)) print(arr.reshape(-1, 5))输出结果:
[0 1 2 3 4 5 6 7 8 9] [[0 1 2 3 4] [5 6 7 8 9]] [[0 1] [2 3] [4 5] [6 7] [8 9]] [[0 1 2 3 4] [5 6 7 8 9]]值得注意的是,reshape() 函数不支持指定行数或列数,所以 -1 在这里是必要的。且所指定的行数或列数一定要能被整除,例如上面代码如果修改为 arr.reshape(3,-1) 即为错误的。
resize() 函数
resize() 函数用来改变向量的维度(修改向量本身),请看下面的代码:import numpy as np arr =np.arange(10) print(arr) # 将向量 arr 维度变换为2行5列 arr.resize(2, 5) print(arr)输出结果:
[0 1 2 3 4 5 6 7 8 9] [[0 1 2 3 4] [5 6 7 8 9]]
T 属性
T 属性用来对向量进行转置,请看下面的的代码:import numpy as np arr =np.arange(12).reshape(3,4) # 向量 arr 为3行4列 print(arr) # 将向量 arr 进行转置为4行3列 print(arr.T)输出结果:
[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] [[ 0 4 8] [ 1 5 9] [ 2 6 10] [ 3 7 11]]
ravel() 函数
ravel() 函数用于向量展平,请看下面的代码:import numpy as np arr =np.arange(6).reshape(2, -1) print(arr) # 按照列优先, 展平 print("按照列优先, 展平") print(arr.ravel('F')) # 按照行优先, 展平 print("按照行优先, 展平") print(arr.ravel())输出结果:
[[0 1 2] [3 4 5]] 按照列优先,展平 [0 3 1 4 2 5] 按照行优先,展平 [0 1 2 3 4 5]
flatten() 函数
flatten() 函数用来把矩阵转换为向量,这种需求经常出现在卷积网络与全连接层之间。请看下面的代码:
import numpy as np a =np.floor(10*np.random.random((3,4))) print(a) print(a.flatten())输出结果:
[[4. 0. 8. 5.] [1. 0. 4. 8.] [8. 2. 3. 7.]] [4. 0. 8. 5. 1. 0. 4. 8. 8. 2. 3. 7.]
squeeze() 函数
这是一个主要用来降维的函数,把矩阵中含1的维度去掉,请看下面的代码:import numpy as np arr =np.arange(3).reshape(3, 1) print(arr.shape) #(3,1) print(arr.squeeze().shape) #(3,) arr1 =np.arange(6).reshape(3,1,2,1) print(arr1.shape) #(3, 1, 2, 1) print(arr1.squeeze().shape) #(3, 2)
transpose() 函数
对高维矩阵进行轴对换,这个在深度学习中经常使用,比如把图片中表示颜色顺序的 RGB 改为 GBR。请看下面的代码:
import numpy as np arr2 = np.arange(24).reshape(2,3,4) print(arr2.shape) #(2, 3, 4) print(arr2.transpose(1,2,0).shape) #(3, 4, 2)
所有教程
- C语言入门
- C语言编译器
- C语言项目案例
- 数据结构
- C++
- STL
- C++11
- socket
- GCC
- GDB
- Makefile
- OpenCV
- Qt教程
- Unity 3D
- UE4
- 游戏引擎
- Python
- Python并发编程
- TensorFlow
- Django
- NumPy
- Linux
- Shell
- Java教程
- 设计模式
- Java Swing
- Servlet
- JSP教程
- Struts2
- Maven
- Spring
- Spring MVC
- Spring Boot
- Spring Cloud
- Hibernate
- Mybatis
- MySQL教程
- MySQL函数
- NoSQL
- Redis
- MongoDB
- HBase
- Go语言
- C#
- MATLAB
- JavaScript
- Bootstrap
- HTML
- CSS教程
- PHP
- 汇编语言
- TCP/IP
- vi命令
- Android教程
- 区块链
- Docker
- 大数据
- 云计算