Matplotlib绘图基础

Matplotlib是Python的一个工具包,提供了丰富的数据绘图工具,主要用于绘制一些统计图形。现将网络上多个学习Matplotlib的总结搬运至此,留作备忘。
简单绘图中,使用最多的是 matplotlib.pyplot.plot 函数。 Matplotlib中函数的使用方式,大致和Matlab一致的。
Notes:

  • from pylab import * 会包含 np, mpl, plt等模块,甚至可以直接使用 plot()函数而不是使用plt.plot()。
  • ipython --pylab #使用--pylab参数启动ipython,这个参数允许使用matplotlib交互,直接绘图。或者进入ipython 之后,输入 %pylab
  • import matplotlib 之后输入 matplotlib.use('agg')。此后,使用 plt.plot() 则不会绘图,再使用 savefig() 保存图片。
  • color maps

Getting Started with Matplotlib

import numpy as np 
import matplotlib as mpl  
import matplotlib.pyplot as plt 

x=np.linspace(-np.pi,np.pi,256,endpoint=True)
C,S=np.cos(x),np.sin(x)

使用默认的绘图属性

plt.plot(x,C)   # 默认为 lines 
plt.plot(x,S)
plt.show()

plt.plot(x,C,x,S)   # 同时画多条线
plt.show()

plt.plot(x,C,'.')   # 只画点,不画线
plt.show()

Grid, legend, and axis

plt.grid(True)
plt.plot(x,C)
plt.show()

plt.plot([1,2,3,4],label=r'$cos(t)$')
plt.legend()

plt.legend(loc='best', fontsize = 9, frameon=False) # 字体大小除了可以用数字外,还可以用预设的值 'small', 'x-small', 'xx-small'等值
loc表示legend的位置,frameon指定legend是否有边框,bbox_to_anchor能够指定legend偏移的位置
numpoints指定marker点的个数(for lines),scatterpoints指定marker点的个数(for scatter)

# 调整坐标轴范围 
plt.axis()               # 输出[xmin, xmax, ymin, ymax] 
plt.axis([0, 4, 0, 4])   # 改变取值范围

axis label and figure title

plt.xlabel('XXXXX', fontsize = 10)    
plt.ylabel('XXXXXX', fontsize = 10)   
plt.title('XXXXX', fontsize = 10)

Handling X and Y ticks

ticks是指坐标轴上的标注(刻度)

x = [5, 3, 7, 2, 4, 1] 
plt.plot(x) 
plt.xticks(range(len(x)), ['a', 'b', 'c', 'd', 'e', 'f']) 
plt.yticks(range(1, 8, 2))

坐标轴的范围以及标注

plt.xlim(X.min()*1.1, X.max()*1.1)
plt.ylim(C.min()*1.1, C.max()*1.1)

通过ticks对横轴和纵轴的含义进行设置和定制。

plt.xlim(x.min()*1.1, x.max()*1.1)
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$+\pi/2$', r'$+\pi$'])

plt.ylim(C.min()*1.1,C.max()*1.1)
plt.yticks([-1, 0, +1], [r'$-1$', r'$0$', r'$+1$'])

设置边框(spines)和刻度的位置。

边框(spines)是围绕着图像边界的线,分为上下左右四个边框。
通过将它们的颜色设置成None,可以来调整是否显示边框。

fig=plt.figure()
ax = fig.add_subplot(111)

ax.spines['right'].set_color('none')    # 不显示右边框 
ax.spines['top'].set_color('none')      # 不显示上边框 

ax.spines['bottom'].set_position(('data',0))  # 设置位置
ax.spines['left'].set_position(('data',0))    # 设置位置

## 对应x轴 1,2分别表示下和上;y轴 1,2分别表示左和右 
for t in ax.xaxis.get_major_ticks():
    t.tick1On = True
    t.tick2On = False
for t in ax.yaxis.get_major_ticks():
    t.tick1On = True
    t.tick2On = False

ax.xaxis.set_ticks_position('bottom')   # 设定x轴在底部
ax.yaxis.set_ticks_position('left')     # 设定y轴在左边 


plt.plot(x,C,color='red')

plt.show()

Saving plots to a file

plt.savefig('plot123.png')
plt.show()

import matplotlib as mpl 
mpl.rcParams['figure.figsize']    # [8.0, 6.0] 
mpl.rcParams['savefig.dpi']       # 100 

plt.savefig('plot123_2.png', dpi=200) # 设置 dpi=200,分辨率变为1600×1200 
默认值: an 8x6 inches figure with 100 DPI results in an 800x600 pixels image

Decorate Graphs with Plot Styles

Markers and line styles

Marker就是指形成线的那些点。
plot通过第三个string参数可以用来指定Colors,Line styles,Marker styles
可以用string format单独或混合的表示所有的style,

y = np.arange(1, 3, 0.3) 
plt.plot(y, 'cx--', y+1, 'mo:', y+2, 'kp-.');
plt.show()

一般用string format已经足够,但也可以用具体的keyword参数进行更多的个性化:

Keyword argument  Description 

color or c       Sets the color of the line;accepts any matplotlib color.
linestyle        Sets the line style;accepts the line styles seen previously.
linewidth        Sets the line width;accepts a float value in points. 
marker           Sets the line marker style.
markeredgecolor  Sets the marker edge color;accepts any matplotlib color.
markerdegewidth  Sets the marker edge width;accpets float value in points.
markerfacecolor  Sets the marker face color;accpets any matplotlib color.
markersize       Sets the marker size in points;accepts float values.
标注和线条的颜色

character    color

'b'            blue
'g'            green
'r'            red
'c'            cyan
'm'            magenta
'y'            yellow
'k'            black
'w'            white

线的style

character    description

'-'           solid line style
'--'       dashed line style
'-.'       dash-dot line style
':'           dotted line style

Marker的style 

'.'           point marker
','           pixel marker
'o'           circle marker
'v'           triangle_down marker
'^'           triangle_up marker
'<'           triangle_left marker
'>'           triangle_right marker
'1'           tri_down marker
'2'           tri_up marker
'3'           tri_left marker
'4'           tri_right marker
's'           square marker
'p'           pentagon marker
'*'           star marker
'h'           hexagon1 marker
'H'           hexagon2 marker
'+'           plus marker
'x'           x marker
'D'           diamond marker
'd'           thin_diamond marker
'|'           vline marker
'_'           hline marker

Text inside figure, annotations, and arrows

增加text

plt.text(x, y, text)

x = np.arange(0, 2*np.pi, .01) 
y = np.sin(x) 
plt.plot(x, y); 
plt.text(0.1, -0.04, 'sin(0)=0');

annotate,标注一些感兴趣的点

参数
- xy,需要添加注释的坐标 
- xytext,注释本身的坐标 
- arrowprops,箭头的类型和属性

y = [13, 11, 13, 12, 13, 10, 30, 12, 11, 13, 12, 12, 12, 11,12] 
plt.plot(y); 
plt.ylim(ymax=35);         # 增大y的空间,否则注释放不下 
plt.annotate('this spot must really\nmean something', xy=(6, 30), xytext=(8, 31.5), arrowprops=dict(facecolor='black', shrink=0.05));

现在使用annotate命令注解一些我们感兴趣的点。我们选择2π/3作为我们想要注解的正弦和余弦值。我们将在曲线上做一个标记和一个垂直的虚线。然后,使用annotate命令来显示一个箭头和一些文本

t = 2*np.pi/3
plt.plot([t,t],[0,np.cos(t)], color ='blue', linewidth=2.5, linestyle="--")
plt.scatter([t,],[np.cos(t),], 50, color ='blue')

plt.annotate(r'$sin(\frac{2\pi}{3})=\frac{\sqrt{3}}{2}$',
         xy=(t, np.sin(t)), xycoords='data',
         xytext=(+10, +30), textcoords='offset points', fontsize=16,
         arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

plt.plot([t,t],[0,np.sin(t)], color ='red', linewidth=2.5, linestyle="--")
plt.scatter([t,],[np.sin(t),], 50, color ='red')

plt.annotate(r'$cos(\frac{2\pi}{3})=-\frac{1}{2}$',
         xy=(t, np.cos(t)), xycoords='data',
         xytext=(-90, -50), textcoords='offset points', fontsize=16,
         arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))


箭头种类

plt.axis([0, 10, 0, 20]); 
arrstyles = ['-', '->', '-[', '<-', '<->', 'fancy', 'simple','wedge'] 
for i, style in enumerate(arrstyles): 
      plt.annotate(style, xytext=(1, 2+2*i), xy=(4, 1+2*i), arrowprops=dict(arrowstyle=style));

connstyles=["arc", "arc,angleA=10,armA=30,rad=15", "arc3,rad=.2", "arc3,rad=-.2", "angle", "angle3"] 
for i, style in enumerate(connstyles): 
      plt.annotate("", xytext=(6, 2+2*i), xy=(8, 1+2*i), arrowprops=dict(arrowstyle='->', connectionstyle=style));

Subplots

matplotlib中,默认会帮我们创建figure和subplot

fig = plt.figure() 
ax = fig.add_subplot(111)

其实我们可以显式的创建,这样的好处是我们可以在一个figure中画多个subplot
其中subplot的参数,

fig.add_subplot(numrows, numcols, fignum) 
  - numrows 代表subplot的总行数
  - numcols 代表subplot的总列数
  - fignum 代表此subplot的编号,范围从1到 numrows*numcols
我们会产生numrows×numcols个subplot,fignum表示编号  

fig = plt.figure() 
ax1 = fig.add_subplot(211) 
ax1.plot([1, 2, 3], [1, 2, 3])  
ax2 = fig.add_subplot(212) 
ax2.plot([1, 2, 3], [3, 2, 1]) 

plt.show() 


fig = plt.figure()
fig.subplots_adjust(bottom=0.025, left=0.025, top = 0.975, right=0.975)

plt.subplot(2,1,1)
plt.xticks([]), plt.yticks([])

plt.subplot(2,3,4)
plt.xticks([]), plt.yticks([])

plt.subplot(2,3,5)
plt.xticks([]), plt.yticks([])

plt.subplot(2,3,6)
plt.xticks([]), plt.yticks([])

plt.show()

gridspec命令能够创建布局更为复杂的subplot。
axes和subplot非常相似,但是允许把图片放置到图像(figure)中的任何地方。
所以如果我们想要在一个大图片中嵌套一个小点的图片,我们通过axes来完成。

Plotting dates

日期比较长,直接画在坐标轴上,没法看
产生x轴数据,利用mpl.dates.drange产生x轴坐标

import datetime as dt 

date2_1 = dt.datetime(2008, 9, 23) 
date2_2 = dt.datetime(2008, 10, 3) 
delta2 = dt.timedelta(days=1) 
dates2 = mpl.dates.drange(date2_1, date2_2, delta2)

随机产生y轴坐标,画出polt图

y2 = np.random.rand(len(dates2)) 
fig = plt.figure() 
ax2 = fig.add_subplot(111) 
ax2.plot_date(dates2, y2, linestyle='-') 

关键步骤来了,我们要设置xaxis的locator和formatter来显示时间 
首先设置formatter,

dateFmt = mpl.dates.DateFormatter('%Y-%m-%d') 
ax2.xaxis.set_major_formatter(dateFmt)

再设置locator,

daysLoc = mpl.dates.DayLocator() 
hoursLoc = mpl.dates.HourLocator(interval=6) 
ax2.xaxis.set_major_locator(daysLoc) 
ax2.xaxis.set_minor_locator(hoursLoc)

注意这里major和minor,major就是大的tick,minor是比较小的tick(默认是null) 
比如date是大的tick,但是想看的细点,所以再设个hour的tick,但是画24个太多了,所以interval=6,只画4个 
而formatter只是设置major的,所以minor的是没有label的


再看个例子
产生x轴坐标,y轴坐标,画出plot

date1_1 = dt.datetime(2008, 9, 23) 
date1_2 = dt.datetime(2009, 2, 16) 
delta1 = dt.timedelta(days=10) 
dates1 = mpl.dates.drange(date1_1, date1_2, delta1) 
y1 = np.random.rand(len(dates1)) 
fig = plt.figure() 
ax1 = fig.add_subplot(111) 
ax1.plot_date(dates1, y1, linestyle='-') 

设置locator 
major的是Month,minor的是week

monthsLoc = mpl.dates.MonthLocator() 
weeksLoc = mpl.dates.WeekdayLocator() 
ax1.xaxis.set_major_locator(monthsLoc) 
ax1.xaxis.set_minor_locator(weeksLoc)

设置Formatter
monthsFmt = mpl.dates.DateFormatter('%b') 
ax1.xaxis.set_major_formatter(monthsFmt)

Using LaTeX formatting

支持显示LaTex的公式,需要用两个$包裹起来。
python raw string需要r””,表示不转义。

plt.plot([5, 3, 7, 2, 4, 1]) 

plt.text(2, 4, r'$-\pi$')
plt.show()

绘制水平线、竖直线

plt.axvline(x,ymin,ymax,**kwargs)   # 添加竖直线,ymin,ymax分别是起始和终止的比例值 
plt.axhline(y,xmin,xmax,**kwargs)   # 添加水平线,xmin,xmax分别是起始和终止的比例值

Plot types

上面介绍了很多,都是以plot作为例子,matplotlib还提供了很多其他类型的图
plot_sort.jpg
上面这张图很赞,描述所有图的用法

Bar charts

对于bar,需要设定3个参数
左起始坐标,高度,宽度(可选,默认0.8)

plt.bar([1, 2, 3], [3, 2, 5])    # 指定起始点和高度参数

看个复杂的例子,bar图一般用于比较多个数据值

data1 = 10*np.random.rand(5) 
data2 = 10*np.random.rand(5) 
data3 = 10*np.random.rand(5) 
e2 = 0.5 * np.abs(np.random.randn(len(data2))) 
locs = np.arange(1, len(data1)+1) 
width = 0.27 
plt.bar(locs, data1, width=width); 
plt.bar(locs+width, data2, yerr=e2, width=width, color='red') 
plt.bar(locs+2*width, data3, width=width, color='green') 
plt.xticks(locs + width*1.5, locs) 

plt.show()

需要学习的是,如何指定多个bar的起始位置,后一个bar的loc = 前一个bar的loc + width 
如何设置ticks的label,让它在一组bars的中间位置,locs + width*1.5

Scatter plots

只画点,不连线,用来描述两个变量之间的关系,比如在进行数据拟合之前,看看变量间是线性还是非线性

x = np.random.randn(1000) 
y = np.random.randn(1000) 
plt.scatter(x, y);

plt.show() 

通过s来指定size,c来指定color, marker来指定点的形状

size = 50*np.random.randn(1000) 
colors = np.random.rand(1000) 
plt.scatter(x, y, s=size, c=colors)

Pie charts

plt.figure(figsize=(3,3)); 
x = [45, 35, 20] 
labels = ['Cats', 'Dogs', 'Fishes'] 
plt.pie(x, labels=labels) 

来个复杂的, 
增加explode,即突出某些wedges,可以设置explode来增加offset the wedge from the center of the pie, 即radius fraction 
0表示不分离,越大表示离pie center越远,需要显式指定每个wedges的explode

增加autopct,即在wedges上显示出具体的比例

plt.figure(figsize=(3,3)); 
x = [4, 9, 21, 55, 30, 18] 
labels = ['Swiss', 'Austria', 'Spain', 'Italy', 'France', 'Benelux'] 
explode = [0.2, 0.1, 0, 0, 0.1, 0] 
plt.pie(x, labels=labels, explode=explode, autopct='%1.1f%%');

plt.show()

Histogram charts

直方图是用来离散的统计数据分布的,会把整个数据集,根据取值范围,分成若干类,称为bins,然后统计中每个bin中的数据个数
hist默认是分为10类,即bins=10

y = np.random.randn(1000) 
plt.hist(y); 
plt.show()

plt.hist(y, 25)   # 设定为25 个bins

Error bar charts

x = np.arange(0, 4, 0.2) 
y = np.exp(-x) 
e1 = 0.1 * np.abs(np.random.randn(len(y))) 
e2 = 0.1 * np.abs(np.random.randn(len(y))) 
plt.errorbar(x, y, yerr=e1, xerr=e2, fmt='.-', capsize=0) 

plt.errorbar(x, y, yerr=[e1, e2], fmt='.-')     # 非对称的误差

等高线图(contour plots)

需要使用contour, contourf, clabel命令。
其中,contour画等高线,contourf进行填充(f,filled),clabel进行标注

输入X,Y,Z均为矩阵(m*n),等高线图反映(x,y)处z值的大小。X Y Z 后面紧跟的这个数字代表等高线的数目,如果后面跟着 [z1,z2,z3]则对应画出这些Z值的等高线(levels)。
extend 的参数为 neither, both, min, max,表示等高线最内侧以内,或最外侧以外的部分是否填充。
cmap = mpl.cm.gray # 指定颜色的系列
alpha 值表示颜色透明度
linewidths linestyles 表示等高线的style

填充方案(不包含下面的边界) z1 < z <= z2

def f(x,y):
    return (1-x/2+x**5+y**3)*np.exp(-x**2-y**2)

n = 256
x = np.linspace(-3,3,n)
y = np.linspace(-3,3,n)
X,Y = np.meshgrid(x,y)


C = plt.contour(X, Y, f(X,Y), 8, colors='black', linewidth=.5)
# C = plt.contour(X, Y, f(X,Y), levels=[-0.4,0,0.4,0.8], colors='black', linewidth=.5)
CS=plt.contourf(X, Y, f(X,Y), 8, alpha=.75, cmap=plt.cm.hot)
plt.clabel(C, inline=1, fontsize=10)

plt.xticks([]), plt.yticks([])
# savefig('../figures/contour_ex.png',dpi=48)
plt.show()

from matplotlib.ticker import MaxNLocator
levels = MaxNLocator(nbins=15).tick_values(f(X,Y).min(), f(X,Y).max())  # 会自动设定不多于15个bin构成的levels 
C = plt.contour(X, Y, f(X,Y), levels=levels, colors='black', linewidth=.5)

添加 colorbar

cb = plt.colorbar(CS,orientation='horizontal',shrink=0.9)  # 默认 colorbar 是竖直放置,可将其设置为水平放置。shrink指定colorbar的相对长度
cb.set_label('meters')      # 设置 label 

## 对于超出colorbar边界的处理 
extend    [ ‘neither’ | ‘both’ | ‘min’ | ‘max’ ] 
If not ‘neither’, make pointed end(s) for out-of- range values. 
These are set for a given colormap using the colormap set_under and set_over methods.


## 调整colorbar  
ax = fig.add_subplot(224)  
cmap = mpl.cm.winter  
norm = mpl.colors.Normalize(vmin=-1, vmax=1)  
im=ax.imshow(data,cmap=cmap)  
plt.colorbar(im,cmap=cmap, norm=norm,ticks=[-1,0,1])

Imshow

输入为矩阵Z(mn),Imshow反映Z的元素的大小。
图中x,y轴的刻度为输入Z矩阵的index,颜色反映Z矩阵的元素的值。比如,Z为3
3维,则图中x,y对应的刻度为0 1 2,需要进一步调整为合适的范围。
插值方式需要注意
origin : [‘upper’ | ‘lower’], optional, default: None. Place the [0,0] index of the array in the upper left or lower left corner of the axes. 默认放在左上角
如果Z中有元素为None,则imshow将不会对此元素进行绘制。

def f(x,y):
    return (1-x/2+x**5+y**3)*np.exp(-x**2-y**2)

n = 200
x = np.linspace(-3,3,3.5*n)
y = np.linspace(-3,3,3.0*n)
X,Y = np.meshgrid(x,y)
Z = f(X,Y)

plt.imshow(Z,interpolation='nearest', cmap='bone', origin='lower')
plt.colorbar(shrink=.92)

plt.xticks([]), plt.yticks([])
# savefig('../figures/imshow_ex.png', dpi=48)
plt.show()

##
A1 = np.zeros((3,3))
A2 = np.zeros((3,3))
A1[1,1] = 1
A2[2,2] = 1

# Apply a mask to filter out unused values
A1[A1==0] = None
A2[A2==0] = None

# Use different colormaps for each layer
plt.imshow(A1,cmap=plt.cm.jet,interpolation='nearest')
plt.imshow(A2,cmap=plt.cm.hsv,interpolation='nearest')
plt.show()

pcolor 绘制2D map

输入为矩阵Z(mn),pcolor反映Z的元素的大小。
pcolormesh是pcolor的升级版。二者不仅能输入Z,还能输入X,Y,而imshow却不能输入X,Y。
pcolor与imshow的差异: Z为 `m
n维时,如果使用imshow只需输入Z即可;如果使用pcolor,可以选择只输入Z;也可以选择输入X,Y,Z,此时X,Y应当为(m+1)*(n+1)`维

fig=plt.figure()
ax=fig.add_subplot('111')

np.random.seed(0)
Z = np.random.randint(4, size=(4, 4))

a=arange(-2.0,2.001,1.0)  
b=arange(-2.0,2.001,1.0)  
x,y=meshgrid(a,b) 

#heatmap = plt.imshow(Z, cmap=matplotlib.cm.Blues,interpolation='None')

heatmap = plt.pcolormesh(Z, cmap=matplotlib.cm.Blues)

plt.show()

fill_between 命令填充图形

plt.fill(x,y,color='r')
plt.fill_between(x,y1,y2,color='g')
plt.fill_between(x,y1,y2,y2>0,color='g')

n = 256
X = np.linspace(-np.pi,np.pi,n,endpoint=True)
Y = np.sin(2*X)

plt.axes([0.025,0.025,0.95,0.95])

plt.plot (X, Y+1, color='blue', alpha=1.00)
plt.fill_between(X, 1, Y+1, color='blue', alpha=.25)

plt.plot (X, Y-1, color='blue', alpha=1.00)
plt.fill_between(X, -1, Y-1, (Y-1) > -1, color='blue', alpha=.25)
plt.fill_between(X, -1, Y-1, (Y-1) < -1, color='red',  alpha=.25)

plt.xlim(-np.pi,np.pi), plt.xticks([])
plt.ylim(-2.5,2.5), plt.yticks([])
# savefig('../figures/plot_ex.png',dpi=48)
plt.show()

三维绘图

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = Axes3D(fig)
X = np.arange(-4, 4, 0.25)
Y = np.arange(-4, 4, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.hot)
ax.contourf(X, Y, Z, zdir='z', offset=-2, cmap=plt.cm.hot)
ax.set_zlim(-2,2)

# savefig('../figures/plot3d_ex.png',dpi=48)
plt.show()

其它图表类型

实时绘图

需要导入 matplotlib.animation

Matplotlib 实例

2D图表

Matplotlib中最基础的模块是pyplot。先从最简单的点图和线图开始,比如我们有一组数据,还有一个拟合模型,通过下面的代码图来可视化:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# 通过rcParams设置全局横纵轴字体大小
mpl.rcParams['xtick.labelsize'] = 24
mpl.rcParams['ytick.labelsize'] = 24

np.random.seed(42)

# x轴的采样点
x = np.linspace(0, 5, 100)

# 通过下面曲线加上噪声生成数据,所以拟合模型就用y了……
y = 2*np.sin(x) + 0.3*x**2
y_data = y + np.random.normal(scale=0.3, size=100)

# figure()指定图表名称
plt.figure('data')

# '.'表明画散点图,每个散点的形状是个圆
plt.plot(x, y_data, '.')

# 画模型的图,plot函数默认画连线图
plt.figure('model')
plt.plot(x, y)

# 两个图画一起
plt.figure('data & model')

# 通过'k'指定线的颜色,lw指定线的宽度
# 第三个参数除了颜色也可以指定线形,比如'r--'表示红色虚线
# 更多属性可以参考官网:http://matplotlib.org/api/pyplot_api.html
plt.plot(x, y, 'k', lw=3)

# scatter可以更容易地生成散点图
plt.scatter(x, y_data)

# 将当前figure的图保存到文件result.png
plt.savefig('result.png')

# 一定要加上这句才能让画好的图显示在屏幕上
plt.show()

点和线图表只是最基本的用法,有的时候我们获取了分组数据要做对比,柱状或饼状类型的图或许更合适:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.titlesize'] = 20
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['xtick.major.size'] = 0
mpl.rcParams['ytick.major.size'] = 0

# 包含了狗,猫和猎豹的最高奔跑速度,还有对应的可视化颜色
speed_map = {
    'dog': (48, '#7199cf'),
    'cat': (45, '#4fc4aa'),
    'cheetah': (120, '#e1a7a2')
}

# 整体图的标题
fig = plt.figure('Bar chart & Pie chart')

# 在整张图上加入一个子图,121的意思是在一个1行2列的子图中的第一张
ax = fig.add_subplot(121)
ax.set_title('Running speed - bar chart')

# 生成x轴每个元素的位置
xticks = np.arange(3)

# 定义柱状图每个柱的宽度
bar_width = 0.5

# 动物名称
animals = speed_map.keys()

# 奔跑速度
speeds = [x[0] for x in speed_map.values()]

# 对应颜色
colors = [x[1] for x in speed_map.values()]

# 画柱状图,横轴是动物标签的位置,纵轴是速度,定义柱的宽度,同时设置柱的边缘为透明
bars = ax.bar(xticks, speeds, width=bar_width, edgecolor='none')

# 设置y轴的标题
ax.set_ylabel('Speed(km/h)')

# x轴每个标签的具体位置,设置为每个柱的中央
ax.set_xticks(xticks+bar_width/2)

# 设置每个标签的名字
ax.set_xticklabels(animals)

# 设置x轴的范围
ax.set_xlim([bar_width/2-0.5, 3-bar_width/2])

# 设置y轴的范围
ax.set_ylim([0, 125])

# 给每个bar分配指定的颜色
for bar, color in zip(bars, colors):
    bar.set_color(color)

# 在122位置加入新的图
ax = fig.add_subplot(122)
ax.set_title('Running speed - pie chart')

# 生成同时包含名称和速度的标签
labels = ['{}\n{} km/h'.format(animal, speed) for animal, speed in zip(animals, speeds)]

# 画饼状图,并指定标签和对应颜色
ax.pie(speeds, labels=labels, colors=colors)

plt.show()

3D图表

Matplotlib中也能支持一些基础的3D图表,比如曲面图,散点图和柱状图。这些3D图表需要使用mpl_toolkits模块,先来看一个简单的曲面图的例子:

import matplotlib.pyplot as plt
import numpy as np

# 3D图标必须的模块,project='3d'的定义
from mpl_toolkits.mplot3d import Axes3D     

np.random.seed(42)

n_grids = 51            # x-y平面的格点数 
c = n_grids / 2         # 中心位置
nf = 2                  # 低频成分的个数

# 生成格点
x = np.linspace(0, 1, n_grids)
y = np.linspace(0, 1, n_grids)

# x和y是长度为n_grids的array
# meshgrid会把x和y组合成n_grids*n_grids的array,X和Y对应位置就是所有格点的坐标
X, Y = np.meshgrid(x, y)

# 生成一个0值的傅里叶谱
spectrum = np.zeros((n_grids, n_grids), dtype=np.complex)

# 生成一段噪音,长度是(2*nf+1)**2/2
noise = [np.complex(x, y) for x, y in np.random.uniform(-1,1,((2*nf+1)**2/2, 2))]

# 傅里叶频谱的每一项和其共轭关于中心对称
noisy_block = np.concatenate((noise, [0j], np.conjugate(noise[::-1])))

# 将生成的频谱作为低频成分
spectrum[c-nf:c+nf+1, c-nf:c+nf+1] = noisy_block.reshape((2*nf+1, 2*nf+1))

# 进行反傅里叶变换
Z = np.real(np.fft.ifft2(np.fft.ifftshift(spectrum)))

# 创建图表
fig = plt.figure('3D surface & wire')

# 第一个子图,surface图
ax = fig.add_subplot(1, 2, 1, projection='3d')

# alpha定义透明度,cmap是color map
# rstride和cstride是两个方向上的采样,越小越精细,lw是线宽
ax.plot_surface(X, Y, Z, alpha=0.7, cmap='jet', rstride=1, cstride=1, lw=0)

# 第二个子图,网线图
ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot_wireframe(X, Y, Z, rstride=3, cstride=3, lw=0.5)

plt.show()

这个例子中先生成一个所有值均为0的复数array作为初始频谱,然后把频谱中央部分用随机生成,但同时共轭关于中心对称的子矩阵进行填充。这相当于只有低频成分的一个随机频谱。最后进行反傅里叶变换就得到一个随机波动的曲面

3D的散点图也是常常用来查看空间样本分布的一种手段,并且画起来比表面图和网线图更加简单,来看例子:

import matplotlib.pyplot as plt
import numpy as np

from mpl_toolkits.mplot3d import Axes3D

np.random.seed(42)

# 采样个数500
n_samples = 500
dim = 3

# 先生成一组3维正态分布数据,数据方向完全随机
samples = np.random.multivariate_normal(
    np.zeros(dim),
    np.eye(dim),
    n_samples
)

# 通过把每个样本到原点距离和均匀分布吻合得到球体内均匀分布的样本
for i in range(samples.shape[0]):
    r = np.power(np.random.random(), 1.0/3.0)
    samples[i] *= r / np.linalg.norm(samples[i])

upper_samples = []
lower_samples = []

for x, y, z in samples:
    # 3x+2y-z=1作为判别平面
    if z > 3*x + 2*y - 1:
        upper_samples.append((x, y, z))
    else:
        lower_samples.append((x, y, z))

fig = plt.figure('3D scatter plot')
ax = fig.add_subplot(111, projection='3d')

uppers = np.array(upper_samples)
lowers = np.array(lower_samples)

# 用不同颜色不同形状的图标表示平面上下的样本
# 判别平面上半部分为红色圆点,下半部分为绿色三角
ax.scatter(uppers[:, 0], uppers[:, 1], uppers[:, 2], c='r', marker='o')
ax.scatter(lowers[:, 0], lowers[:, 1], lowers[:, 2], c='g', marker='^')

plt.show()

这个例子中,为了方便,直接先采样了一堆3维的正态分布样本,保证方向上的均匀性。然后归一化,让每个样本到原点的距离为1,相当于得到了一个均匀分布在球面上的样本。再接着把每个样本都乘上一个均匀分布随机数的开3次方,这样就得到了在球体内均匀分布的样本,最后根据判别平面3x+2y-z-1=0对平面两侧样本用不同的形状和颜色画出

图像显示

Matplotlib也支持图像的存取和显示,并且和OpenCV一类的接口比起来,对于一般的二维矩阵的可视化要方便很多,来看例子:

import matplotlib.pyplot as plt

# 读取一张小白狗的照片并显示
plt.figure('A Little White Dog')
little_dog_img = plt.imread('little_white_dog.jpg')
plt.imshow(little_dog_img)

# Z是上小节生成的随机图案,img0就是Z,img1是Z做了个简单的变换
img0 = Z
img1 = 3*Z + 4

# cmap指定为'gray'用来显示灰度图
fig = plt.figure('Auto Normalized Visualization')
ax0 = fig.add_subplot(121)
ax0.imshow(img0, cmap='gray')

ax1 = fig.add_subplot(122)
ax1.imshow(img1, cmap='gray')

plt.show()

这段代码中第一个例子是读取一个本地图片并显示,第二个例子中直接把上小节中反傅里叶变换生成的矩阵作为图像拿过来,原图和经过乘以3再加4变换的图直接绘制了两个形状一样,但是值的范围不一样的图案。显示的时候imshow会自动进行归一化,把最亮的值显示为纯白,最暗的值显示为纯黑。这是一种非常方便的设定,尤其是查看深度学习中某个卷积层的响应图时。

基本设置

plt.close('all')     # 关闭当前所有窗口

# 到了需要绘图时敲入以下代码

# 生成一个3.2 * 2.8 inch的图片
fig=plt.figure(figsize=(3.2, 2.8),frameon=True)
ax=fig.add_subplot(111)

# 设置背景透明度,0为全透明,1为不透明
fig.patch.set_alpha(0.5)

# 调整tick的字体大小 
plt.rc('xtick', labelsize = 9)
plt.rc('ytick', labelsize = 9)

# 调整x坐标tick的数目,这里显示5个坐标值
plt.locator_params(axis = 'x', nbins = 5)
plt.locator_params(axis = 'y', nbins = 5)

# 设置坐标轴label,图片title  
plt.xlabel('XXXXX', fontsize = 10)    
plt.ylabel('XXXXXX', fontsize = 10)   
plt.title('XXXXX', fontsize = 10)     


# 调整legend的位置,字体大小和边框。
plt.legend(loc='best', fontsize = 9, frameon=False) # 字体大小除了可以用数字外,还可以用预设的值 'small', 'x-small', 'xx-small'等值
plt.tight_layout()   # 紧凑显示图片,居中显示


# 坐标tick默认为科学计数格式,但对于小于10^6的数,还是将其全部展开显示,导致数字很长,影响美观。进行修改
ax1 = plt.gca()  # 获取当前图像的坐标轴信息
ax1.yaxis.get_major_formatter().set_powerlimits((0,1))  # 将坐标轴的base number设置为一位。

实例

import matplotlib.pyplot as plt

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax2 = ax1.twiny()         # 创建一个 twin axes,使用不同的x轴刻度,共用同一个y轴刻度。

ax1.set_ylabel("ns/day")

cpu_cs = [20, 40, 60, 80, 100, 120]
cpu_per = [5.383, 8.282, 10.384, 16.545, 13.954, 14.223]
ax1.plot(cpu_cs, cpu_per, 'ro-', label='cpu')
my_cpu_cs = [20, 24]
my_cpu_per = [3.978, 4.294]
ax1.plot(my_cpu_cs, my_cpu_per, 'go-', label='my_cpu')
aliyun_cs = [16, 32, 48, 64, 80, 96]
aliyun_per = [3.697, 6.294, 7.383, 9.554, 16.324, 13.954]
ax1.plot(aliyun_cs, aliyun_per, 'yo-', label='aliyun')
ax1.set_xlabel(r"cpu cores")
ax1.axis([0, 140, 0, 25])
ax1.legend(bbox_to_anchor=(1.05, 1), loc=2)    # bbox_to_anchor 指定legend偏移的位置 

gpu_cs = [1, 2, 4, 5]
gpu_per = [3.514, 10.663, 14.770, 22.775]
ax2.plot(gpu_cs, gpu_per, 'b^-', label='gpu')
ax2.axis([0, 6, 0, 25])
ax2.set_xlabel(r"gpu cores")
ax2.legend(bbox_to_anchor=(1.05, 1), loc=3)

seaborn plotting

seaborn plotting


Sources: