EM(最大期望)算法推導、GMM的應用與代碼實現

  EM算法是一種迭代算法,用於含有隱變量的概率模型參數的極大似然估計。

使用EM算法的原因

  首先舉李航老師《統計學習方法》中的例子來說明為什麼要用EM算法估計含有隱變量的概率模型參數。

  假設有三枚硬幣,分別記作A, B, C。這些硬幣正面出現的概率分別是$\pi,p,q$。進行如下擲硬幣試驗:先擲硬幣A,根據其結果選出硬幣B或C,正面選硬幣B,反面邊硬幣C;然後擲選出的硬幣,擲硬幣的結果出現正面記作1,反面記作0;獨立地重複$n$次試驗,觀測結果為$\{y_1,y_2,…,y_n\}$。問三硬幣出現正面的概率。

  三硬幣模型(也就是第二枚硬幣正反面的概率)可以寫作

$ \begin{aligned} &P(y|\pi,p,q) \\ =&\sum\limits_z P(y,z|\pi,p,q)\\ =&\sum\limits_z P(y|z,\pi,p,q)P(z|\pi,p,q)\\ =&\pi p^y(1-p)^{1-y}+(1-\pi)q^y(1-q)^{1-y} \end{aligned} $

  其中$z$表示硬幣A的結果,也就是前面說的隱變量。通常我們直接使用極大似然估計,即最大化似然函數

$ \begin{aligned} &\max\limits_{\pi,p,q}\prod\limits_{i=1}^n P(y_i|\pi,p,q) \\ =&\max\limits_{\pi,p,q}\prod\limits_{i=1}^n[\pi p^{y_i}(1-p)^{1-y_i}+(1-\pi)q^{y_i}(1-q)^{1-y_i}]\\ =&\max\limits_{\pi,p,q}\sum\limits_{i=1}^n\log[\pi p^{y_i}(1-p)^{1-y_i}+(1-\pi)q^{y_i}(1-q)^{1-y_i}]\\ =&\max\limits_{\pi,p,q}L(\pi,p,q) \end{aligned} $

  分別對$\pi,p,q$求偏導並等於0,求解線性方程組來估計這三個參數。但是,由於它是帶有隱變量的,在獲取最終的隨機變量之前有一個分支選擇的過程,導致這個$\log$的內部是加和的形式,計算導數十分困難,而待求解的方程組不是線性方程組。當複雜度一高,解這種方程組幾乎成為不可能的事。以下推導EM算法,它以迭代的方式來求解這些參數,應該也算一種貪心吧。

算法導出與理解

  對於參數為$\theta$且含有隱變量$Z$的概率模型,進行$n$次抽樣。假設隨機變量$Y$的觀察值為$\mathcal{Y} = \{y_1,y_2,…,y_n\}$,隱變量$Z$的$m$個可能的取值為$\mathcal{Z}=\{z_1,z_2,…,z_m\}$。

  寫出似然函數:

$ \begin{aligned} L(\theta) &= \sum\limits_{Y\in\mathcal{Y}}\log P(Y|\theta)\\ &=\sum\limits_{Y\in\mathcal{Y}}\log \sum\limits_{Z\in \mathcal{Z}} P(Y,Z|\theta)\\ \end{aligned} $

  EM算法首先初始化參數$\theta = \theta^0$,然後每一步迭代都會使似然函數增大,即$L(\theta^{k+1})\ge L(\theta^k)$。如何做到不斷變大呢?考慮迭代前的似然函數(為了方便不用$\theta^{k+1}$):

$ \begin{gather} \begin{aligned} L(\theta)=&\sum\limits_{Y\in \mathcal{Y}} \log\sum\limits_{Z\in \mathcal{Z}} P(Y,Z|\theta)\\ =&\sum\limits_{Y\in \mathcal{Y}} \log\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\frac{P(Y,Z|\theta)}{P(Z|Y,\theta^k)}\\ \end{aligned} \label{} \end{gather} $

  至於上式的第二個等式為什麼取出$P(Z|Y,\theta^k)$而不是別的,正向的原因我想不出來,馬後炮原因在後面記錄。

  考慮其中的求和

$ \sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)=1$

  且由於$\log$函數是凹函數,因此由Jenson不等式得

$ \begin{gather} \begin{aligned} L(\theta) \ge&\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\log\frac{P(Y,Z|\theta)}{P(Z|Y,\theta^k)}\\ =&B(\theta,\theta^k) \end{aligned}\label{} \end{gather} $

  當$\theta = \theta^k$時,有

$ \begin{gather} \begin{aligned} L(\theta^k) \ge& B(\theta^k,\theta^k)\\ =&\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\log\frac{P(Y,Z|\theta^k)}{P(Z|Y,\theta^k)}\\ =&\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\log P(Y|\theta^k)\\ =&\sum\limits_{Y\in \mathcal{Y}}\log P(Y|\theta^k)\\ =&L(\theta^k)\\ \end{aligned} \label{} \end{gather} $

  也就是在這時,$(2)$式取等,即$L(\theta^k) = B(\theta^k,\theta^k)$。取

$ \begin{gather} \theta^*=\text{arg}\max\limits_{\theta}B(\theta,\theta^k)\label{} \end{gather} $

  可得不等式

$L(\theta^*)\ge B(\theta^*,\theta^k)\ge B(\theta^k,\theta^k) = L(\theta^k)$

  所以,我們只要優化$(4)$式,讓$\theta^{k+1} = \theta^*$,即可保證每次迭代的非遞減勢頭,有$L(\theta^{k+1})\ge L(\theta^k)$。而由於似然函數是概率乘積的對數,一定有$L(\theta) < 0$,所以迭代有上界並且會收斂。以下是《統計學習方法》中EM算法一次迭代的示意圖:

  進一步簡化$(4)$式,去掉優化無關項:

$ \begin{aligned} \theta^*=&\text{arg}\max\limits_{\theta}B(\theta,\theta^k) \\ =&\text{arg}\max\limits_{\theta}\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\log\frac{P(Y,Z|\theta)}{P(Z|Y,\theta^k)} \\ =&\text{arg}\max\limits_{\theta}\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\log P(Y,Z|\theta) \\ =&\text{arg}\max\limits_{\theta}Q(\theta,\theta^k) \\ \end{aligned} $

  $Q$函數使用導數求極值的方程與沒有隱變量的方程類似,容易求解。

  綜上,EM算法的流程為:

  1. 設置$\theta^0$的初值。EM算法對初值是敏感的,不同初值迭代出來的結果可能不同。

  2. 更新$\theta^k = \text{arg}\max\limits_{\theta}Q(\theta,\theta^{k-1})$。理解上來說,通常將這一步分為計算$Q$與極大化$Q$兩步,即求期望E與求極大M,但在代碼中並不會將它們分出來,因此這裏濃縮為一步。另外,如果這個優化很難計算的話,因為有不等式的保證,直接取$\theta^k$為某個$\hat{\theta}$,只要有$Q(\hat{\theta},\theta^{k-1})\ge Q(\theta^{k-1},\theta^{k-1})$即可。

  3. 比較$\theta^k$與$\theta^{k-1}$的差異,比如求它們的差的二范數,若小於一定閾值就結束迭代,否則重複步驟2。

  下面記錄一下我對$(1)$式取出$P(Z|Y,\theta^k)$而不取別的$P$的理解:

  經過以上的推導,我認為這是為了給不等式取等創造條件。如果不能確定$L(\theta^k)$與$Q(\theta^k,\theta^k)$能否取等,那麼取$Q$的最大值$Q(\theta^*,\theta^k)$時,儘管有$Q(\theta^*,\theta^k)\ge Q(\theta^k,\theta^k)$,但並不能保證$L(\theta^*)\ge L(\theta^k)$,迭代的不減性質就就沒了。

  我這裏暫且把它看做一種巧合,是研究EM算法的大佬,碰巧想用Jenson不等式來迭代而構造出來的一種做法。本人段位還太弱,無法正向理解其中的緣故,只能以這種方式來揣度大佬的思路了。知乎大佬發的EM算法九層理解(點擊鏈接),我當前只能到第3層,有時間一定要拜讀一下深度學習之父的著作。

高斯混合模型的應用

迭代式推導

  假設高斯混合模型混合了$m$個高斯分佈,參數為$\theta = (\alpha_1,\theta_1,\alpha_2,\theta_2,…,\alpha_m,\theta_m),\theta_i=(\mu_i,\sigma_i)$則整個概率分佈為:

$\displaystyle P(y|\theta) = \sum\limits_{i=1}^m\alpha_i \phi(y|\theta_i) =  \sum\limits_{i=1}^m\frac{\alpha_i }{\sqrt{2\pi}\sigma_i}\exp\left(-\frac{(y-\mu_i)^2}{2\sigma_i^2}\right),\;\text{where}\;\sum\limits_{j=1}^m\alpha_j = 1$

  對混合分佈抽樣$n$次得到$\{y_1,…,y_n\}$,則在第$k+1$次迭代,待優化式為:

$\begin{gather}\begin{aligned} &\max\limits_{\theta}Q(\theta,\theta^k) \\ =&\max\limits_{\theta}\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta^k)\log P(Y,Z|\theta) \\ =&\max\limits_{\theta}\sum\limits_{Y\in \mathcal{Y}}\sum\limits_{Z\in \mathcal{Z}} \frac{P(Z,Y|\theta^k)}{P(Y|\theta^k)}\log P(Y,Z|\theta) \\ =&\max\limits_{\theta}\sum\limits_{i=1}^n\sum\limits_{j=1}^m \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)} \log \left[\alpha_j\phi(y_i|\theta_j)\right] \\ =&\max\limits_{\theta}\sum\limits_{i=1}^n\sum\limits_{j=1}^m \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)} \log \left[ \frac{\alpha_j}{\sqrt{2\pi}\sigma_j}\exp\left(-\frac{(y_i-\mu_j)^2}{2\sigma_j^2}\right) \right]\\ =&\max\limits_{\theta}\sum\limits_{j=1}^m \sum\limits_{i=1}^n \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)} \left[ \log \alpha_j – \log \sigma_j-\frac{(y_i-\mu_j)^2}{2\sigma_j^2} \right]\\  \end{aligned} \label{}\end{gather}$

計算α

  定義

$\displaystyle n_j = \sum\limits_{i=1}^n \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)}$

  則對於$\alpha$,優化式為

$\begin{gather} \begin{aligned} \max\limits_{\alpha}\sum\limits_{j=1}^m n_j \log \alpha_j \end{aligned} \label{}\end{gather}$

  又因為$\sum\limits_{j=1}^m \alpha_j=1$,所以只需優化$m-1$個參數,上式變為:

$ \max\limits_\alpha \left[ \begin{matrix} n_1&n_2&\cdots &n_{m-1}&n_{m}\\ \end{matrix} \right] \cdot \left[ \begin{matrix} \log\alpha_1\\ \log\alpha_2\\ \vdots\\ \log\alpha_{m-1}\\ \log(1-\alpha_1-\cdots-\alpha_{m-1})\\ \end{matrix} \right] $

  對每個$\alpha_j$求導並等於0,得到線性方程組:

$\left[\begin{matrix}n_1+n_m&n_1&n_1&\cdots&n_1\\n_2&n_2+n_m&n_2&\cdots&n_2\\n_3&n_3&n_3+n_m&\cdots&n_3\\&&&\vdots&\\n_{m-1}&n_{m-1}&n_{m-1}&\cdots&n_{m-1}+n_m\\\end{matrix}\right]\cdot\left[\begin{matrix}\alpha_1\\\alpha_2\\\alpha_3\\\vdots\\\alpha_{m-1}\\\end{matrix}\right]=\left[\begin{matrix}n_1\\n_2\\n_3\\\vdots\\n_{m-1}\\\end{matrix}\right]$

  求解這個爪形線性方程組,得到

$\left[\begin{matrix}\sum_{j=1}^mn_j/n_1&0&0&\cdots&0\\-n_2/n_1&1&0&\cdots&0\\-n_3/n_1&0&1&\cdots&0\\&&&\vdots&\\-n_{m-1}/n_1&0&0&\cdots&1\\\end{matrix}\right]\cdot\left[\begin{matrix}\alpha_1\\\alpha_2\\\alpha_3\\\vdots\\\alpha_{m-1}\\\end{matrix}\right]=\left[\begin{matrix}1\\0\\0\\\vdots\\0\\\end{matrix}\right]$

  因為

$\displaystyle \sum\limits_{j=1}^m n_j =   \sum\limits_{j=1}^m\sum\limits_{i=1}^n \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)}=\sum\limits_{i=1}^n \sum\limits_{j=1}^m \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)} =\sum\limits_{i=1}^n 1 =  n$

  解得

$\displaystyle\alpha_j = \frac{n_j}{n} = \frac{1}{n}\sum\limits_{i=1}^n \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)}$

計算σ與μ

  與$\alpha$不同,它的方程組是所有$\alpha_j$之間聯立的;而$\sigma,\mu$的方程組則是$\sigma_j$與$\mu_j$之間聯立的。定義

$\displaystyle p_{ji} = \frac{\alpha_j^k\phi(y_i|\theta_j^k)} {\sum\limits_{l=1}^m \alpha_l^k\phi(y_i|\theta_l^k)}$

  則對於$\sigma_j,\mu_j$,優化式為(比較$(6),(7)$式的區別)

$\begin{gather}\displaystyle\min\limits_{\sigma_j,\mu_j}\sum\limits_{i=1}^n p_{ji} \left(\log \sigma_j+\frac{(y_i-\mu_j)^2}{2\sigma_j^2} \right)\label{}\end{gather}$

  對上式求導等於0,解得

$ \begin{aligned} &\mu_j = \frac{\sum\limits_{i=1}^np_{ji}y_i}{\sum\limits_{i=1}^np_{ji}} = \frac{\sum\limits_{i=1}^np_{ji}y_i}{n_j} = \frac{\sum\limits_{i=1}^np_{ji}y_i}{n\alpha_j}\\ &\sigma^2_j = \frac{\sum\limits_{i=1}^np_{ji}(y_i-\mu_j)^2}{\sum\limits_{i=1}^np_{ji}} = \frac{\sum\limits_{i=1}^np_{ji}(y_i-\mu_j)^2}{n_j} = \frac{\sum\limits_{i=1}^np_{ji}(y_i-\mu_j)^2}{n\alpha_j} \end{aligned} $

代碼實現

  對於概率密度為$P(x) = −2x+2,x\in (0,1)$的隨機變量,以下代碼實現GMM對這一概率密度的的擬合。共10000個抽樣,GMM混合了100個高斯分佈。

#%%定義參數、函數、抽樣
import numpy as np
import matplotlib.pyplot as plt

dis_num = 100 #用於擬合的分佈數量
sample_num = 10000 #用於擬合的分佈數量
alphas = np.random.rand(dis_num) 
alphas /= np.sum(alphas)  
mus = np.random.rand(dis_num)
sigmas = np.random.rand(dis_num)**2#方差,不是標準差
samples = 1-(1-np.random.rand(sample_num))**0.5 #樣本
C_pi = (2*np.pi)**0.5

dis_val = np.zeros([sample_num,dis_num])    #每個樣本在每個分佈成員上都有值,形成一個sample_num*dis_num的矩陣
pij = np.zeros([sample_num,dis_num])        #pij矩陣
def calc_dis_val(sample,alpha,mu,sigma,c_pi):
    return alpha*np.exp(-(sample[:,np.newaxis]-mu)**2/(2*sigma))/(c_pi*sigma**0.5) 
def calc_pij(dis_v):  
    return dis_v / dis_v.sum(axis = 1)[:,np.newaxis]      
#%%優化 
for i in range(1000):
    print(i)
    dis_val = calc_dis_val(samples,alphas,mus,sigmas,C_pi)
    pij = calc_pij(dis_val)  
    nj = pij.sum(axis = 0)
    alphas_before = alphas
    alphas = nj / sample_num
    mus = (pij*samples[:,np.newaxis]).sum(axis=0)/nj
    sigmas = (pij*(samples[:,np.newaxis] - mus)**2 ).sum(axis=0)/nj
    a = np.linalg.norm(alphas_before - alphas)
    print(a)
    if  a< 0.001:
        break

#%%繪圖 
plt.rcParams['font.sans-serif']=['SimHei'] #用來正常显示中文標籤
plt.rcParams['axes.unicode_minus']=False #用來正常显示負號
def get_dis_val(x,alpha,sigma,mu,c_pi):
    y = np.zeros([len(x)]) 
    for a,s,m in zip(alpha,sigma,mu):   
        y += a*np.exp(-(x-m)**2/(2*s))/(c_pi*s**0.5)   
    return y
def paint(alpha,sigma,mu,c_pi,samples):
    x = np.linspace(-1,2,500)
    y = get_dis_val(x,alpha,sigma,mu,c_pi) 
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.hist(samples,density = True,label = '抽樣分佈') 
    ax.plot(x,y,label = "擬合的概率密度")
    ax.legend(loc = 'best')
    plt.show()
paint(alphas,sigmas,mus,C_pi,samples)

  以下是擬合結果圖,有點像是核函數估計,但是完全不同:

EM算法的推廣

  EM算法的推廣是對EM算法的另一種解釋,最終的結論是一樣的,它可以使我們對EM算法的理解更加深入。它也解釋了我在$(1)$式下方提出的疑問:為什麼取出$P(Z|Y,\theta^k)$而不是別的。

  定義$F$函數,即所謂Free energy自由能(自由能具體是啥先不研究了):

$ \begin{aligned} F(\tilde{P},\theta) &= E_{\tilde{P}}(\log P(Y,Z|\theta)) + H(\tilde{P})\\ &= \sum\limits_{Z\in \mathcal{Z}} \tilde{P}(Z)\log P(Y,Z|\theta) – \sum\limits_{Z\in \mathcal{Z}} \tilde{P}(Z)\log \tilde{P}(Z)\\ \end{aligned} $

  其中$\tilde{P}$是$Z$的某個概率分佈(不一定是單獨的分佈,可能是在某個條件下的分佈),$E_{\tilde{P}}$表示分佈$\tilde{P}$下的期望,$H$表示信息熵。

  我們計算一下,對於固定的$\theta$,什麼樣的$\tilde{P}$會使$F(\tilde{P},\theta) $最大。也就是找到一個函數$\tilde{P}_{\theta}$,使$F$極大,寫成優化的形式就是(這裡是找函數而不是找參數哦,理解上可能要用到泛函分析的內容):

$ \begin{aligned} &\max\limits_{\tilde{P}} \sum\limits_{Z\in \mathcal{Z}} \tilde{P}(Z)\log P(Y,Z|\theta) – \sum\limits_{Z\in \mathcal{Z}} \tilde{P}(Z)\log \tilde{P}(Z)\\ &\;\text{s.t.}\; \sum\limits_{Z\in \mathcal{Z}}\tilde{P}(Z) = 1 \end{aligned} $

  拉格朗日函數(拉格朗日對偶性,點擊鏈接)為:

$ \begin{aligned} L =  \sum\limits_{Z\in \mathcal{Z}} \tilde{P}(Z)\log P(Y,Z|\theta) – \sum\limits_{Z\in \mathcal{Z}} \tilde{P}(Z)\log \tilde{P}(Z)+ \lambda\left(1-\sum\limits_{Z\in \mathcal{Z}}\tilde{P}(Z)\right) \end{aligned} $

  因為每個$\tilde{P}(Z)$之間都是求和,沒有其它其它諸如乘積的操作,所以可以直接令$L$對某個$\tilde{P}(Z)$求導等於$0$來計算極值:

$ \begin{aligned} \frac{\partial L}{\partial \tilde{P}(Z)} = \log P(Y,Z|\theta) – \log \tilde{P}(Z) -1 -\lambda = 0 \end{aligned} $

  於是可以推出:

$ \begin{aligned} P(Y,Z|\theta) = e^{1+\lambda}\tilde{P}(Z) \end{aligned} $

  又由約束$\sum\limits_{Z\in \mathcal{Z}}\tilde{P}(Z) = 1$:

$P(Y|\theta) = e^{1+\lambda}$

  於是得到

$\begin{gather}\tilde{P}_{\theta}(Z) = P(Z|Y,\theta)\label{}\end{gather}$

  代回$F(\tilde{P},\theta)$,得到

$ \begin{aligned} F(\tilde{P}_\theta,\theta) &= \sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta)\log P(Y,Z|\theta) – \sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta)\log P(Z|Y,\theta)\\ &= \sum\limits_{Z\in \mathcal{Z}} P(Z|Y,\theta)\log \frac{P(Y,Z|\theta)}{P(Z|Y,\theta)}\\ &= \log P(Y|\theta)\\ \end{aligned} $

  也就是說,對$F$關於$\tilde{P}$進行最大化后,$F$就是待求分佈的對數似然;然後再關於$\theta$最大化,也就算得了最終要估計的參數$\hat{\theta}$。所以,EM算法也可以解釋為$F$的極大-極大算法。優化結果$(8)$式也解釋了我之前在$(1)$式下方的提問。

  那麼,怎麼使用$F$函數進行估計呢?還是要用迭代來算,迭代方式是和前面介紹的一樣的(懶得記錄了,統計學習方法上直接看吧)。實際上,$F$函數的方法只是提供了EM算法的另一種解釋,具體方法上並沒有提升之處。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

網頁設計公司推薦不同的風格,搶佔消費者視覺第一線

※想知道購買電動車哪裡補助最多?台中電動車補助資訊懶人包彙整

南投搬家公司費用,距離,噸數怎麼算?達人教你簡易估價知識!

※教你寫出一流的銷售文案?

※超省錢租車方案

重識Java8函數式編程

前言

最近真的是太忙忙忙忙忙了,很久沒有更新文章了。最近工作中看到了幾段關於函數式編程的代碼,但是有點費解,於是就準備總結一下函數式編程。很多東西很簡單,但是如果不總結,可能會被它的各種變體所困擾。接觸Lambda表達式已經很久了,但是也一直是處於照葫蘆畫瓢的階段,所以想自己去編寫相關代碼,也有些捉襟見肘。

1. Lambda表達式的不同形式

// 基本形式
參數 -> 主體

1.1 形式一

Runnable noArguments = () -> System.out.println("Hello World");

該形式的Lambda表達式不包含參數,使用空括號()表示沒有參數。它實現了Runnable接口,該接口也只有一個run方法,沒有桉樹,且返回類型為void。

1.2 形式二

ActionListener oneArgument = event -> System.out.println("button clicked");

該形式的Lambda表達式包含且只包含一個參數,可省略參數的符號。

1.3 形式三

Runnable multiStatement = () -> {
	System.out.print("Hello"); 
    System.out.println(" World"); 
};

Lambda表達式的主體不僅可以使一個表達式,而且也可以是一段代碼塊,使用大括號{}將代碼塊括起來。該代碼塊和普通方法遵循的規則別無二致,可以用返回或拋出異常來退出。只有以行代碼的Lambda表達式也可以使用大括號,用以明確Lambda表達式從何處開始,到哪裡結束。

1.4 形式四

BinaryOperator<Long> add = (x, y) -> x + y;

Lambda表達式也可以表示包含多個參數的方法,上面的Lambda表達式並不是將兩個数字相加,而是創建了一個函數,用來計算兩個数字相加的結果。變量add的類型時BinaryOperator ,它不是兩個数字的和,而是將兩個数字相加的那行代碼。

1.5 形式五

BinaryOperator<Long> addExplicit = (Long x, Long y) -> x + y;

到目前為止,所有Lambda表達式中的參數類型都是由編譯器推斷得出的。但有時最好也可以显示聲明參數類型,此時就需要使用小括號將參數括起來,多個參數的情況也是如此。

2. 引用值,而不是變量

如果你曾使用過匿名內部類,也許遇到過這樣的情況:需要引用它所在方法里的變量。這是,需要將變量聲明為final。

final String name = getUserName(); 
button.addActionListener(new ActionListener() {
	public void actionPerformed(ActionEvent event) { 
        System.out.println("hi " + name); 
    } 
});

將變量聲明為 final,意味着不能為其重複賦 值。同時也意味着在使用 final 變量時,實際上是在使用賦給該變量的一個特定的值。

Java 8 雖然放鬆了這一限制,可以引用非 final 變量,但是該變量在既成事實上必須是 final(意思就是你不能再次對該變量賦值)。雖然無需將變量聲明為 final,但在 Lambda 表達式中,也無法用作非終態變量。如 果堅持用作非終態變量,編譯器就會報錯。 既成事實上的 final 是指只能給該變量賦值一次。換句話說,Lambda 表達式引用的是值, 而不是變量。

例如:

String name = getUserName(); 
button.addActionListener(event -> System.out.println("hi " + name));

3. 函數接口

在 Java 里,所有方法參數都有固定的類型。假設將数字 3 作為參數傳給一個方法,則參數 的類型是 int。那麼,Lambda 表達式的類型又是什麼呢?

使用只有一個方法的接口來表示某特定方法並反覆使用,是很早就有的習慣。使用 Swing 編寫過用戶界面的人對這種方式都不陌生,這裏無需再標新立異,Lambda 表達式也使用同樣的技巧,並將這種接口稱為函數接口。

接口中單一方法的命名並不重要,只要方法簽名和 Lambda 表達式的類型匹配即可。可在函數接口中為參數起一個有意義的名字,增加代碼易讀性,便於更透徹 地理解參數的用途。

3.1 Java中重要的函數接口

接口 參數 返回類型 示例
Predicate T boolean 判斷是否
Consumer T void 輸出一個值
Function<T,R> T T 獲得對象的名字
Supplier None T 工廠方法
UnaryOperator T T 邏輯非(!)
BinaryOperator (T, T) T 求兩個數的乘積(*)

3.2 函數接口定義

定義函數接口需要使用到註解@FunctionalInterface

例如:

@FunctionalInterface
public interface MyFuncInterface {
	void print();
}

使用:

public class MyFunctionalInterfaceTest {
    public static void main(String[] args) {
        doPrint(() -> System.out.println("java"));
    }

    public static void doPrint(MyFuncInterface my) {
        System.out.println("請問你喜歡什麼編程語言?");
        my.print();
    }
}

說明:

這隻是一個很簡單的例子,有人覺得為什麼要搞這麼複雜,去定義一個接口?這個問題還是讀者在平時的工作中去感悟吧,總之,先學會怎麼用它。不至於看了別人寫的代碼都看不懂。

至於我個人的理解,可以簡單聊聊。以前寫過JavaScript,裏面有一種語法就是將自定義函數B作為參數傳遞到另外一個函數A裏面,在函數A裏面會執行你自定義的函數B邏輯,我當時就非常喜歡這種特性,因為每個人關於函數B的實現可能不一樣,亦或者場景不一樣也會導致函數B的實現不一樣。我覺得Java8的這個函數式編程就是對這一特性的補充。

4. 流

流的常用操作有很多,例如collect(toList())mapfiltermaxmin等,下面介紹一下flatMapreduce

4.1 flatMap

flatMap 方法可用 Stream 替換值,然後將多個 Stream 連接成一個 Stream。

List<Integer> together = Stream.of(asList(1, 2), asList(3, 4)) 				 
    .flatMap(numbers -> numbers.stream())
    .collect(toList()); 
assertEquals(asList(1, 2, 3, 4), together);

調用 stream 方法,將每個列錶轉換成 Stream 對象,其餘部分由 flatMap 方法處理。 flatMap 方法的相關函數接口和 map 方法的一樣,都是 Function 接口,只是方法的返回值 限定為 Stream 類型罷了。

4.2 reduce

reduce 操作可以實現從一組值中生成一個值。對於 count、min 和 max 方 法,因為常用而被納入標準庫中。事實上,這些方法都是 reduce 操作。

如何通過 reduce 操作對 Stream 中的数字求和。以 0 作起點——一個空Stream 的求和結果,每一步都將 Stream 中的元素累加至 accumulator,遍歷至 Stream 中的 最後一個元素時,accumulator 的值就是所有元素的和。

int count = Stream.of(1, 2, 3)
    .reduce(0, (acc, element) -> acc + element); 
assertEquals(6, count);

Lambda 表達式的返回值是最新的 acc,是上一輪 acc 的值和當前元素相加的結果。reducer 的類型是前面已介紹過的 BinaryOperator。

5. Optional

reduce 方法的一個重點尚未提及:reduce 方法有兩種形式,一種如前面出現的需要有一 個初始值,另一種變式則不需要有初始值。沒有初始值的情況下,reduce 的第一步使用 Stream 中的前兩個元素。有時,reduce 操作不存在有意義的初始值,這樣做就是有意義的,此時,reduce 方法返回一個 Optional 對象。

Optional 是為核心類庫新設計的一個數據類型,用來替換 null 值。人們對原有的 null 值有很多抱怨。人們常常使用 null 值表示值不存在,Optional 對象能更好地表達這個概念。使用 null 代 表值不存在的最大問題在於 NullPointerException。一旦引用一個存儲 null 值的變量,程 序會立即崩潰。使用 Optional 對象有兩個目的:首先,Optional 對象鼓勵程序員適時檢查變量是否為空,以避免代碼缺陷;其次,它將一個類的 API 中可能為空的值文檔化,這比閱讀實現代碼要簡單很多。

下面我們舉例說明 Optional 對象的 API,從而切身體會一下它的使用方法。使用工廠方法 of,可以從某個值創建出一個 Optional 對象。Optional 對象相當於值的容器,而該值可以 通過 get 方法提取。

Optional<String> a = Optional.of("a"); 
assertEquals("a", a.get());

Optional 對象也可能為空,因此還有一個對應的工廠方法 empty,另外一個工廠方法 ofNullable 則可將一個空值轉換成 Optional 對象。下面的代碼同時展示 了第三個方法 isPresent 的用法(該方法表示一個 Optional 對象里是否有值)。

Optional emptyOptional = Optional.empty(); 
Optional alsoEmpty = Optional.ofNullable(null); assertFalse(emptyOptional.isPresent());

使用 Optional 對象的方式之一是在調用 get() 方法前,先使用 isPresent 檢查 Optional 對象是否有值。使用 orElse 方法則更簡潔,當 Optional 對象為空時,該方法提供了一個 備選值。如果計算備選值在計算上太過繁瑣,即可使用 orElseGet 方法。該方法接受一個 Supplier 對象,只有在 Optional 對象真正為空時才會調用。

assertEquals("b", emptyOptional.orElse("b")); 
assertEquals("c", emptyOptional.orElseGet(() -> "c"));

最後

實踐是檢驗真理的唯一標準,多寫代碼,多思考,你的代碼才會越來越好。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

網頁設計公司推薦不同的風格,搶佔消費者視覺第一線

※Google地圖已可更新顯示潭子電動車充電站設置地點!!

※廣告預算用在刀口上,台北網頁設計公司幫您達到更多曝光效益

※別再煩惱如何寫文案,掌握八大原則!

網頁設計最專業,超強功能平台可客製化

RabbitMQ入門,我是動了心的

人一輩子最值得炫耀的不應該是你的財富有多少(雖然這話說得有點違心,呵呵),而是你的學習能力。技術更新迭代的速度非常快,那作為程序員,我們就應該擁有一顆擁抱變化的心,积極地跟進。

在 RabbitMQ 入門之前,我已經入門了 Redis、Elasticsearch 和 MongoDB,這讓我感覺自己富有極客精神,非常良好。

小夥伴們在繼續閱讀之前,我必須要聲明一點,我對 RabbitMQ 並沒有進行很深入的研究,僅僅是因為要用,就學一下。但作為一名負責任的技術博主,我是動了心的,這篇入門教程,小夥伴們讀完后絕對會感到滿意,忍不住無情地點贊,以及赤裸裸地轉發。

當然了,小夥伴們遇到文章中有錯誤的地方,不要手下留情,可以組團過來捶我,但要保證一點,不要打臉,我怕毀容。

01、RabbitMQ 是什麼

首先,我知道,Rabbit 是一隻兔子(哎呀媽呀,忍不住秀了一波自己的英語功底),可愛的形象已經躍然於我的腦海中了。那 MQ 又是什麼呢?是 Message Queue 的首字母縮寫,也就是說 RabbitMQ 是一款開源的消息隊列系統。

RabbitMQ 的主要特點在於健壯性好、易於使用、高性能、高併發、集群易擴展,以及強大的開源社區支持。反正就是很牛逼的樣子。

九年前我做大宗期貨交易的時候,也需要消息推送,那時候還不知道去找這種現成的中間件,就用自定義的隊列實現,結果搞了不少 bug,有些到現在還沒有解決,真的是不堪回首的往事啊。

下圖是 RabbitMQ 的消息模型圖(來源於網絡,侵刪),小夥伴們來感受下。

1)P 是 Producer,代表生產者,也就是消息的發送者,可以將消息發送到 X

2)X 是 Exchange(為啥不是 E,我也很好奇),代表交換機,可以接受生產者發送的消息,並根據路由將消息發送給指定的隊列

3)Q 是 Queue,也就是隊列,存放交換機發送來的消息

4)C 是 Consumer,代表消費者,也就是消息的接受者,從隊列中獲取消息

聽我這樣一解釋,是不是對 RabbitMQ 的印象就很具象化了?小夥伴們,學起來吧!

02、安裝 Erlang

咦,怎麼不是安裝 RabbitMQ 啊?先來看看官方的解釋。

英文看不太懂,沒關係,我來補充兩人話。RabbitMQ 服務器是用 Erlang 語言編寫的,它的安裝包里並沒有集成 Erlang 的環境,因此需要先安裝 Erlang。小夥伴們不要擔心,Erlang 安裝起來沒有任何難度。

Erlang 下載地址如下:

https://erlang.org/download/otp_versions_tree.html

最新的版本是 23.0.1,我選擇的是 64 位的版本,104M 左右。下載完就可以雙擊運行安裝,傻瓜式的。

需要注意的是,我安裝的過程中,電腦重啟了一次,好像要安裝一個什麼庫,重啟之前忘記保存圖片了(sorry)。重啟后,重新雙擊運行 otp_win64_23.0.1.exe 文件完成 Erlang 安裝。

03、安裝 RabbitMQ

Erlang 安裝成功后,就可以安裝 RabbitMQ 了。下載地址如下所示:

https://www.rabbitmq.com/install-windows.html

找到下圖中的位置,選擇紅色框中的文件進行下載。

安裝包只有 16.5M 大小,還是非常輕量級的。下載完后直接雙擊運行 exe 文件就可以傻瓜式地安裝了。

安裝成功后,就可以將 RabbitMQ 作為 Windows 服務啟動,可以從“開始”菜單管理 RabbitMQ Windows 服務。

點擊「RabbitMQ Command Prompt (sbin dir)」,進入命令行,輸入 rabbitmqctl.bat status 可確認 RabbitMQ 的啟動狀態。

可以看到 RabbitMQ 一些狀態信息:

  • 進程 ID,也就是 PID 為 2816
  • 操作系統為 Windows
  • 當前的版本號為 3.8.4
  • Erlang 的配置信息

命令行界面看起來不夠優雅,因此我們可以輸入以下命令來啟用客戶端管理 UI 插件:

rabbitmq-plugins enable rabbitmq_management

看到以下信息就可以確認插件啟用成功了。

在瀏覽器地址欄輸入 http://localhost:15672/ 可以進入管理端界面,如下圖所示:

04、在 Java 中使用 RabbitMQ

有些小夥伴可能會問,“二哥,我是一名 Java 程序員,我該如何在 Java 中使用 RabbitMQ 呢?”這個問題問得好,這就來,這就來。

第一步,在項目中添加 RabbitMQ 客戶端依賴:

<dependency>
    <groupId>com.rabbitmq</groupId>
    <artifactId>amqp-client</artifactId>
    <version>5.9.0</version>
</dependency>

第二步,我們來模擬一個最簡單的場景,一個生產者發送消息到隊列中,一個消費者從隊列中讀取消息並打印。

官方對 RabbitMQ 有一個很好的解釋,我就“拿來主義”的用一下。在我上高中的年代,同學們之間最流行的交流方式不是 QQ、微信,甚至短信這些,而是書信。因為那時候還沒有智能手機,況且上學期間學校也是命令禁用手機的,所以書信是情感表達的最好方式。好懷念啊。

假如我向女朋友小巷寫了一封情書,內容如下所示:

致小巷
你好呀,小巷。
你走了以後我每天都感到很悶,就像堂吉訶德一樣,每天想念托波索的達辛妮亞。我現在已經養成了一種習慣,就是每兩三天就要找你說幾句不想對別人說的話。
。。。。。。
王二,5月20日

那這封情書要寄給小巷,我就需要跑到郵局,買上郵票,投遞到郵箱當中。女朋友要收到這封情書,就需要郵遞員盡心儘力,不要弄丟了。

RabbitMQ 就像郵局一樣,只不過處理的不是郵件,而是消息。之前解釋過了,P 就是生產者,C 就是消費者。

新建生產者類 Wanger :

public class Wanger {
    private final static String QUEUE_NAME = "love";
    public static void main(String[] args) throws IOException, TimeoutException {
        ConnectionFactory factory = new ConnectionFactory();

        try (Connection connection = factory.newConnection();
             Channel channel = connection.createChannel()) {
            channel.queueDeclare(QUEUE_NAME, falsefalsefalsenull);
            String message = "小巷,我喜歡你。";
            channel.basicPublish("", QUEUE_NAME, null, message.getBytes(StandardCharsets.UTF_8));
            System.out.println(" [王二] 發送 '" + message + "'");
        }
    }
}

1)QUEUE_NAME 為隊列名,也就是說,生產者發送的消息會放到 love 隊列中。

2)通過以下方式創建服務器連接:

ConnectionFactory factory = new ConnectionFactory();
try (Connection connection = factory.newConnection();
             Channel channel = connection.createChannel()) {

ConnectionFactory 是一個非常方便的工廠類,可用來創建到 RabbitMQ 的默認連接(主機名為“localhost”)。然後,創建一個通道( Channel)來發送消息。

Connection 和 Channel 類都實現了 Closeable 接口,所以可以使用 try-with-resource 語句,如果有小夥伴對 try-with-resource 語句不太熟悉,可以查看我之前寫的我去文章。

3)在發送消息的時候,必須設置隊列名稱,通過 queueDeclare() 方法設置。

4)basicPublish() 方法用於發布消息:

  • 第一個參數為交換機(exchange),當前場景不需要,因此設置為空字符串;
  • 第二個參數為路由關鍵字(routingKey),暫時使用隊列名填充;
  • 第三個參數為消息的其他參數(BasicProperties),暫時不配置;
  • 第四個參數為消息的主體,這裏為 UTF-8 格式的字節數組,可以有效地杜絕中文亂碼。

生產者類有了,接下來新建消費者類 XiaoXiang:

public class XiaoXiang {
    private final static String QUEUE_NAME = "love";
    public static void main(String[] args) throws IOException, TimeoutException {
        ConnectionFactory factory = new ConnectionFactory();
        Connection connection = factory.newConnection();
        Channel channel = connection.createChannel();

        channel.queueDeclare(QUEUE_NAME, falsefalsefalsenull);
        System.out.println("等待接收消息");

        DeliverCallback deliverCallback = (consumerTag, delivery) -> {
            String message = new String(delivery.getBody(), "UTF-8");
            System.out.println(" [小巷] 接收到的消息 '" + message + "'");
        };
        channel.basicConsume(QUEUE_NAME, true, deliverCallback, consumerTag -> { });
    }
}

1)創建通道的代碼和生產者差不多,只不過沒有使用 try-with-resource 語句來自動關閉連接和通道,因為我們希望消費者能夠一直保持連接,直到我們強制關閉它。

2)在接收消息的時候,必須設置隊列名稱,通過 queueDeclare() 方法設置。

3)由於 RabbitMQ 將會通過異步的方式向我們推送消息,因此我們需要提供了一個回調,該回調將對消息進行緩衝,直到我們做好準備接收它們為止。

DeliverCallback deliverCallback = (consumerTag, delivery) -> {
    String message = new String(delivery.getBody(), "UTF-8");
    System.out.println(" [小巷] 接收到的消息 '" + message + "'");
};

basicConsume() 方法用於接收消息:

  • 第一個參數為隊列名(queue),和生產者相匹配(love)。

  • 第二個參數為 autoAck,如果為 true 的話,表明服務器要一次性交付消息。怎麼理解這個概念呢?小夥伴們可以在運行消費者類 XiaoXiang 類之前,先多次運行生產者類 Wanger,向隊列中發送多個消息,等到消費者類啟動后,你就會看到多條消息一次性接收到了,就像下面這樣。

等待接收消息
 [小巷] 接收到的消息 '小巷,我喜歡你。'
 [小巷] 接收到的消息 '小巷,我喜歡你。'
 [小巷] 接收到的消息 '小巷,我喜歡你。'
  • 第三個參數為 DeliverCallback,也就是消息的回調函數。

  • 第四個參數為 CancelCallback,我暫時沒搞清楚是幹嘛的。

在消息發送的過程中,也可以使用 RabbitMQ 的管理面板查看到消息的走勢圖,如下所示。

05、鳴謝

好了,我親愛的小夥伴們,以上就是本文的全部內容了,是不是看完后很想實操一把 RabbitMQ,趕快行動吧!如果你在學習的過程中遇到了問題,歡迎隨時和我交流,雖然我也是個菜鳥,但我有熱情啊。

另外,如果你想寫入門級別的文章,這篇就是最好的範例。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

網頁設計公司推薦不同的風格,搶佔消費者視覺第一線

※廣告預算用在刀口上,台北網頁設計公司幫您達到更多曝光效益

※自行創業缺乏曝光? 網頁設計幫您第一時間規劃公司的形象門面

南投搬家公司費用需注意的眉眉角角,別等搬了再說!

新北清潔公司,居家、辦公、裝潢細清專業服務

※教你寫出一流的銷售文案?

Tensorflow2 自定義數據集圖片完成圖片分類任務

對於自定義數據集的圖片任務,通用流程一般分為以下幾個步驟:

  • Load data

  • Train-Val-Test

  • Build model

  • Transfer Learning

其中大部分精力會花在數據的準備和預處理上,本文用一種較為通用的數據處理手段,並通過手動構建,簡單模型, 層數較深的resnet網絡,和基於VGG19的遷移學習。

你可以通過這個例子,快速搭建網絡,並訓練處一個較為滿意的結果。

1. Load data

數據集來自Pokemon的5分類數據, 每一種的圖片數量為200多張,是一個較小型的數據集。

官方項目鏈接:

Keras and Convolutional Neural Networks (CNNs)

1.1 數據集介紹

Pokemon文件夾中包含5個子文件,其中每個子文件夾名為對應的類別名。文件夾中包含有png, jpeg的圖片文件。

1.2 解題思路

  • 由於文件夾中沒有劃分,訓練集和測試集,所以需要構建一個csv文件讀取所有的文件,及其類別

  • shuffle數據集以後,劃分Train_val_test

  • 對數據進行預處理, 數據標準化,數據增強, 可視化處理

“””python
# 創建数字編碼錶

  import os
  import glob
  import random
  import csv
  import tensorflow as tf
  from tensorflow import keras
  import matplotlib.pyplot as plt
  import time
  
  
  def load_csv(root, filename, name2label):
      """
      將分散在各文件夾中的圖片, 轉換為圖片和label對應的一個dataset文件, 格式為csv
      :param root: 文件路徑(每個子文件夾中的文件屬於一類)
      :param filename: 文件名
      :param name2label: 類名編碼錶  {'類名1':0, '類名2':1..}
      :return: images, labels
      """
      # 判斷是否csv文件已經生成
      if not os.path.exists(os.path.join(root, filename)):  # join-將路徑與文件名何為一個路徑並返回(沒有會生成新路徑)
          images = []  # 存的是文件路徑
          for name in name2label.keys():
              # pokemon\pikachu\00000001.png
              # glob.glob() 利用通配符檢索路徑內的文件,類似於正則表達式
              images += glob.glob(os.path.join(root, name, '*'))  # png, jpg, jpeg
          print(name2label)
          print(len(images), images)
  
          random.shuffle(images)
  
          with open(os.path.join(root, filename), 'w', newline='') as f:
              writer = csv.writer(f)
              for img in images:
                  name = img.split(os.sep)[1]  # os.sep 表示分隔符 window-'\\' , linux-'/'
                  label = name2label[name]  # 0, 1, 2..
                  # 'pokemon\\bulbasaur\\00000000.png', 0
                  writer.writerow([img, label])  # 如果不設定newline='', 2個數據會分為2行寫
              print('write into csv file:', filename)
  
      # 讀取現有文件
      images, labels = [], []
      with open(os.path.join(root, filename)) as f:
          reader = csv.reader(f)
          for row in reader:
              # 'pokemon\\bulbasaur\\00000000.png', 0
              img, label = row
              label = int(label)  # str-> int
              images.append(img)
              labels.append(label)
  
      assert len(images) == len(labels)
  
      return images, labels
  
  
  def load_pokemon(root, mode='train'):
      """
      # 創建数字編碼錶
      :param root: root path
      :param mode: train, valid, test
      :return: images, labels, name2label
      """
  
      name2label = {}  # {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
      for name in sorted(os.listdir(os.path.join(root))):
          # sorted() 是為了復現結果的一致性
          # os.listdir - 返迴路徑下的所有文件(文件夾,文件)列表
          if not os.path.isdir(os.path.join(root, name)):  # 是否為文件夾且是否存在
              continue
          # 每個類別編碼一個数字
          name2label[name] = len(name2label)
  
      # 讀取label
      images, labels = load_csv(root, 'images.csv', name2label)
  
      # 劃分數據集 [6:2:2]
      if mode == 'train':
          images = images[:int(0.6 * len(images))]
          labels = labels[:int(0.6 * len(labels))]  # len(images) == len(labels)
  
      elif mode == 'valid':
          images = images[int(0.6 * len(images)):int(0.8 * len(images))]
          labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
  
      else:
          images = images[int(0.8 * len(images)):]
          labels = labels[int(0.8 * len(labels)):]
  
      return images, labels, name2label
  
  
  # imagenet 數據集均值, 方差
  img_mean = tf.constant([0.485, 0.456, 0.406])  # 3 channel
  img_std = tf.constant([0.229, 0.224, 0.225])
  
  def normalization(x, mean=img_mean, std=img_std):
      # [224, 224, 3]
      x = (x - mean) / std
      return x
  
  def denormalization(x, mean=img_mean, std=img_std):
      x = x * std + mean
      return x
  
  
  def preprocess(x, y):
      # x: path, y: label
      x = tf.io.read_file(x)  # 2進制
      # x = tf.image.decode_image(x)
      x = tf.image.decode_jpeg(x, channels=3)  # RGBA
      x = tf.image.resize(x, [244, 244])
  
      # data augmentation
      # x = tf.image.random_flip_up_down(x)
      x = tf.image.random_flip_left_right(x)
      x = tf.image.random_crop(x, [224, 224, 3])  # 模型縮減比例不宜過大,否則會增大訓練難度
  
      x = tf.cast(x, dtype=tf.float32) / 255. # unit8 -> float32
      # U[0,1] -> N(0,1)  # 提高訓練準確度
      x = normalization(x)
  
      y = tf.convert_to_tensor(y)
  
      return x, y
  
  def main():
      images, labels, name2label = load_pokemon('pokemon', 'train')
      print('images:', len(images), images)
      print('labels:', len(labels), labels)
      # print(name2label)
  
      # .map()函數要位於.batch()之前, 否則 x=tf.io.read_file()會一次讀取一個batch的圖片,從而報錯
      db = tf.data.Dataset.from_tensor_slices((images, labels)).map(preprocess).shuffle(1000).batch(32)
  
      # tf.summary()
      # 提供了各類方法(支持各種多種格式)用於保存訓練過程中產生的數據(比如loss_value、accuracy、整個variable),
      # 這些數據以日誌文件的形式保存到指定的文件夾中。
  
      # 數據可視化:而tensorboard可以將tf.summary()
      # 記錄下來的日誌可視化,根據記錄的數據格式,生成折線圖、統計直方圖、圖片列表等多種圖。
      # tf.summary()
      # 通過遞增的方式更新日誌,這讓我們可以邊訓練邊使用tensorboard讀取日誌進行可視化,從而實時監控訓練過程。
      writer = tf.summary.create_file_writer('logs')
      for step, (x, y) in enumerate(db):
          with writer.as_default():
              x = denormalization(x)
              tf.summary.image('img', x, step=step, max_outputs=9)  # STEP:默認選項,指的是橫軸显示的是訓練迭代次數
  
              time.sleep(5)
  
  
  
  if __name__ == '__main__':
      main()

“””

2. 構建模型進行訓練

2.1 自定義小型網絡

由於數據集數量較少,大型網絡的訓練中往往會出現過擬合情況,這裏就定義了一個2層卷積的小型網絡。
引入early_stopping回調函數后,3個epoch沒有較大變化的情況下,模型訓練的準確率為0.8547

“””
# 1. 自定義小型網絡
model = keras.Sequential([
layers.Conv2D(16, 5, 3),
layers.MaxPool2D(3, 3),
layers.ReLU(),
layers.Conv2D(64, 5, 3),
layers.MaxPool2D(2, 2),
layers.ReLU(),
layers.Flatten(),
layers.Dense(64),
layers.ReLU(),
layers.Dense(5)
])

  model.build(input_shape=(None, 224, 224, 3))  
  model.summary()
  
  early_stopping = EarlyStopping(
      monitor='val_loss',
      patience=3,
      min_delta=0.001
  )
  
  
  model.compile(optimizer=optimizers.Adam(lr=1e-3),
                 loss=losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])
  model.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
             callbacks=[early_stopping])
  model.evaluate(db_test)

“””

2.2 自定義的Resnet網絡

resnet 網絡對於層次較深的網絡的可訓練型提升很大,主要是通過一個identity layer保證了深層次網絡的訓練效果不會弱於淺層網絡。
其他文章中有詳細介紹resnet的搭建,這裏就不做贅述, 這裏構建了一個resnet18網絡, 準確率0.7607。

“””
import os

  import numpy as np
  import tensorflow as tf
  from tensorflow import keras
  from tensorflow.keras import layers
  
  tf.random.set_seed(22)
  np.random.seed(22)
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  assert tf.__version__.startswith('2.')
  
  
  class ResnetBlock(keras.Model):
  
      def __init__(self, channels, strides=1):
          super(ResnetBlock, self).__init__()
  
          self.channels = channels
          self.strides = strides
  
          self.conv1 = layers.Conv2D(channels, 3, strides=strides,
                                     padding=[[0, 0], [1, 1], [1, 1], [0, 0]])
          self.bn1 = keras.layers.BatchNormalization()
          self.conv2 = layers.Conv2D(channels, 3, strides=1,
                                     padding=[[0, 0], [1, 1], [1, 1], [0, 0]])
          self.bn2 = keras.layers.BatchNormalization()
  
          if strides != 1:
              self.down_conv = layers.Conv2D(channels, 1, strides=strides, padding='valid')
              self.down_bn = tf.keras.layers.BatchNormalization()
  
      def call(self, inputs, training=None):
          residual = inputs
  
          x = self.conv1(inputs)
          x = tf.nn.relu(x)
          x = self.bn1(x, training=training)
          x = self.conv2(x)
          x = tf.nn.relu(x)
          x = self.bn2(x, training=training)
  
          # 殘差連接
          if self.strides != 1:
              residual = self.down_conv(inputs)
              residual = tf.nn.relu(residual)
              residual = self.down_bn(residual, training=training)
  
          x = x + residual
          x = tf.nn.relu(x)
          return x
  
  
  class ResNet(keras.Model):
  
      def __init__(self, num_classes, initial_filters=16, **kwargs):
          super(ResNet, self).__init__(**kwargs)
  
          self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')
  
          self.blocks = keras.models.Sequential([
              ResnetBlock(initial_filters * 2, strides=3),
              ResnetBlock(initial_filters * 2, strides=1),
              # layers.Dropout(rate=0.5),
  
              ResnetBlock(initial_filters * 4, strides=3),
              ResnetBlock(initial_filters * 4, strides=1),
  
              ResnetBlock(initial_filters * 8, strides=2),
              ResnetBlock(initial_filters * 8, strides=1),
  
              ResnetBlock(initial_filters * 16, strides=2),
              ResnetBlock(initial_filters * 16, strides=1),
          ])
  
          self.final_bn = layers.BatchNormalization()
          self.avg_pool = layers.GlobalMaxPool2D()
          self.fc = layers.Dense(num_classes)
  
      def call(self, inputs, training=None):
          # print('x:',inputs.shape)
          out = self.stem(inputs, training = training)
          out = tf.nn.relu(out)
  
          # print('stem:',out.shape)
  
          out = self.blocks(out, training=training)
          # print('res:',out.shape)
  
          out = self.final_bn(out, training=training)
          # out = tf.nn.relu(out)
  
          out = self.avg_pool(out)
  
          # print('avg_pool:',out.shape)
          out = self.fc(out)
  
          # print('out:',out.shape)
  
          return out
  
  
  def main():
      num_classes = 5
  
      resnet18 = ResNet(5)
      resnet18.build(input_shape=(None, 224, 224, 3))
      resnet18.summary()
  
  
  if __name__ == '__main__':
      main()

“””

“””
# 2.resnet18訓練, 圖片數量較小,訓練結果不是特別好
# resnet = ResNet(5) # 0.7607
# resnet.build(input_shape=(None, 224, 224, 3))
# resnet.summary()
“””

2.3 VGG19遷移學習

遷移學習利用了數據集之間的相似性,對於數據集數量較少的時候,訓練效果會遠優於其他。
在訓練過程中,使用include_top=False, 去掉最後分類的基層Dense, 重新構建並訓練就可以了。準確率0.9316

“””
# 3. VGG19遷移學習,遷移學習利用數據集之間的相似性, 結果遠好於其他2種
# 為了方便,這裏仍然使用resnet命名
net = tf.keras.applications.VGG19(weights=’imagenet’, include_top=False, pooling=’max’ )
net.trainable = False
resnet = keras.Sequential([
net,
layers.Dense(5)
])
resnet.build(input_shape=(None, 224, 224, 3)) # 0.9316
resnet.summary()

  early_stopping = EarlyStopping(
      monitor='val_loss',
      patience=3,
      min_delta=0.001
  )
  
  
  resnet.compile(optimizer=optimizers.Adam(lr=1e-3),
                 loss=losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])
  resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
             callbacks=[early_stopping])
  resnet.evaluate(db_test)

“””

附錄:

train_scratch.py 代碼

“””

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping

tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')

# 設置GPU顯存按需分配
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     try:
#         # Currently, memory growth needs to be the same across GPUs
#         for gpu in gpus:
#             tf.config.experimental.set_memory_growth(gpu, True)
#         logical_gpus = tf.config.experimental.list_logical_devices('GPU')
#         print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
#     except RuntimeError as e:
#         # Memory growth must be set before GPUs have been initialized
#         print(e)

from pokemon import load_pokemon, normalization
from resnet import ResNet


def preprocess(x, y):
    # x: 圖片的路徑,y:圖片的数字編碼
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)  # RGBA
    # 圖片縮放
    # x = tf.image.resize(x, [244, 244])
    # 圖片旋轉
    # x = tf.image.rot90(x,2)
    # 隨機水平翻轉
    x = tf.image.random_flip_left_right(x)
    # 隨機豎直翻轉
    # x = tf.image.random_flip_up_down(x)

    # 圖片先縮放到稍大尺寸
    x = tf.image.resize(x, [244, 244])
    # 再隨機裁剪到合適尺寸
    x = tf.image.random_crop(x, [224, 224, 3])

    # x: [0,255]=> -1~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalization(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

    return x, y


batchsz = 32

# create train db
images1, labels1, table = load_pokemon('pokemon', 'train')
db_train = tf.data.Dataset.from_tensor_slices((images1, labels1))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# create validation db
images2, labels2, table = load_pokemon('pokemon', 'valid')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('pokemon', mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)


# 1. 自定義小型網絡
# resnet = keras.Sequential([
#     layers.Conv2D(16, 5, 3),
#     layers.MaxPool2D(3, 3),
#     layers.ReLU(),
#     layers.Conv2D(64, 5, 3),
#     layers.MaxPool2D(2, 2),
#     layers.ReLU(),
#     layers.Flatten(),
#     layers.Dense(64),
#     layers.ReLU(),
#     layers.Dense(5)
# ])  # 0.8547


# 2.resnet18訓練, 圖片數量較小,訓練結果不是特別好
# resnet = ResNet(5)  # 0.7607
# resnet.build(input_shape=(None, 224, 224, 3))
# resnet.summary()


# 3. VGG19遷移學習,遷移學習利用數據集之間的相似性, 結果遠好於其他2種
net = tf.keras.applications.VGG19(weights='imagenet', include_top=False, pooling='max' )
net.trainable = False
resnet = keras.Sequential([
    net,
    layers.Dense(5)
])
resnet.build(input_shape=(None, 224, 224, 3))   # 0.9316
resnet.summary()

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=3,
    min_delta=0.001
)


resnet.compile(optimizer=optimizers.Adam(lr=1e-3),
               loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])
resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
           callbacks=[early_stopping])
resnet.evaluate(db_test)

“””

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

※帶您來了解什麼是 USB CONNECTOR  ?

※自行創業缺乏曝光? 網頁設計幫您第一時間規劃公司的形象門面

※如何讓商品強力曝光呢? 網頁設計公司幫您建置最吸引人的網站,提高曝光率!

※綠能、環保無空污,成為電動車最新代名詞,目前市場使用率逐漸普及化

※廣告預算用在刀口上,台北網頁設計公司幫您達到更多曝光效益

※教你寫出一流的銷售文案?

這篇文章,我們來談一談Spring中的屬性注入

本系列文章:

讀源碼,我們可以從第一行讀起

你知道Spring是怎麼解析配置類的嗎?

配置類為什麼要添加@Configuration註解?

談談Spring中的對象跟Bean,你知道Spring怎麼創建對象的嗎?

推薦閱讀:

Spring官網閱讀 | 總結篇

Spring雜談

本系列文章將會帶你一行行的將Spring的源碼吃透,推薦閱讀的文章是閱讀源碼的基礎!

前言

在前面的文章中已經知道了Spring是如何將一個對象創建出來的,那麼緊接着,Spring就需要將這個對象變成一個真正的Bean了,這個過程主要分為兩步

  1. 屬性注入
  2. 初始化

在這兩個過程中,Bean的後置處理器會穿插執行,其中有些後置處理器是為了幫助完成屬性注入或者初始化的,而有些後置處理器是Spring提供給程序員進行擴展的,當然,這二者並不衝突。整個Spring創建對象並將對象變成Bean的過程就是我們經常提到了Spring中Bean的生命周期。當然,本系列源碼分析的文章不會再對生命周期的概念做過多闡述了,如果大家有這方面的需求的話可以參考我之前的文章,或者關注我的公眾號:程序員DMZ

Spring官網閱讀(九)Spring中Bean的生命周期(上)

Spring官網閱讀(十)Spring中Bean的生命周期(下)

源碼分析

閑話不再多說,我們正式進入源碼分析階段,本文重點要分析的方法就是org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory#doCreateBean,其源碼如下:

doCreateBean

	protected Object doCreateBean(final String beanName, final RootBeanDefinition mbd, final @Nullable Object[] args)
			throws BeanCreationException {

		// 創建對象的過程在上篇文章中我們已經介紹過了,這裏不再贅述
		BeanWrapper instanceWrapper = null;
		if (mbd.isSingleton()) {
			instanceWrapper = this.factoryBeanInstanceCache.remove(beanName);
		}
		if (instanceWrapper == null) {
			instanceWrapper = createBeanInstance(beanName, mbd, args);
		}
        
        // 獲取到創建的這個對象
		final Object bean = instanceWrapper.getWrappedInstance();
		Class<?> beanType = instanceWrapper.getWrappedClass();
		if (beanType != NullBean.class) {
			mbd.resolvedTargetType = beanType;
		}

		// Allow post-processors to modify the merged bean definition.
        // 按照官方的註釋來說,這個地方是Spring提供的一個擴展點,對程序員而言,我們可以通過一個實現了MergedBeanDefinitionPostProcessor的後置處理器來修改bd中的屬性,從而影響到後續的Bean的生命周期
        // 不過官方自己實現的後置處理器並沒有去修改bd,而是調用了applyMergedBeanDefinitionPostProcessors方法
        // 這個方法名直譯過來就是-應用合併后的bd,也就是說它這裏只是對bd做了進一步的使用而沒有真正的修改
		synchronized (mbd.postProcessingLock) {
           // bd只允許被處理一次
			if (!mbd.postProcessed) {
				try {
                    // 應用合併后的bd
					applyMergedBeanDefinitionPostProcessors(mbd, beanType, beanName);
				}
				catch (Throwable ex) {
					throw new BeanCreationException(mbd.getResourceDescription(), beanName,
							"Post-processing of merged bean definition failed", ex);
				}
                // 標註這個bd已經被MergedBeanDefinitionPostProcessor的後置處理器處理過
                // 那麼在第二次創建Bean的時候,不會再次調用applyMergedBeanDefinitionPostProcessors
				mbd.postProcessed = true;
			}
		}

		// 這裡是用來出來循環依賴的,關於循環以來,在介紹完正常的Bean的創建后,單獨用一篇文章說明
        // 這裏不做過多解釋
		boolean earlySingletonExposure = (mbd.isSingleton() && this.allowCircularReferences &&
				isSingletonCurrentlyInCreation(beanName));
		if (earlySingletonExposure) {
			if (logger.isTraceEnabled()) {
				logger.trace("Eagerly caching bean '" + beanName +
						"' to allow for resolving potential circular references");
			}
			addSingletonFactory(beanName, () -> getEarlyBeanReference(beanName, mbd, bean));
		}


		Object exposedObject = bean;
		try {
            // 我們這篇文章重點要分析的就是populateBean方法,在這個方法中完成了屬性注入
			populateBean(beanName, mbd, instanceWrapper);
            // 初始化
			exposedObject = initializeBean(beanName, exposedObject, mbd);
		}
		catch (Throwable ex) {
			// 省略異常代碼
		}

		// 後續代碼不在本文探討範圍內了,暫不考慮

		return exposedObject;
	}

applyMergedBeanDefinitionPostProcessors

源碼如下:

// 可以看到這個方法的代碼還是很簡單的,就是調用了MergedBeanDefinitionPostProcessor的postProcessMergedBeanDefinition方法
protected void applyMergedBeanDefinitionPostProcessors(RootBeanDefinition mbd, Class<?> beanType, String beanName) {
    for (BeanPostProcessor bp : getBeanPostProcessors()) {
        if (bp instanceof MergedBeanDefinitionPostProcessor) {
            MergedBeanDefinitionPostProcessor bdp = (MergedBeanDefinitionPostProcessor) bp;
            bdp.postProcessMergedBeanDefinition(mbd, beanType, beanName);
        }
    }
}

這個時候我們就要思考一個問題,容器中現在有哪些後置處理器是MergedBeanDefinitionPostProcessor呢?

查看這個方法的實現類我們會發現總共就這麼幾個類實現了MergedBeanDefinitionPostProcessor接口。實際上除了ApplicationListenerDetector之外,其餘的後置處理器的邏輯都差不多。我們在這裏我們主要就分析兩個後置處理

  1. ApplicationListenerDetector
  2. AutowiredAnnotationBeanPostProcessor

ApplicationListenerDetector

首先,我們來ApplicationListenerDetector,這個類在之前的文章中也多次提到過了,它的作用是用來處理嵌套Bean的情況,主要是保證能將嵌套在Bean標籤中的ApplicationListener也能添加到容器的監聽器集合中去。我們先通過一個例子來感受下這個後置處理器的作用吧

配置文件:

<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
	   xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	   xsi:schemaLocation="http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd">

	<bean class="com.dmz.source.populate.service.DmzService" id="dmzService">
		<constructor-arg name="orderService">
			<bean class="com.dmz.source.populate.service.OrderService"/>
		</constructor-arg>
	</bean>
</beans>

示例代碼:

// 事件
public class DmzEvent extends ApplicationEvent {
	public DmzEvent(Object source) {
		super(source);
	}
}

public class DmzService {

	OrderService orderService;

	public DmzService(OrderService orderService) {
		this.orderService = orderService;
	}
}
// 實現ApplicationListener接口
public class OrderService implements ApplicationListener<DmzEvent> {
	@Override
	public void onApplicationEvent(DmzEvent event) {
		System.out.println(event.getSource());
	}
}

public class Main {
	public static void main(String[] args) {
		ClassPathXmlApplicationContext cc = new ClassPathXmlApplicationContext("application-populate.xml");
		cc.publishEvent(new DmzEvent("my name is dmz"));
	}
}

// 程序運行結果,控制台打印:my name is dmz

說明OrderService已經被添加到了容器的監聽器集合中。但是請注意,在這種情況下,如果要使OrderService能夠執行監聽的邏輯,必須要滿足下面這兩個條件

  • 外部的Bean要是單例的,對於我們的例子而言就是dmzService
  • 內嵌的Bean也必須是單例的,在上面的例子中也就是orderService必須是單例

另外需要注意的是,這種嵌套的Bean比較特殊,它雖然由Spring創建,但是確不存在於容器中,就是說我們不能將其作為依賴注入到別的Bean中。

AutowiredAnnotationBeanPostProcessor

對應源碼如下:

public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, Class<?> beanType, String beanName) {
    // 找到注入的元數據,第一次是構建,後續可以直接從緩存中拿
    // 註解元數據其實就是當前這個類中的所有需要進行注入的“點”的集合,
    // 注入點(InjectedElement)包含兩種,字段/方法
    // 對應的就是AutowiredFieldElement/AutowiredMethodElement
    InjectionMetadata metadata = findAutowiringMetadata(beanName, beanType, null);
    // 排除掉被外部管理的注入點
    metadata.checkConfigMembers(beanDefinition);
}

上面代碼的核心邏輯就是

  • 找到所有的注入點,其實就是被@Autowired註解修飾的方法以及字段,同時靜態的方法以及字段也會被排除
  • 排除掉被外部管理的注入點,在後續的源碼分析中我們再細說

findAutowiringMetadata

// 這個方法的核心邏輯就是先從緩存中獲取已經解析好的注入點信息,很明顯,在原型情況下才會使用緩存
// 創建注入點的核心邏輯在buildAutowiringMetadata方法中
private InjectionMetadata findAutowiringMetadata(String beanName, Class<?> clazz, @Nullable PropertyValues pvs) {
    String cacheKey = (StringUtils.hasLength(beanName) ? beanName : clazz.getName());
    InjectionMetadata metadata = this.injectionMetadataCache.get(cacheKey);
    // 可能我們會修改bd中的class屬性,那麼InjectionMetadata中的注入點信息也需要刷新
    if (InjectionMetadata.needsRefresh(metadata, clazz)) {
        synchronized (this.injectionMetadataCache) {
            metadata = this.injectionMetadataCache.get(cacheKey);
            if (InjectionMetadata.needsRefresh(metadata, clazz)) {
                if (metadata != null) {
                    metadata.clear(pvs);
                }
                // 這裏真正創建注入點
                metadata = buildAutowiringMetadata(clazz);
                this.injectionMetadataCache.put(cacheKey, metadata);
            }
        }
    }
    return metadata;
}

buildAutowiringMetadata

// 我們應用中使用@Autowired註解標註在字段上或者setter方法能夠完成屬性注入
// 就是因為這個方法將@Autowired註解標註的方法以及字段封裝成InjectionMetadata
// 在後續階段會調用InjectionMetadata的inject方法進行注入
private InjectionMetadata buildAutowiringMetadata(final Class<?> clazz) {
    List<InjectionMetadata.InjectedElement> elements = new ArrayList<>();
    Class<?> targetClass = clazz;

    do {
        final List<InjectionMetadata.InjectedElement> currElements = new ArrayList<>();
		// 處理所有的被@AutoWired/@Value註解標註的字段
        ReflectionUtils.doWithLocalFields(targetClass, field -> {
            AnnotationAttributes ann = findAutowiredAnnotation(field);
            if (ann != null) {
                // 靜態字段會直接跳過
                if (Modifier.isStatic(field.getModifiers())) {
                    // 省略日誌打印
                    return;
                }
                // 得到@AutoWired註解中的required屬性
                boolean required = determineRequiredStatus(ann);
                currElements.add(new AutowiredFieldElement(field, required));
            }
        });
		// 處理所有的被@AutoWired註解標註的方法,相對於字段而言,這裏需要對橋接方法進行特殊處理
        ReflectionUtils.doWithLocalMethods(targetClass, method -> {
            // 只處理一種特殊的橋接場景,其餘的橋接方法都會被忽略
            Method bridgedMethod = BridgeMethodResolver.findBridgedMethod(method);
            if (!BridgeMethodResolver.isVisibilityBridgeMethodPair(method, bridgedMethod)) {
                return;
            }
            AnnotationAttributes ann = findAutowiredAnnotation(bridgedMethod);
            // 處理方法時需要注意,當父類中的方法被子類重寫時,如果子父類中的方法都加了@Autowired
            // 那麼此時父類方法不能被處理,即不能被封裝成一個AutowiredMethodElement
            if (ann != null && method.equals(ClassUtils.getMostSpecificMethod(method, clazz))) {
                if (Modifier.isStatic(method.getModifiers())) {
                    // 省略日誌打印
                    return;
                }
                if (method.getParameterCount() == 0) {
                    // 當方法的參數數量為0時,雖然不需要進行注入,但是還是會把這個方法作為注入點使用
                    // 這個方法最終還是會被調用
                    if (logger.isInfoEnabled()) {
                        logger.info("Autowired annotation should only be used on methods with parameters: " +
                                    method);
                    }
                }
                boolean required = determineRequiredStatus(ann);
                // PropertyDescriptor: 屬性描述符
                // 就是通過解析getter/setter方法,例如void getA()會解析得到一個屬性名稱為a
                // readMethod為getA的PropertyDescriptor,
                // 在《Spring官網閱讀(十四)Spring中的BeanWrapper及類型轉換》文中已經做過解釋
                // 這裏不再贅述,這裏之所以來這麼一次查找是因為當XML中對這個屬性進行了配置后,
                // 那麼就不會進行自動注入了,XML中显示指定的屬性優先級高於註解
                PropertyDescriptor pd = BeanUtils.findPropertyForMethod(bridgedMethod, clazz);		   // 構造一個對應的AutowiredMethodElement,後續這個方法會被執行
                // 方法的參數會被自動注入,這裏不限於setter方法
                currElements.add(new AutowiredMethodElement(method, required, pd));
            }
        });
		// 會處理父類中字段上及方法上的@AutoWired註解,並且父類的優先級比子類高
        elements.addAll(0, currElements);
        targetClass = targetClass.getSuperclass();
    }
    while (targetClass != null && targetClass != Object.class);

    return new InjectionMetadata(clazz, elements);
}
難點代碼分析

上面的代碼整體來說應該很簡單,就如我們之前所說的,處理帶有@Autowired註解的字段及方法,同時會過濾掉所有的靜態字段及方法。上面複雜的地方在於對橋接方法的處理,可能大部分人都沒辦法理解這幾行代碼:

// 第一行
Method bridgedMethod = BridgeMethodResolver.findBridgedMethod(method);

// 第二行
if (!BridgeMethodResolver.isVisibilityBridgeMethodPair(method, bridgedMethod)) {
    return;
}

// 第三行
if (ann != null && method.equals(ClassUtils.getMostSpecificMethod(method, clazz))) {

}

要理解這些代碼,首先你得知道什麼是橋接,為此我已經寫好了一篇文章:

Spring雜談 | 從橋接方法到JVM方法調用

除了在上面的文章中提到的橋接方法外,還有一種特殊的情況

// A類跟B類在同一個包下,A不是public的
class A {
	public void test(){

	}
}

// 在B中會生成一個跟A中的方法描述符(參數+返回值)一模一樣的橋接方法
// 這個橋接方法實際上就是調用父類中的方法
// 具體可以參考:https://bugs.java.com/bugdatabase/view_bug.do?bug_id=63424113
public class B extends A {
}

在理解了什麼是橋接之後,那麼上邊的第一行代碼你應該就能看懂了,就以上面的代碼為例,B中會生成一個橋接方法,對應的被橋接的方法就是A中的test方法。

接着,我們看看第二行代碼

public static boolean isVisibilityBridgeMethodPair(Method bridgeMethod, Method bridgedMethod) {
    // 說明這個方法本身就不是橋接方法,直接返回true
    if (bridgeMethod == bridgedMethod) {
        return true;
    }
    // 說明是橋接方法,並且方法描述符一致
    // 當且僅當是上面例子中描述的這種橋接的時候這個判斷才會滿足
    // 正常來說橋接方法跟被橋接方法的返回值+參數類型肯定不一致
    // 所以這個判斷會過濾掉其餘的所有類型的橋接方法
    // 只會保留本文提及這種特殊情況下產生的橋接方法
    return (bridgeMethod.getReturnType().equals(bridgedMethod.getReturnType()) &&
            Arrays.equals(bridgeMethod.getParameterTypes(), bridgedMethod.getParameterTypes()));
}

最後,再來看看第三行代碼,核心就是這句 method.equals(ClassUtils.getMostSpecificMethod(method, clazz)。這句代碼的主要目的就是為了處理下面這種情況

@Component
public class D extends C {

	@Autowired
	@Override
	public void setDmzService(DmzService dmzService) {
		dmzService.init();
		this.dmzService = dmzService;
	}
}

// C不是Spring中的組件
public class C {
	DmzService dmzService;
    @Autowired
	public void setDmzService(DmzService dmzService) {
		this.dmzService = dmzService;
	}
}

這種情況下,在處理D中的@Autowired註解時,雖然我們要處理父類中的@Autowired註解,但是因為子類中的方法已經複寫了父類中的方法,所以此時應該要跳過父類中的這個被複寫的方法,這就是第三行代碼的作用。

小結

到這裏我們主要分析了applyMergedBeanDefinitionPostProcessors這段代碼的作用,它的執行時機是在創建對象之後,屬性注入之前。按照官方的定義來說,到這裏我們仍然可以使用這個方法來修改bd的定義,那麼相對於通過BeanFactoryPostProcessor的方式修改bd,applyMergedBeanDefinitionPostProcessors這個方法影響的範圍更小,BeanFactoryPostProcessor影響的是整個Bean的生命周期,而applyMergedBeanDefinitionPostProcessors只會影響屬性注入之後的生命周期。

其次,我們分析了Spring中內置的MergedBeanDefinitionPostProcessor,選取了其中兩個特殊的後置處理器進行分析,其中ApplicationListenerDetector主要處理內嵌的事件監聽器,而AutowiredAnnotationBeanPostProcessor主要用於處理@Autowired註解,實際上我們會發現,到這裏還只是完成了@Autowired註解的解析,還沒有真正開始進行注入,真正注入的邏輯在後面我們要分析的populateBean方法中,在這個方法中會使用解析好的注入元信息完成真正的屬性注入,那麼接下來我們就開始分析populateBean這個方法的源碼。

populateBean

循環依賴的代碼我們暫且跳過,後續出一篇專門文章解讀循環依賴,我們直接看看populateBean到底做了什麼。

protected void populateBean(String beanName, RootBeanDefinition mbd, @Nullable BeanWrapper bw) {

    // 處理空實例
    if (bw == null) {
        // 如果創建的對象為空,但是在XML中又配置了需要注入的屬性的話,那麼直接報錯
        if (mbd.hasPropertyValues()) {
            throw new BeanCreationException(
                mbd.getResourceDescription(), beanName, "Cannot apply property values to null instance");
        }
        else {
            // 空對象,不進行屬性注入
            return;
        }
    }

    // 滿足兩個條件,不是合成類 && 存在InstantiationAwareBeanPostProcessor
    // 其中InstantiationAwareBeanPostProcessor主要作用就是作為Bean的實例化前後的鈎子
    // 外加完成屬性注入,對於三個方法就是
    // postProcessBeforeInstantiation  創建對象前調用
    // postProcessAfterInstantiation   對象創建完成,@AutoWired註解解析后調用   
    // postProcessPropertyValues(已過期,被postProcessProperties替代) 進行屬性注入
    // 下面這段代碼的主要作用就是我們可以提供一個InstantiationAwareBeanPostProcessor
    // 提供的這個後置處理如果實現了postProcessAfterInstantiation方法並且返回false
    // 那麼可以跳過Spring默認的屬性注入,但是這也意味着我們要自己去實現屬性注入的邏輯
    // 所以一般情況下,我們也不會這麼去擴展
    if (!mbd.isSynthetic() && hasInstantiationAwareBeanPostProcessors()) {
        for (BeanPostProcessor bp : getBeanPostProcessors()) {
            if (bp instanceof InstantiationAwareBeanPostProcessor) {
                InstantiationAwareBeanPostProcessor ibp = (InstantiationAwareBeanPostProcessor) bp;
                if (!ibp.postProcessAfterInstantiation(bw.getWrappedInstance(), beanName)) {
                    return;
                }
            }
        }
    }
	
    // 這裏其實就是判斷XML是否提供了屬性相關配置
    PropertyValues pvs = (mbd.hasPropertyValues() ? mbd.getPropertyValues() : null);
	
    // 確認注入模型
    int resolvedAutowireMode = mbd.getResolvedAutowireMode();
    
    // 主要處理byName跟byType兩種注入模型,byConstructor這種注入模型在創建對象的時候已經處理過了
    // 這裏都是對自動注入進行處理,byName跟byType兩種注入模型均是依賴setter方法
    // byName,根據setter方法的名字來查找對應的依賴,例如setA,那麼就是去容器中查找名字為a的Bean
    // byType,根據setter方法的參數類型來查找對應的依賴,例如setXx(A a),就是去容器中查詢類型為A的bean
    if (resolvedAutowireMode == AUTOWIRE_BY_NAME || resolvedAutowireMode == AUTOWIRE_BY_TYPE) {
        MutablePropertyValues newPvs = new MutablePropertyValues(pvs);
        if (resolvedAutowireMode == AUTOWIRE_BY_NAME) {
            autowireByName(beanName, mbd, bw, newPvs);
        }
        if (resolvedAutowireMode == AUTOWIRE_BY_TYPE) {
            autowireByType(beanName, mbd, bw, newPvs);
        }
        // pvs是XML定義的屬性
        // 自動注入后,bean實際用到的屬性就應該要替換成自動注入后的屬性
        pvs = newPvs;
    }
	// 檢查是否有InstantiationAwareBeanPostProcessor
    // 前面說過了,這個後置處理器就是來完成屬性注入的
    boolean hasInstAwareBpps = hasInstantiationAwareBeanPostProcessors();
    
    //  是否需要依賴檢查,默認是不會進行依賴檢查的
    boolean needsDepCheck = (mbd.getDependencyCheck() != AbstractBeanDefinition.DEPENDENCY_CHECK_NONE);
	
    // 下面這段代碼有點麻煩了,因為涉及到版本問題
    // 其核心代碼就是調用了postProcessProperties完成了屬性注入
   
    PropertyDescriptor[] filteredPds = null;
    
    // 存在InstantiationAwareBeanPostProcessor,我們需要調用這類後置處理器的方法進行注入
		if (hasInstAwareBpps) {
			if (pvs == null) {
				pvs = mbd.getPropertyValues();
			}
			for (BeanPostProcessor bp : getBeanPostProcessors()) {
				if (bp instanceof InstantiationAwareBeanPostProcessor) {
					InstantiationAwareBeanPostProcessor ibp = (InstantiationAwareBeanPostProcessor) bp;
                    // 這句就是核心
					PropertyValues pvsToUse = ibp.postProcessProperties(pvs, bw.getWrappedInstance(), beanName);
					if (pvsToUse == null) {
						if (filteredPds == null) {
                            // 得到需要進行依賴檢查的屬性的集合
							filteredPds = filterPropertyDescriptorsForDependencyCheck(bw, mbd.allowCaching);
						}
                        //  這個方法已經過時了,放到這裏就是為了兼容老版本
						pvsToUse = ibp.postProcessPropertyValues(pvs, filteredPds, bw.getWrappedInstance(), beanName);
						if (pvsToUse == null) {
							return;
						}
					}
					pvs = pvsToUse;
				}
			}
		}
    // 需要進行依賴檢查
		if (needsDepCheck) {
			if (filteredPds == null) {
                // 得到需要進行依賴檢查的屬性的集合
				filteredPds = filterPropertyDescriptorsForDependencyCheck(bw, mbd.allowCaching);
			}
            // 對需要進行依賴檢查的屬性進行依賴檢查
			checkDependencies(beanName, mbd, filteredPds, pvs);
		}
    // 將XML中的配置屬性應用到Bean上
		if (pvs != null) {
			applyPropertyValues(beanName, mbd, bw, pvs);
		}
}

上面這段代碼主要可以拆分為三個部分

  1. 處理自動注入
  2. 處理屬性注入(主要指處理@Autowired註解),最重要
  3. 處理依賴檢查

處理自動注入

autowireByName

對應源碼如下:

protected void autowireByName(
    String beanName, AbstractBeanDefinition mbd, BeanWrapper bw, MutablePropertyValues pvs) {
    // 得到符合下麵條件的屬性名稱
    // 1.有setter方法
    // 2.需要進行依賴檢查
    // 3.不包含在XML配置中
    // 4.不是簡單類型(基本數據類型,枚舉,日期等)
    // 這裏可以看到XML配置優先級高於自動注入的優先級
    // 不進行依賴檢查的屬性,也不會進行屬性注入
    String[] propertyNames = unsatisfiedNonSimpleProperties(mbd, bw);
    for (String propertyName : propertyNames) {
        if (containsBean(propertyName)) {
            Object bean = getBean(propertyName);
            // 將自動注入的屬性添加到pvs中去
            pvs.add(propertyName, bean);
            // 註冊bean之間的依賴關係
            registerDependentBean(propertyName, beanName);
            // 忽略日誌
        }
        // 忽略日誌
    }
}

看到了嗎?代碼就是這麼的簡單,不是要通過名稱注入嗎?直接通過beanName調用getBean,完事兒

autowireByType

	protected void autowireByType(
			String beanName, AbstractBeanDefinition mbd, BeanWrapper bw, MutablePropertyValues pvs) {
		// 這個類型轉換器,主要是在處理@Value時需要使用
		TypeConverter converter = getCustomTypeConverter();
		if (converter == null) {
			converter = bw;
		}

		Set<String> autowiredBeanNames = new LinkedHashSet<>(4);
		// 得到符合下麵條件的屬性名稱
		// 1.有setter方法
		// 2.需要進行依賴檢查
		// 3.不包含在XML配置中
		// 4.不是簡單類型(基本數據類型,枚舉,日期等)
		// 這裏可以看到XML配置優先級高於自動注入的優先級
		String[] propertyNames = unsatisfiedNonSimpleProperties(mbd, bw);
		for (String propertyName : propertyNames) {
			try {
				PropertyDescriptor pd = bw.getPropertyDescriptor(propertyName);
				if (Object.class != pd.getPropertyType()) {
					// 這裏獲取到的就是setter方法的參數,因為我們需要按照類型進行注入嘛
					MethodParameter methodParam = BeanUtils.getWriteMethodParameter(pd);
					
                    // 如果是PriorityOrdered在進行類型匹配時不會去匹配factoryBean
					// 如果不是PriorityOrdered,那麼在查找對應類型的依賴的時候會會去匹factoryBean
				 	// 這就是Spring的一種設計理念,實現了PriorityOrdered接口的Bean被認為是一種
                    // 最高優先級的Bean,這一類的Bean在進行為了完成裝配而去檢查類型時,
                    // 不去檢查factoryBean
                    // 具體可以參考PriorityOrdered接口上的註釋文檔
					boolean eager = !(bw.getWrappedInstance() instanceof PriorityOrdered);
					// 將參數封裝成為一個依賴描述符
					// 依賴描述符會通過:依賴所在的類,字段名/方法名,依賴的具體類型等來描述這個依賴
					DependencyDescriptor desc = new AutowireByTypeDependencyDescriptor(methodParam, eager);
					// 解析依賴,這裡會處理@Value註解
                    // 另外,通過指定的類型到容器中查找對應的bean
					Object autowiredArgument = resolveDependency(desc, beanName, autowiredBeanNames, converter);
					if (autowiredArgument != null) {
						// 將查找出來的依賴屬性添加到pvs中,後面會將這個pvs應用到bean上
						pvs.add(propertyName, autowiredArgument);
					}
					// 註冊bean直接的依賴關係
					for (String autowiredBeanName : autowiredBeanNames) {
						registerDependentBean(autowiredBeanName, beanName);
						if (logger.isDebugEnabled()) {
							logger.debug("Autowiring by type from bean name '" + beanName + "' via property '" +
									propertyName + "' to bean named '" + autowiredBeanName + "'");
						}
					}
					autowiredBeanNames.clear();
				}
			}
			catch (BeansException ex) {
				throw new UnsatisfiedDependencyException(mbd.getResourceDescription(), beanName, propertyName, ex);
			}
		}
	}

resolveDependency

這個方法在Spring雜談 | 什麼是ObjectFactory?什麼是ObjectProvider?已經做過分析了,本文不再贅述。

可以看到,真正做事的方法是doResolveDependency

@Override
public Object resolveDependency(DependencyDescriptor descriptor, String requestingBeanName, Set<String> autowiredBeanNames, @Nullable TypeConverter typeConverter) throws BeansException {
	// descriptor代表當前需要注入的那個字段,或者方法的參數,也就是注入點
    // ParameterNameDiscovery用於解析方法參數名稱
    descriptor.initParameterNameDiscovery(getParameterNameDiscoverer());
    // 1. Optional<T>
    if (Optional.class == descriptor.getDependencyType()) {
        return createOptionalDependency(descriptor, requestingBeanName);
    // 2. ObjectFactory<T>、ObjectProvider<T>
    } else if (ObjectFactory.class == descriptor.getDependencyType() ||
             ObjectProvider.class == descriptor.getDependencyType()) {
        return new DependencyObjectProvider(descriptor, requestingBeanName);
    // 3. javax.inject.Provider<T>
    } else if (javaxInjectProviderClass == descriptor.getDependencyType()) {
        return new Jsr330Factory().createDependencyProvider(descriptor, requestingBeanName);
    } else {
        // 4. @Lazy
        Object result = getAutowireCandidateResolver().getLazyResolutionProxyIfNecessary(
            descriptor, requestingBeanName);
        // 5. 正常情況
        if (result == null) {
            result = doResolveDependency(descriptor, requestingBeanName, autowiredBeanNames, typeConverter);
        }
        return result;
    }
}
doResolveDependency
	public Object doResolveDependency(DependencyDescriptor descriptor, @Nullable String beanName,
			@Nullable Set<String> autowiredBeanNames, @Nullable TypeConverter typeConverter) throws BeansException {

		InjectionPoint previousInjectionPoint = ConstructorResolver.setCurrentInjectionPoint(descriptor);
		try {
			Object shortcut = descriptor.resolveShortcut(this);
			if (shortcut != null) {
				return shortcut;
			}
			// 依賴的具體類型
			Class<?> type = descriptor.getDependencyType();
			// 處理@Value註解,這裏得到的時候@Value中的值
			Object value = getAutowireCandidateResolver().getSuggestedValue(descriptor);
			if (value != null) {
				if (value instanceof String) {
					// 解析@Value中的佔位符
					String strVal = resolveEmbeddedValue((String) value);
					// 獲取到對應的bd
					BeanDefinition bd = (beanName != null && containsBean(beanName) ? getMergedBeanDefinition(beanName) : null);
					// 處理EL表達式
					value = evaluateBeanDefinitionString(strVal, bd);
				}
				// 通過解析el表達式可能還需要進行類型轉換
				TypeConverter converter = (typeConverter != null ? typeConverter : getTypeConverter());
				return (descriptor.getField() != null ?
						converter.convertIfNecessary(value, type, descriptor.getField()) :
						converter.convertIfNecessary(value, type, descriptor.getMethodParameter()));
			}
			
            // 對map,collection,數組類型的依賴進行處理
			// 最終會根據集合中的元素類型,調用findAutowireCandidates方法
			Object multipleBeans = resolveMultipleBeans(descriptor, beanName, autowiredBeanNames, typeConverter);
			if (multipleBeans != null) {
				return multipleBeans;
			}
			
            // 根據指定類型可能會找到多個bean
            // 這裏返回的既有可能是對象,也有可能是對象的類型
            // 這是因為到這裏還不能明確的確定當前bean到底依賴的是哪一個bean
            // 所以如果只會返回這個依賴的類型以及對應名稱,最後還需要調用getBean(beanName)
            // 去創建這個Bean
			Map<String, Object> matchingBeans = findAutowireCandidates(beanName, type, descriptor);
			// 一個都沒找到,直接拋出異常
			if (matchingBeans.isEmpty()) {
				if (isRequired(descriptor)) {
					raiseNoMatchingBeanFound(type, descriptor.getResolvableType(), descriptor);
				}
				return null;
			}

			String autowiredBeanName;
			Object instanceCandidate;
			// 通過類型找到了多個
			if (matchingBeans.size() > 1) {
				// 根據是否是主Bean
				// 是否是最高優先級的Bean
				// 是否是名稱匹配的Bean
				// 來確定具體的需要注入的Bean的名稱
                // 到這裏可以知道,Spring在查找依賴的時候遵循先類型再名稱的原則(沒有@Qualifier註解情況下)
				autowiredBeanName = determineAutowireCandidate(matchingBeans, descriptor);
				if (autowiredBeanName == null) {
					// 無法推斷出具體的名稱
					// 如果依賴是必須的,直接拋出異常
					// 如果依賴不是必須的,但是這個依賴類型不是集合或者數組,那麼也拋出異常
					if (isRequired(descriptor) || !indicatesMultipleBeans(type)) {
						return descriptor.resolveNotUnique(type, matchingBeans);
					}
					// 依賴不是必須的,但是依賴類型是集合或者數組,那麼返回一個null
					else {
						return null;
					}
				}
				instanceCandidate = matchingBeans.get(autowiredBeanName);
			}
			else {
				// 直接找到了一個對應的Bean
				Map.Entry<String, Object> entry = matchingBeans.entrySet().iterator().next();
				autowiredBeanName = entry.getKey();
				instanceCandidate = entry.getValue();
			}
			if (autowiredBeanNames != null) {
				autowiredBeanNames.add(autowiredBeanName);
			}
            
            // 前面已經說過了,這裏可能返回的是Bean的類型,所以需要進一步調用getBean
			if (instanceCandidate instanceof Class) {
				instanceCandidate = descriptor.resolveCandidate(autowiredBeanName, type, this);
			}
            
            // 做一些檢查,如果依賴是必須的,查找出來的依賴是一個null,那麼報錯
            // 查詢處理的依賴類型不符合,也報錯
			Object result = instanceCandidate;
			if (result instanceof NullBean) {
				if (isRequired(descriptor)) {
					raiseNoMatchingBeanFound(type, descriptor.getResolvableType(), descriptor);
				}
				result = null;
			}
			if (!ClassUtils.isAssignableValue(type, result)) {
				throw new BeanNotOfRequiredTypeException(autowiredBeanName, type, instanceCandidate.getClass());
			}
			return result;
		}
		finally {
			ConstructorResolver.setCurrentInjectionPoint(previousInjectionPoint);
		}
	}
findAutowireCandidates
protected Map<String, Object> findAutowireCandidates(
    @Nullable String beanName, Class<?> requiredType, DependencyDescriptor descriptor) {
	
    // 簡單來說,這裏就是到容器中查詢requiredType類型的所有bean的名稱的集合
    // 這裡會根據descriptor.isEager()來決定是否要匹配factoryBean類型的Bean
    // 如果isEager()為true,那麼會匹配factoryBean,反之,不會
    String[] candidateNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
        this, requiredType, true, descriptor.isEager());
   
    Map<String, Object> result = new LinkedHashMap<>(candidateNames.length);
   
    // 第一步會到resolvableDependencies這個集合中查詢是否已經存在了解析好的依賴
    // 像我們之所以能夠直接在Bean中注入applicationContext對象
    // 就是因為Spring之前就將這個對象放入了resolvableDependencies集合中
    for (Class<?> autowiringType : this.resolvableDependencies.keySet()) {
        if (autowiringType.isAssignableFrom(requiredType)) {
            Object autowiringValue = this.resolvableDependencies.get(autowiringType);
            
            // 如果resolvableDependencies放入的是一個ObjectFactory類型的依賴
            // 那麼在這裡會生成一個代理對象
            // 例如,我們可以在controller中直接注入request對象
            // 就是因為,容器啟動時就在resolvableDependencies放入了一個鍵值對
            // 其中key為:Request.class,value為:ObjectFactory
            // 在實際注入時放入的是一個代理對象
            autowiringValue = AutowireUtils.resolveAutowiringValue(autowiringValue, requiredType);
            if (requiredType.isInstance(autowiringValue)) {
                // 這裏放入的key不是Bean的名稱
                // value是實際依賴的對象
                result.put(ObjectUtils.identityToString(autowiringValue), autowiringValue);
                break;
            }
        }
    }
    
    // 接下來開始對之前查找出來的類型匹配的所有BeanName進行處理
    for (String candidate : candidateNames) {
        // 不是自引用,什麼是自引用?
        // 1.候選的Bean的名稱跟需要進行注入的Bean名稱相同,意味着,自己注入自己
        // 2.或者候選的Bean對應的factoryBean的名稱跟需要注入的Bean名稱相同,
        // 也就是說A依賴了B但是B的創建又需要依賴A
        // 要符合注入的條件
        if (!isSelfReference(beanName, candidate) && isAutowireCandidate(candidate, descriptor)) {
            // 調用addCandidateEntry,加入到返回集合中,後文有對這個方法的分析
            addCandidateEntry(result, candidate, descriptor, requiredType);
        }
    }
    
    // 排除自引用的情況下,沒有找到一個合適的依賴
    if (result.isEmpty() && !indicatesMultipleBeans(requiredType)) {
        // 1.先走fallback邏輯,Spring提供的一個擴展吧,感覺沒什麼卵用
        // 默認情況下fallback的依賴描述符就是自身
        DependencyDescriptor fallbackDescriptor = descriptor.forFallbackMatch();
        for (String candidate : candidateNames) {
            if (!isSelfReference(beanName, candidate) && isAutowireCandidate(candidate, fallbackDescriptor)) {
                addCandidateEntry(result, candidate, descriptor, requiredType);
            }
        }
        // fallback還是失敗
        if (result.isEmpty()) {
            // 處理自引用
            // 從這裏可以看出,自引用的優先級是很低的,只有在容器中真正的只有這個Bean能作為
            // 候選者的時候,才會去處理,否則自引用是被排除掉的
            for (String candidate : candidateNames) {
                if (isSelfReference(beanName, candidate) &&
                    // 不是一個集合或者
                    // 是一個集合,但是beanName跟candidate的factoryBeanName相同
                    (!(descriptor instanceof MultiElementDescriptor) || !beanName.equals(candidate)) &&
                    isAutowireCandidate(candidate, fallbackDescriptor)) {
                    addCandidateEntry(result, candidate, descriptor, requiredType);
                }
            }
        }
    }
    return result;
}


// candidates:就是findAutowireCandidates方法要返回的候選集合
// candidateName:當前的這個候選Bean的名稱
// descriptor:依賴描述符
// requiredType:依賴的類型
private void addCandidateEntry(Map<String, Object> candidates, String candidateName,
                               DependencyDescriptor descriptor, Class<?> requiredType) {
	
    // 如果依賴是一個集合,或者容器中已經包含這個單例了
    // 那麼直接調用getBean方法創建或者獲取這個Bean
    if (descriptor instanceof MultiElementDescriptor || containsSingleton(candidateName)) {
        Object beanInstance = descriptor.resolveCandidate(candidateName, requiredType, this);
        candidates.put(candidateName, (beanInstance instanceof NullBean ? null : beanInstance));
    }
    // 如果依賴的類型不是一個集合,這個時候還不能確定到底要使用哪個依賴,
    // 所以不能將這些Bean創建出來,所以這個時候,放入candidates是Bean的名稱以及類型
    else {
        candidates.put(candidateName, getType(candidateName));
    }
}

處理屬性注入(@Autowired)

postProcessProperties

// 在applyMergedBeanDefinitionPostProcessors方法執行的時候,
// 已經解析過了@Autowired註解(buildAutowiringMetadata方法)
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) {
    // 這裏獲取到的是解析過的緩存好的注入元數據
    InjectionMetadata metadata = findAutowiringMetadata(beanName, bean.getClass(), pvs);
    try {
        // 直接調用inject方法
        // 存在兩種InjectionMetadata
        // 1.AutowiredFieldElement
        // 2.AutowiredMethodElement
        // 分別對應字段的屬性注入以及方法的屬性注入
        metadata.inject(bean, beanName, pvs);
    }
    catch (BeanCreationException ex) {
        throw ex;
    }
    catch (Throwable ex) {
        throw new BeanCreationException(beanName, "Injection of autowired dependencies failed", ex);
    }
    return pvs;
}
字段的屬性注入
// 最終反射調用filed.set方法
protected void inject(Object bean, @Nullable String beanName, @Nullable PropertyValues pvs) throws Throwable {
    Field field = (Field) this.member;
    Object value;
    if (this.cached) {
        // 第一次注入的時候肯定沒有緩存
        // 這裏也是對原型情況的處理
        value = resolvedCachedArgument(beanName, this.cachedFieldValue);
    } else {
        DependencyDescriptor desc = new DependencyDescriptor(field, this.required);
        desc.setContainingClass(bean.getClass());
        Set<String> autowiredBeanNames = new LinkedHashSet<>(1);
        Assert.state(beanFactory != null, "No BeanFactory available");
        TypeConverter typeConverter = beanFactory.getTypeConverter();
        try {
            // 這裏可以看到,對@Autowired註解在字段上的處理
            // 跟byType下自動注入的處理是一樣的,就是調用resolveDependency方法
            value = beanFactory.resolveDependency(desc, beanName, autowiredBeanNames, typeConverter);
        } catch (BeansException ex) {
            throw new UnsatisfiedDependencyException(null, beanName, new InjectionPoint(field), ex);
        }
        synchronized (this) {
            // 沒有緩存過的話,這裏需要進行緩存
            if (!this.cached) {
                if (value != null || this.required) {
                    this.cachedFieldValue = desc;
                    // 註冊Bean之間的依賴關係
                    registerDependentBeans(beanName, autowiredBeanNames);
                    // 如果這個類型的依賴只存在一個的話,我們就能確定這個Bean的名稱
                    // 那麼直接將這個名稱緩存到ShortcutDependencyDescriptor中
                    // 第二次進行注入的時候就可以直接調用getBean(beanName)得到這個依賴了
                    // 實際上正常也只有一個,多個就報錯了
                    // 另外這裡會過濾掉@Vlaue得到的依賴
                    if (autowiredBeanNames.size() == 1) {
                        String autowiredBeanName = autowiredBeanNames.iterator().next();
                        // 通過resolvableDependencies這個集合找的依賴不滿足containsBean條件
                        // 不會進行緩存,因為緩存實際還是要調用getBean,而resolvableDependencies
                        // 是沒法通過getBean獲取的
                        if (beanFactory.containsBean(autowiredBeanName) &&
                            beanFactory.isTypeMatch(autowiredBeanName, field.getType())) {							 // 依賴描述符封裝成ShortcutDependencyDescriptor進行緩存
                            this.cachedFieldValue = new ShortcutDependencyDescriptor(
                                desc, autowiredBeanName, field.getType());
                        }
                    }
                } else {
                    this.cachedFieldValue = null;
                }
                this.cached = true;
            }
        }
    }
    if (value != null) {
        // 反射調用Field.set方法
        ReflectionUtils.makeAccessible(field);
        field.set(bean, value);
    }
}
方法的屬性注入
// 代碼看着很長,實際上邏輯跟字段注入基本一樣
protected void inject(Object bean, @Nullable String beanName, @Nullable PropertyValues pvs) throws Throwable {
    // 判斷XML中是否配置了這個屬性,如果配置了直接跳過
    // 換而言之,XML配置的屬性優先級高於@Autowired註解
    if (checkPropertySkipping(pvs)) {
        return;
    }
    Method method = (Method) this.member;
    Object[] arguments;
    if (this.cached) {
        arguments = resolveCachedArguments(beanName);
    } else {
        // 通過方法參數類型構造依賴描述符
        // 邏輯基本一樣的,最終也是調用beanFactory.resolveDependency方法
        Class<?>[] paramTypes = method.getParameterTypes();
        arguments = new Object[paramTypes.length];
        DependencyDescriptor[] descriptors = new DependencyDescriptor[paramTypes.length];
        Set<String> autowiredBeans = new LinkedHashSet<>(paramTypes.length);
        Assert.state(beanFactory != null, "No BeanFactory available");
        TypeConverter typeConverter = beanFactory.getTypeConverter();
        
        // 遍歷方法的每個參數
        for (int i = 0; i < arguments.length; i++) {
            MethodParameter methodParam = new MethodParameter(method, i);
            DependencyDescriptor currDesc = new DependencyDescriptor(methodParam, this.required);
            currDesc.setContainingClass(bean.getClass());
            descriptors[i] = currDesc;
            try {
                // 還是要調用這個方法
                Object arg = beanFactory.resolveDependency(currDesc, beanName, autowiredBeans, typeConverter);
                if (arg == null && !this.required) {
                    arguments = null;
                    break;
                }
                arguments[i] = arg;
            } catch (BeansException ex) {
                throw new UnsatisfiedDependencyException(null, beanName, new InjectionPoint(methodParam), ex);
            }
        }
        synchronized (this) {
            if (!this.cached) {
                if (arguments != null) {
                    Object[] cachedMethodArguments = new Object[paramTypes.length];
                    System.arraycopy(descriptors, 0, cachedMethodArguments, 0, arguments.length);  
                    // 註冊bean之間的依賴關係
                    registerDependentBeans(beanName, autowiredBeans);
                    
                    // 跟字段注入差不多,存在@Value註解,不進行緩存
                    if (autowiredBeans.size() == paramTypes.length) {
                        Iterator<String> it = autowiredBeans.iterator();
                        for (int i = 0; i < paramTypes.length; i++) {
                            String autowiredBeanName = it.next();
                            if (beanFactory.containsBean(autowiredBeanName) &&
                                beanFactory.isTypeMatch(autowiredBeanName, paramTypes[i])) {
                                cachedMethodArguments[i] = new ShortcutDependencyDescriptor(
                                    descriptors[i], autowiredBeanName, paramTypes[i]);
                            }
                        }
                    }
                    this.cachedMethodArguments = cachedMethodArguments;
                } else {
                    this.cachedMethodArguments = null;
                }
                this.cached = true;
            }
        }
    }
    if (arguments != null) {
        try {
            // 反射調用方法
            // 像我們的setter方法就是在這裏調用的
            ReflectionUtils.makeAccessible(method);
            method.invoke(bean, arguments);
        } catch (InvocationTargetException ex) {
            throw ex.getTargetException();
        }
    }
}

處理依賴檢查

protected void checkDependencies(
    String beanName, AbstractBeanDefinition mbd, PropertyDescriptor[] pds, PropertyValues pvs)
    throws UnsatisfiedDependencyException {

    int dependencyCheck = mbd.getDependencyCheck();
    for (PropertyDescriptor pd : pds) {
        
        // 有set方法但是在pvs中沒有對應屬性,那麼需要判斷這個屬性是否要進行依賴檢查
        // 如果需要進行依賴檢查的話,就需要報錯了
        // pvs中保存的是自動注入以及XML配置的屬性
        if (pd.getWriteMethod() != null && !pvs.contains(pd.getName())) {
           
            // 是否是基本屬性,枚舉/日期等也包括在內
            boolean isSimple = BeanUtils.isSimpleProperty(pd.getPropertyType());
           	
            // 如果DEPENDENCY_CHECK_ALL,對任意屬性都開啟了依賴檢查,報錯
            // DEPENDENCY_CHECK_SIMPLE,對基本屬性開啟了依賴檢查並且是基本屬性,報錯
            // DEPENDENCY_CHECK_OBJECTS,對非基本屬性開啟了依賴檢查並且不是非基本屬性,報錯
            boolean unsatisfied = (dependencyCheck == AbstractBeanDefinition.DEPENDENCY_CHECK_ALL) ||
                (isSimple && dependencyCheck == AbstractBeanDefinition.DEPENDENCY_CHECK_SIMPLE) ||
                (!isSimple && dependencyCheck == AbstractBeanDefinition.DEPENDENCY_CHECK_OBJECTS);
            
            if (unsatisfied) {
                throw new UnsatisfiedDependencyException(mbd.getResourceDescription(), beanName, pd.getName(),
                                                         "Set this property value or disable dependency checking for this bean.");
            }
        }
    }
}

將解析出來的屬性應用到Bean上

到這一步解析出來的屬性主要有三個來源

  1. XML中配置的
  2. 通過byName的方式自動注入的
  3. 通過byType的方式自動注入的

但是在應用到Bean前還需要做一步類型轉換,這一部分代碼實際上跟我們之前在Spring官網閱讀(十四)Spring中的BeanWrapper及類型轉換介紹的差不多,而且因為XML跟自動注入的方式都不常見,正常@Autowired的方式進行注入的話,這個方法沒有什麼用,所以本文就不再贅述。

總結

本文我們主要分析了Spring在屬性注入過程中的相關代碼,整個屬性注入可以分為兩個部分

  1. @Autowired/@Vale的方式完成屬性注入
  2. 自動注入(byType/byName

完成屬性注入的核心方法其實就是doResolveDependencydoResolveDependency這個方法的邏輯簡單來說分為兩步:

  1. 通過依賴類型查詢到所有的類型匹配的bean的名稱
  2. 如果找到了多個的話,再根據依賴的名稱匹配對應的Bean的名稱
  3. 調用getBean得到這個需要被注入的Bean
  4. 最後反射調用字段的set方法完成屬性注入

從上面也可以知道,其實整個屬性注入的邏輯是很簡單的。

如果本文對你有幫助的話,記得點個贊吧!也歡迎關注我的公眾號,微信搜索:程序員DMZ,或者掃描下方二維碼,跟着我一起認認真真學Java,踏踏實實做一個coder。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

※為什麼 USB CONNECTOR 是電子產業重要的元件?

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

※台北網頁設計公司全省服務真心推薦

※想知道最厲害的網頁設計公司"嚨底家"!

新北清潔公司,居家、辦公、裝潢細清專業服務

※推薦評價好的iphone維修中心

為什麼用抓包工具看HTTPS包是明文的

測試或者開發調試的過程中,經常會進行抓包分析,並且裝上抓包工具的證書就能抓取 HTTPS 的數據包並显示。由此就產生了一個疑問,為什麼抓包工具裝上證書後就能抓到 HTTPS 的包並显示呢?不是說 HTTPS 是加密傳輸的嗎?

今天這篇文章就來探究下上面這個問題,要解釋清楚這個問題,我會通過解答以下兩個問題來講述:

  1. HTTPS 到底是什麼?
  2. 抓包工具抓包的原理?

HTTPS 到底是什麼

HTTP 作為一種被廣泛使用的傳輸協議,也存在一些的缺點:

  1. 無狀態(可以通過 Cookie 或 Session 解決);
  2. 明文傳輸;
  3. 不安全;

為了解決 “明文” 和 “不安全” 兩個問題,就產生了 HTTPSHTTPS 不是一種單獨的協議,它是由 HTTP + SSL/TLS 組成。

HTTP與HTTPS

所以要理解 HTTPS 就只需在 HTTP 的基礎上理解 SSL/TLS (TLS 是 SSL 的後續版本,現在一般使用 TLS),下面就來了解下 TLS 是什麼。

TLS

傳輸層安全性協議(英語:Transport Layer Security,縮寫:TLS)及其前身安全套接層(英語:Secure Sockets Layer,縮寫:SSL)是一種安全協議,目的是為互聯網通信提供安全及數據完整性保障。

TLS 由記錄協議、握手協議、警報協議、變更密碼規範協議、擴展協議等幾個子協議組成,綜合使用了對稱加密、非對稱加密、身份認證等許多密碼學前沿技術。

  • 記錄協議 規定
    TLS 收發數據的基本單位為:記錄。類似
    TCP 里的
    segment,所有其它子協議都需要通過記錄協議發出。
  • 警報協議 的職責是向對方發出警報信息,類似於
    HTTP 里的狀態碼。
  • 握手協議
    TLS 里最複雜的子協議,瀏覽器和服務器在握手過程中會協商
    TLS 版本號、隨機數、密碼套件等信息,然後交換證書和密鑰參數,最終雙方協商得到會話密鑰,用於後續的混合加密系統。
  • 變更密碼規範協議 用於告知對方,後續的數據都將使用加密傳輸。

TLS 的握手過程:

TLS握手過程

握手過程抓包显示:

TLS抓包
TLS所傳輸的數據

交換密鑰的過程為:

  1. 客戶端發起一個請求給服務器;
  2. 服務器生成一對非對稱的公鑰(
    pubkey)和私鑰(
    privatekey),然後把公鑰附加到一個
    CA数字證書 上返回給客戶端;
  3. 客戶端校驗該證書是否合法(通過瀏覽器內置的廠商根證書等手段校驗),然後從證書中提取出公鑰(
    pubkey);
  4. 客戶端生成一個隨機數(
    key),然後使用公鑰(
    pubkey)對這個隨機數進行加密后發送給服務器;
  5. 服務器利用私鑰(
    privatekey)對收到的隨機數密文進行解密得到
    key ;
  6. 後續客戶端和服務器傳輸數據使用該
    key 進行加密后再傳輸;

抓包工具抓包的原理

先來看看抓 HTTP 包的原理

HTTP抓包過程

  1. 首先抓包工具會提供出代理服務,客戶端需要連接該代理;
  2. 客戶端發出
    HTTP 請求時,會經過抓包工具的代理,抓包工具將請求的原文進行展示;
  3. 抓包工具使用該原文將請求發送給服務器;
  4. 服務器返回結果給抓包工具,抓包工具將返回結果進行展示;
  5. 抓包工具將服務器返回的結果原樣返回給客戶端;

抓包工具就相當於個透明的中間人,數據經過的時候它一隻手接到數據,然後另一隻手把數據傳出去。

再來看看 HTTPS 的抓包

HTTPS抓包過程

這個時候抓包工具對客戶端來說相當於服務器,對服務器來說相當於客戶端。在這個傳輸過程中,客戶端會以為它就是目標服務器,服務器也會以為它就是請求發起的客戶端。

  1. 客戶端連接抓包工具提供的代理服務;
  2. 客戶端需要安裝抓包工具的根證書;
  3. 客戶端發出
    HTTPS 請求,抓包工具模擬服務器與客戶端進行
    TLS 握手交換密鑰等流程;
  4. 抓包工具發送一個
    HTTPS 請求給客戶端請求的目標服務器,並與目標服務器進行
    TLS 握手交換密鑰等流程;
  5. 客戶端使用與抓包工具協定好的密鑰加密數據后發送給抓包工具;
  6. 抓包工具使用與客戶端協定好的密鑰解密數據,並將結果進行展示;
  7. 抓包工具將解密后的客戶端數據,使用與服務器協定好的密鑰進行加密后發送給目標服務器;
  8. 服務器解密數據后,做對應的邏輯處理,然後將返回結果使用與抓包工具協定好的密鑰進行加密發送給抓包工具;
  9. 抓包工具將服務器返回的結果,用與服務器協定好的密鑰解密,並將結果進行展示;
  10. 抓包工具將解密后的服務器返回數據,使用與客戶端協定好的密鑰進行加密后發送給客戶端;
  11. 客戶端解密數據;

總結

  • HTTPS 不是單獨的一個協議,它是
    HTTP +
    SSL/TLS 的組合;
  • TLS 是傳輸層安全性協議,它會對傳輸的
    HTTP 數據進行加密,使用非對稱加密和對稱加密的混合方式;
  • 抓包工具的原理就是“偽裝“,對客戶端偽裝成服務器,對服務器偽裝成客戶端;
  • 使用抓包工具抓
    HTTPS 包必須要將抓包工具的證書安裝到客戶端本地,並設置信任;
  • HTTPS 數據只是在傳輸時進行了加密,而抓包工具是接收到數據后再重新加密轉發,所以抓包工具抓到的
    HTTPS 包可以直接看到明文;

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理
【其他文章推薦】

USB CONNECTOR掌控什麼技術要點? 帶您認識其相關發展及效能

台北網頁設計公司這麼多該如何選擇?

※智慧手機時代的來臨,RWD網頁設計為架站首選

※評比南投搬家公司費用收費行情懶人包大公開

※幫你省時又省力,新北清潔一流服務好口碑

※回頭車貨運收費標準

【離散優化】覆蓋問題

覆蓋問題

我們知道設施選址問題有兩類基礎問題,分別是中值問題和覆蓋問題,下面要介紹的就是覆蓋問題。

什麼是覆蓋問題?

覆蓋問題是以所期望的服務範圍滿足大多數或者所有用戶需求為前提,確定設施的位置。覆蓋模型的思想是離服務設施較近的用戶越多,則服務越好。

覆蓋問題的分類

覆蓋問題主要分為兩類:

  • 集合覆蓋問題(Location Set Covering Problem,LSCP)
  • 最大覆蓋問題(Maximum Covering Location Problem,MCLP)

覆蓋模型常用於哪些場景?

由於 P-中值模型常以總距離或者總時間作為測度指標,使得其並不適用於一些特殊的場景,比如消防中心和救護車等應急設施的區位選址問題,而覆蓋模型則比較適用於這些場景。

如何定義覆蓋?

如果需求點 \(i\) 到備選設施點 \(j\) 的距離或者時間小於臨界值 \(D_c\),那麼稱需求點 \(i\) 被候選設施點 \(j\) 覆蓋。、

下面介紹兩類覆蓋問題的數學模型表達

集合覆蓋問題 (Location Set Covering Problem,LSCP)

目標函數:

\[\min \sum_{j \in J}x_j \]

約束:

\[\sum_{j \in N_i} x_j \geqslant 1 \quad \forall i \in I \tag{c-1} \]

\[x_j \in \{0, 1\} \quad \forall j \in J \tag{c-2} \]

其中,

  • \(N_i = \{j:a_{ij}=1\}\) 是覆蓋需求點 \(i\) 的候選設施點的集合,變量 \(a_{ij}\) 用來判斷需求點 \(i\) 是否被候選設施點 \(j\) 覆蓋,若是,則 \(a_{ij}=1\),否則 \(a_{ij}=0\)
  • 目標函數旨在尋求設施總量最小
  • 約束 \(c-1\) 保證每個需求點至少被一個設施服務範圍所覆蓋
  • 約束 \(c-2\) 是決策變量的取值範圍

在某些場景中,集合覆蓋問題有以下兩個缺點:

  • 為了保證所有需求點均被覆蓋而引入過多的設施,以至於超出預算
  • 模型無法區分需求點的需求強度

現實生活中,常常由於預算或者資源的約束,有限的設施不能保證空間中所有需求點都被覆蓋,此時,優先考慮需求強度大的需求點是十分必要的,下面要介紹的最大覆蓋模型就是為了解決這個問題而被提出。

最大覆蓋問題(Maximum Covering Location Problem,MCLP)

目標函數

\[\max \sum_{i \in N_i} \omega_iz_i \]

約束

\[z_i \leqslant \sum_{j \in N_i}x_j \quad \forall i \in I \tag{c-1} \]

\[\sum_{j\in J}x_j = p \tag{c-2} \]

\[x_j \in \{0,1\} \quad \forall j \in J \tag{c-3} \]

\[z_i = \{0, 1\} \quad \forall i \in I \tag{c-4} \]

其中,

  • \(\omega_i\) 為需求點 \(i\) 的需求強度

  • \(z_i\) 用來判斷需求點 \(i\) 是否被覆蓋,若覆蓋,則為 1,否則為 0

  • 目標函數旨在尋求有限設施(\(p\) 個)覆蓋的需求最多

  • 約束 \(c-1\) 要求除非在備選設施點中已定位一個設施可以覆蓋需求點 \(i\),否則需求點 \(i\) 將不被記作被覆蓋

  • 約束 \(c-2\) 限制設施的總數為 \(p\)

  • 約束 \(c-3, c-4\) 是決策變量的取值範圍

更多種類的選址問題

以上介紹的覆蓋問題的基礎模型框架,然而具體問題一般是較為複雜的設施選址問題,這就需要我們對基礎模型設置不同的條件從而進行擴展,比如:

  • 用於環境污染防治的鄰避型設施選址問題
  • 用於不同服務等級的層次型設置選址問題
  • 用於商業競爭的競爭型設施選址問題
  • 選址問題也開始考慮動態、不確定性等因素

總結

總結以上兩類問題,我們可以發現最大覆蓋模型和集合覆蓋模型的主要區別在於對設施數量和需求強度的關注不同,前者一般適用於建設經費充足或者設施成本相同的情況,後者則適用於有設施成本約束的選址決策。

參考文獻

本文內容主要從論文《設施選址問題中的基礎模型與求解方法比較》總結而來。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

網頁設計公司推薦不同的風格,搶佔消費者視覺第一線

※想知道購買電動車哪裡補助最多?台中電動車補助資訊懶人包彙整

南投搬家公司費用,距離,噸數怎麼算?達人教你簡易估價知識!

※教你寫出一流的銷售文案?

※超省錢租車方案

記一次uboot升級過程的兩個坑

背景

之前做過一次uboot的升級,當時留下了一些記錄,本文摘錄其中比較有意思的兩個問題。

啟動失敗問題

問題簡述

uboot代碼中用到了一個庫,考慮到庫本身跟uboot版本沒什麼關係,就直接把舊的庫文件拷貝過來使用。結果編譯鏈接是沒問題,啟動卻會卡住。

消失的打印

為了明確卡住的位置,就去修改了庫的源碼,添加一些打印(此時還是在舊版本uboot下編譯的),結果發現卡住的位置或隨着添加打印的變化而變化,且有些打印語句,添加后未打印出來。

我決定先從這些神秘消失的打印入手。

分析下uboot中的printf實現,最底層就是寫寄存器,是一個同步的函數,也沒什麼可疑的地方。

為了確認打印不出來的時候,到底有沒有調用到printf,我決定給printf增加一個計數器,在gd結構體中,增加一個printf_count字段,初始化為0,每次打印時執行printf_count++並打印出值。

設計這個試驗,本意是確認未打印出來時是否確實也調用到了printf,但卻有了別的發現,實驗結果中printf_count值會異常變化,不是按打印順序遞增,而是會突變成很大的異常值。

printf_countgd結構體的成員,那就是gd的問題了。進一步將uboot全局結構體gd的地址打印出來。確認了原因是gd結構體的指針變化了。

這也可以解釋部分打印消失的現象,原因是我們在gd中有另一個字段,用於控制打印等級。當gd被改動了,printf就可能解析出錯,誤以為打印等級為0而提前返回。

gd的實現

那麼好端端的,gd為什麼會被改了呢?這就要先看看gd到底是怎麼實現的了。

uboot中維護了一個全局的結構體gd。在代碼中加入

DECLARE_GLOBAL_DATA_PTR;

即可使用gd指針訪問這個全局結構體,許多地方都會藉助gd來保存傳遞信息。

進一步看看這個宏的定義

舊版本uboot:
#define DECLARE_GLOBAL_DATA_PTR        register volatile gd_t *gd asm ("r8")

新版本uboot:
#define DECLARE_GLOBAL_DATA_PTR        register volatile gd_t *gd asm ("r9")

居然不一樣,一個是將gd的值放到r8寄存器,一個是放在r9寄存器。

那麼就可以猜測到,庫是在舊版本uboot中編譯出來的,可能使用了r9,那麼放到新版本uboot中去,就會破壞r9寄存器中保存的gd值,導致一系列依賴gd的代碼不能正常工作。

驗證改動

為了求證,將庫反彙編出來,發現確實避開了r8寄存器,但使用了r9寄存器。

說明uboot在指定gd寄存器的同時,還有某種方法讓其他代碼不使用這個寄存器。

那是不是把舊uboot中的這個r8改成r9,重新編譯庫就可以了呢?試一下,還是不行。

那麼禁止其他代碼使用r8寄存器肯定就是通過別的方式實現的了。簡單粗暴地在舊版本uboot下搜索r8,去掉.c .h等類型后,很容易發現了

./arch/arm/cpu/armv7/config.mk:24:PLATFORM_RELFLAGS += -fno-common -ffixed-r8 -msoft-floa

-ffixed-r8修改為-ffixed-r9,重新編譯出庫,這回就可以正常工作了,打印正常,啟動正常。反彙編出來也可以看到,新編譯出來的庫用了r8沒有用r9

當然更好的改法,是直接在新版本的uboot中編譯,這是最可靠的。

追本溯源

話說回來,為什麼兩個版本的uboot,會使用不同的寄存器呢?難道有什麼坑?

這就得去翻一下git記錄了。

commit fe1378a961e508b31b1f29a2bb08ba1dac063155
Author: Jeroen Hofstee <jeroen@myspectrum.nl>
Date:   Sat Sep 21 14:04:41 2013 +0200

    ARM: use r9 for gd
    
    To be more EABI compliant and as a preparation for building
    with clang, use the platform-specific r9 register for gd
    instead of r8.
    
    note: The FIQ is not updated since it is not used in u-boot,
    and under discussion for the time being.
    
    The following checkpatch warning is ignored:
    WARNING: Use of volatile is usually wrong: see
    Documentation/volatile-considered-harmful.txt
    
    Signed-off-by: Jeroen Hofstee <jeroen@myspectrum.nl>
    cc: Albert ARIBAUD <albert.u.boot@aribaud.net>

git記錄中,也可以確認完整地將r8切換到r9,都需要做哪些修改

diff --git a/arch/arm/config.mk b/arch/arm/config.mk
index 16c2e3d1e0..d0cf43ff41 100644
--- a/arch/arm/config.mk
+++ b/arch/arm/config.mk
@@ -17,7 +17,7 @@ endif
 
 LDFLAGS_FINAL += --gc-sections
 PLATFORM_RELFLAGS += -ffunction-sections -fdata-sections \
-                     -fno-common -ffixed-r8 -msoft-float
+                     -fno-common -ffixed-r9 -msoft-float
 
 # Support generic board on ARM
 __HAVE_ARCH_GENERIC_BOARD := y
diff --git a/arch/arm/cpu/armv7/lowlevel_init.S b/arch/arm/cpu/armv7/lowlevel_init.S
index 82b2b86520..69e3053a42 100644
--- a/arch/arm/cpu/armv7/lowlevel_init.S
+++ b/arch/arm/cpu/armv7/lowlevel_init.S
@@ -22,11 +22,11 @@ ENTRY(lowlevel_init)
        ldr     sp, =CONFIG_SYS_INIT_SP_ADDR
        bic     sp, sp, #7 /* 8-byte alignment for ABI compliance */
 #ifdef CONFIG_SPL_BUILD
-       ldr     r8, =gdata
+       ldr     r9, =gdata
 #else
        sub     sp, #GD_SIZE
        bic     sp, sp, #7
-       mov     r8, sp
+       mov     r9, sp
 #endif
        /*
         * Save the old lr(passed in ip) and the current lr to stack
diff --git a/arch/arm/include/asm/global_data.h b/arch/arm/include/asm/global_data.h
index 79a9597419..e126436093 100644
--- a/arch/arm/include/asm/global_data.h
+++ b/arch/arm/include/asm/global_data.h
@@ -47,6 +47,6 @@ struct arch_global_data {
 
 #include <asm-generic/global_data.h>
 
-#define DECLARE_GLOBAL_DATA_PTR     register volatile gd_t *gd asm ("r8")
+#define DECLARE_GLOBAL_DATA_PTR     register volatile gd_t *gd asm ("r9")
 
 #endif /* __ASM_GBL_DATA_H */
diff --git a/arch/arm/lib/crt0.S b/arch/arm/lib/crt0.S
index 960d12e732..ac54b9359a 100644
--- a/arch/arm/lib/crt0.S
+++ b/arch/arm/lib/crt0.S
@@ -69,7 +69,7 @@ ENTRY(_main)
        bic     sp, sp, #7      /* 8-byte alignment for ABI compliance */
        sub     sp, #GD_SIZE    /* allocate one GD above SP */
        bic     sp, sp, #7      /* 8-byte alignment for ABI compliance */
-       mov     r8, sp          /* GD is above SP */
+       mov     r9, sp          /* GD is above SP */
        mov     r0, #0
        bl      board_init_f
 
@@ -81,15 +81,15 @@ ENTRY(_main)
  * 'here' but relocated.
  */
 
-       ldr     sp, [r8, #GD_START_ADDR_SP]     /* sp = gd->start_addr_sp */
+       ldr     sp, [r9, #GD_START_ADDR_SP]     /* sp = gd->start_addr_sp */
        bic     sp, sp, #7      /* 8-byte alignment for ABI compliance */
-       ldr     r8, [r8, #GD_BD]                /* r8 = gd->bd */
-       sub     r8, r8, #GD_SIZE                /* new GD is below bd */
+       ldr     r9, [r9, #GD_BD]                /* r9 = gd->bd */
+       sub     r9, r9, #GD_SIZE                /* new GD is below bd */
 
        adr     lr, here
-       ldr     r0, [r8, #GD_RELOC_OFF]         /* r0 = gd->reloc_off */
+       ldr     r0, [r9, #GD_RELOC_OFF]         /* r0 = gd->reloc_off */
        add     lr, lr, r0
-       ldr     r0, [r8, #GD_RELOCADDR]         /* r0 = gd->relocaddr */
+       ldr     r0, [r9, #GD_RELOCADDR]         /* r0 = gd->relocaddr */
        b       relocate_code
 here:
 
@@ -111,8 +111,8 @@ clbss_l:cmp r0, r1                  /* while not at end of BSS */
        bl red_led_on
 
        /* call board_init_r(gd_t *id, ulong dest_addr) */
-       mov     r0, r8                  /* gd_t */
-       ldr     r1, [r8, #GD_RELOCADDR] /* dest_addr */
+       mov     r0, r9                  /* gd_t */
+       ldr     r1, [r9, #GD_RELOCADDR] /* dest_addr */
        /* call board_init_r */
        ldr     pc, =board_init_r       /* this is auto-relocated! */

啟動慢問題

問題簡述

填了幾個坑之後,新的uboot可以啟動到內核了,但發現啟動速度非常慢,內核啟動速度慢了接近10倍!明明是同一個內核,為什麼差異這麼大。

排查寄存器

初步排查了下設備樹配置,以及uboot跳轉內核前的一些關鍵寄存器,確實在兩個版本的uboot中有所不同,但具體去看這些不同,發現都不會影響速度,將一些驅動對齊之後寄存器差異基本就消失了。

差異的分界

那再細看,kernel的速度有差異,uboot呢?在哪個時間點之後,速度開始產生差異?

嘗試在兩個版本的uboot中插入一些操作,對比時間戳,發現兩個uboot在某個節點之後的速度確實有區別。

進一步排查,原來是在打開cache操作之後,舊uboot的速度就會比新uboot快。嘗試將舊ubootcache關掉,則二者基本一致。嘗試將舊uboot操作cache的代碼,移植到新uboot,未發生改變。

此時可確認新uboot的開cache有問題。但覺得這個跟kernel啟動慢沒關係。因為uboot進入kernel之前都會關cache,由kernel自己去重新打開。

也就是不管是用哪份uboot,也不管uboot中是否開了cache,對kernel階段都應該沒有影響才對。

於是記錄下來uboot的這個問題,待後續修復。先繼續找kernel啟動慢的原因。(注:現在看來當時的做法是有問題的,這裏的異常這麼明顯,應該設法追蹤下去找出原因才對)

鎖定uboot

uboot的嫌疑非常大,但還不能完全確認,因為uboot之前還有一級spl。是否會是spl的問題呢?

嘗試改用新spl+舊uboot,啟動速度正常。而新spl+新uboot的啟動速度則很慢,其他因素都不變,說明問題確實出在uboot階段。

多做or少做

當時到這一步就卡住了,直接比較兩份uboot的代碼不太現實,差異太大了。

後來我就給自己提了個問題,到底新uboot是多做了某件事情,還是少做了某件事情?

換個說法,目前已知

spl --> 舊uboot --> kernel(速度快)
spl --> 新uboot --> kernel(速度快)

但到底是以下的情況A還是情況B呢?

A: spl(速度慢) --> 舊uboot(做了某個會提升速度的操作) --> kernel(速度快)
   spl(速度慢) --> 新uboot(少做了某個會提升速度的操作) --> kernel(速度慢)

B: spl(速度快) --> 舊uboot(沒做特殊操作) --> kernel(速度快)
   spl(速度快) --> 新uboot(多做了某個會限制速度的操作) --> kernel(速度慢)

為了驗證,我決定讓spl直接啟動內核,看看內核到底是快是慢。

支持過程碰到了一些小問題

1.spl沒有能力加載這麼大的kernel

解決:此時不需要kernel能完全啟動,只需要能加載啟動一段,足以體現出啟動速度是否正常即可,於是裁剪出一個非常小kernel來輔助實驗。

2.kernel需要dtb

解決:內核有一個CONFIG_BUILD_ARM_APPENDED_DTB_IMAGE選項。選上重新編譯。編譯后再用ddkerneldtb拼接到一起,作為新的kernel。這樣,spl就只需要加載一個文件並跳轉過去即可。

試驗結果,spl啟動的kernel和使用新uboot啟動的kernel速度一致,均比舊uboot啟動的kernel慢。

說明,舊uboot中做了某個關鍵操作,而新uboot沒做。

找出關鍵操作

那接下來的任務就是,找出舊uboot中的這個關鍵操作了。

怎麼找呢?有了上一步的成果,我們可以使用以下方法來排查

  1. spl加載kernel和舊uboot

  2. spl跳轉到舊uboot,此時kernel其實已經在dram中準備好了,隨時可以啟動

  3. 在舊uboot的啟動流程各個階段,嘗試直接跳轉到kernel,觀察啟動速度

  4. 如果在舊ubootA點跳轉kernel啟動慢,B點跳轉啟動快,則說明關鍵操作位於AB點之間。

方法有了,很快就鎖定到start.S,進一步在start.S中揪出了這段代碼

#if defined(CONFIG_ARM_A7)
@set SMP bit
    mrc     p15, 0, r0, c1, c0, 1
    orr        r0, r0, #(1<<6)
    mcr        p15, 0, r0, c1, c0, 1
#endif

ubootstart.S中沒有這段代碼,嘗試在新ubootstart.S中添加此操作,速度立馬恢復正常了。

再全局搜索下,原來這個新版本uboot中,套路是在board_init中進行此項設置的,而這個平台從舊版本移植過來,就沒有設置 SMP bit, 補上即可。

SMP bit是什麼

SMP 是指對稱多處理器,看起來這個 bit 會影響多核的 cache一致性,此處沒有再深入研究。

但可以知道,對於單處理器的情況,也需要設置這個bit才能正常使用cache

貼下arm的圖和描述:

[6]	SMP	

Signals if the Cortex-A9 processor is taking part in coherency or not.

In uniprocessor configurations, if this bit is set, then Inner Cacheable Shared is treated as Cacheable. The reset value is zero.

搜下kernel的代碼,發現也是有地方調用了的。不過這個芯片是單核的,根本就沒配置CONFIG_SMP

#ifdef CONFIG_SMP
	ALT_SMP(mrc	p15, 0, r0, c1, c0, 1)
	ALT_UP(mov	r0, #(1 << 6))		@ fake it for UP
	tst	r0, #(1 << 6)			@ SMP/nAMP mode enabled?
	orreq	r0, r0, #(1 << 6)		@ Enable SMP/nAMP mode
	orreq	r0, r0, r10			@ Enable CPU-specific SMP bits
	mcreq	p15, 0, r0, c1, c0, 1
#endif

總結

整理出來一方面是記錄這兩個bug,另一方面也是想記錄下當時的一些操作。

畢竟同樣的bug可能以後都不會碰到了,但解bug的方法和思路卻是可以積累復用的。

blog: https://www.cnblogs.com/zqb-all/p/13172546.html
公眾號:https://sourl.cn/shT3kz

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

網頁設計公司推薦不同的風格,搶佔消費者視覺第一線

※Google地圖已可更新顯示潭子電動車充電站設置地點!!

※廣告預算用在刀口上,台北網頁設計公司幫您達到更多曝光效益

※別再煩惱如何寫文案,掌握八大原則!

網頁設計最專業,超強功能平台可客製化

Python 為什麼不支持 i++ 自增語法,不提供 ++ 操作符?

在 C/C++/Java 等等語言中,整型變量的自增或自減操作是標配,它們又可分為前綴操作(++i 和 –i)與後綴操作(i++ 和 i–),彼此存在着一些細微差別,各有不同的用途。

這些語言的使用者在接觸 Python 時,可能會疑惑為什麼它不提供 ++ 或 — 的操作呢?在我前不久發的《Python的十萬個為什麼?》里,就有不少同學在調查問卷中表示了對此話題感興趣。

Python 中雖然可能出現 ++i 這種前綴形式的寫法,但是它並沒有“++”自增操作符,此處只是兩個“+”(正數符號)的疊加而已,至於後綴形式的“++”,則完全不支持(SyntaxError: invalid syntax)。

本期“Python為什麼 ”欄目,我們將會從兩個主要的角度來回答:Python 為什麼不支持 i++ 自增語法? (PS:此處自增指代“自增和自減”,下同)

首先,Python 當然可以實現自增效果,即寫成i += 1 或者 i = i + 1 ,這在其它語言中也是通用的。

雖然 Python 在底層用了不同的魔術方法(__add__()__iadd__() )來完成計算,但表面上的效果完全相同。

所以,我們的問題可以轉化成:為什麼上面的兩種寫法會勝過 i++,成為 Python 的最終選擇呢?

1、Python 的整數是不可變類型

當我們定義i = 1000 時,不同語言會作出不同的處理:

  • C 之類的語言(寫法 int i = 1000)會申請一塊內存空間,並給它“綁定”一個固定的名稱 i,同時寫入一個可變的值 1000。在這裏,i 的地址以及類型是固定的,而值是可變的(在一定的表示範圍內)
  • Python(寫法i = 1000)也會申請一塊內存空間,但是它會“綁定”給数字 1000,即這個 1000 的地址以及類型是固定的(immutable),至於 i,只是一個名稱標籤貼在 1000 上,自身沒有固定的地址和類型

所以當我們令 i “自增”時(i = i + 1),它們的處理是不同的:

  • C 之類的語言先找到 i 的地址上存的數值,然後令它加 1,操作后新的數值就取代了舊的數值
  • Python 的操作過程是把 i 指向的数字加 1,然後把結果綁定到新申請的一塊內存空間,再把名稱標籤 i “貼”到新的数字上。新舊数字可以同時存在,不是取代關係

打一個不太恰當的比方:C 中的 i 就像一個宿主,数字 1000 寄生在它上面;而 Python 中的 1000 像個宿主,名稱 i 寄生在它上面。C 中的 i 與 Python 中的 1000,它們則寄生在底層的內存空間上……

還可以這樣理解:C 中的變量 i 是一等公民,数字 1000 是它的一個可變的屬性;Python 中的数字 1000 是一等公民,名稱 i 是它的一個可變的屬性。

有了以上的鋪墊,我們再來看看 i++,不難發現:

  • C 之類的語言,i++ 可以表示 i 的数字屬性的增加,它不會開闢新的內存空間,也不會產生新的一等公民
  • Python 之類的語言,i++ 如果是對其名稱屬性的操作,那樣就沒有意義了(總不能按字母表順序,把 i 變成 j 吧);如果理解成對数字本體的操作,那麼情況就會變得複雜:它會產生新的一等公民 1001,因此需要給它分配一個內存地址,此時若佔用 1000 的地址,則涉及舊對象的回收,那原有對於 1000 的引用關係都會受到影響,所以只能開闢新的內存空間給 1001

Python 若支持 i++,其操作過程要比 C 的 i++ 複雜,而且其含義也不再是“令数字增加1”(自增),而是“創建一個新的数字”(新增), 這樣的話,“自增操作符”(increment operator)就名不副實了。

Python 在理論上可以實現 i++ 操作,但它就必須重新定義“自增操作符”,還會令有其它語言經驗的人產生誤解,不如就讓大家直接寫成i += 1 或者 i = i + 1 好了。

2、Python 有可迭代對象

C/C++ 等語言設計出 i++,最主要的目的是為了方便使用三段式的 for 結構:

for(int i = 0; i < 100; i++){
    // 執行 xxx
}

這種程序關心的是数字本身的自增過程,数字做加法與程序體的執行相關聯。

Python 中沒有這種 for 結構的寫法,它提供了更為優雅的方式:

for i in range(100):
    # 執行 xxx

my_list = ["你好", "我是Python貓", "歡迎關注"]
for info in my_list:
    print(info)

這裏體現了不同的思維方式,它關心的是在一個數值範圍內的迭代遍歷,並不關心也不需要人為對数字做加法。

Python 中的可迭代對象/迭代器/生成器提供了非常良好的迭代/遍歷用法,能夠做到對 i++ 的完全替代。

例如,上例中實現了對列表內值的遍歷,Python 還可以用 enumerate() 實現對下標與具體值的同時遍歷:

my_list = ["你好", "我是Python貓", "歡迎關注"]
for i, info in enumerate(my_list):
    print(i, info)

# 打印結果:
0 你好
1 我是Python貓
2 歡迎關注

再例如對於字典的遍歷,Python 提供了 keys()、values()、items() 等遍歷方法,非常好用:

my_dict = {'a': '1', 'b': '2', 'c': '3'}
for key in my_dict.keys():
    print(key)

for key, value in my_dict.items():
    print(key, value)

有了這樣的利器,哪裡還有 i++ 的用武之地呢?

不僅如此,Python 中基本上很少使用i += 1 或者 i = i + 1 ,由於存在着隨處可見的可迭代對象,開發者們很容易實現對一個數值區間的操作,也就很少有對於某個數值作累加的訴求了。

所以,回到我們開頭的問題,其實這兩種“自增”寫法並沒有勝出 i++ 多少,只因為它們是通用型操作,又不需要引入新的操作符,所以 Python 才延續了一種基礎性的支持。真正的贏家其實是各種各樣的可迭代對象!

稍微小結下:Python 不支持自增操作符,一方面是因為它的整數是不可變類型的一等公民,自增操作(++)若要支持,則會帶來歧義;另一方面主要因為它有更合適的實現,即可迭代對象,對遍歷操作有很好的支持。

如果你覺得本文分析得不錯,那你應該會喜歡這些文章:

1、Python為什麼使用縮進來劃分代碼塊?

2、Python 的縮進是不是反人類的設計?

3、Python 為什麼不用分號作語句終止符?

4、Python 為什麼沒有 main 函數?為什麼我不推薦寫 main 函數?

5、Python 為什麼推薦蛇形命名法?

寫在最後:本文屬於“Python為什麼”系列(Python貓出品),該系列主要關注 Python 的語法、設計和發展等話題,以一個個“為什麼”式的問題為切入點,試着展現 Python 的迷人魅力。部分話題會推出視頻版,請在 B 站收看,觀看地址:視頻地址

公眾號【Python貓】, 本號連載優質的系列文章,有Python為什麼系列、喵星哲學貓系列、Python進階系列、好書推薦系列、技術寫作、優質英文推薦與翻譯等等,歡迎關注哦。

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

網頁設計公司推薦不同的風格,搶佔消費者視覺第一線

※廣告預算用在刀口上,台北網頁設計公司幫您達到更多曝光效益

※自行創業缺乏曝光? 網頁設計幫您第一時間規劃公司的形象門面

南投搬家公司費用需注意的眉眉角角,別等搬了再說!

新北清潔公司,居家、辦公、裝潢細清專業服務

※教你寫出一流的銷售文案?

手把手教你學Numpy,搞定數據處理——收官篇

本文始發於個人公眾號:TechFlow,原創不易,求個關注

今天是Numpy專題第6篇文章,我們一起來看看Numpy庫當中剩餘的部分。

數組的持久化

在我們做機器學習模型的研究或者是學習的時候,在完成了訓練之後,有時候會希望能夠將相應的參數保存下來。否則的話,如果是在Notebook當中,當Notebook關閉的時候,這些值就丟失了。一般的解決方案是將我們需要的值或者是數組“持久化”,通常的做法是存儲在磁盤上。

Python當中讀寫文件稍稍有些麻煩,我們還需要創建文件句柄,然後一行行寫入,寫入完成之後需要關閉句柄。即使是用with語句,也依然不夠簡便。針對這個問題,numpy當中自帶了寫入文件的api,我們直接調用即可。

通過numpy當中save的文件是二進制格式的,所以我們是無法讀取其中內容的,即使強行打開也會是亂碼。

以二進制的形式存儲數據避免了數據類型轉化的過程,尤其是numpy底層的數據是以C++實現的,如果使用Python的文件接口的話,勢必要先轉化成Python的格式,這會帶來大量開銷。既然可以存儲,自然也可以讀取,我們可以調用numpy的load函數將numpy文件讀取進來。

要注意我們保存的時候沒有添加文件後綴,numpy會自動為我們添加後綴,但是讀取的時候必須要指定文件的全名,否則會numpy無法找到,會引發報錯。

不僅如此,numpy還支持我們同時保存多個數組進入一個文件當中。

我們使用savez來完成,在這個api當中我們傳入了a=arr,b=arr,其實是以類似字典的形式傳入的。在文件當中,numpy會將變量名和數組的值映射起來。這樣我們在讀入的時候,就可以通過變量名訪問到對應的值了。

如果要存儲的數據非常大的話,我們還可以對數據進行壓縮,我們只需要更換savez成savez_compressed即可。

線性代數

Numpy除了科學計算之外,另外一大強大的功能就是支持矩陣運算,這也是它廣為流行並且在機器學習當中大受歡迎的原因之一。我們在之前的線性代數的文章當中曾經提到過Numpy這方面的一些應用,我們今天再在這篇文章當中匯總一些常用的線性代數的接口。

點乘

說起來矩陣點乘應該是最常用的線代api了,比如在神經網絡當中,如果拋開激活函數的話,一層神經元對於當前數據的影響,其實等價於特徵矩陣點乘了一個係數矩陣。再比如在邏輯回歸當中,我們計算樣本的加權和的時候,也是通過矩陣點乘來實現的。

在Andrew的深度學習課上,他曾經做過這樣的實現,對於兩個巨大的矩陣進行矩陣相乘的運算。一次是通過Python的循環來實現,一次是通過Numpy的dot函數實現,兩者的時間開銷相差了足足上百倍。這當中的效率差距和Python語言的特性以及併發能力有關,所以在機器學習領域當中,我們總是將樣本向量化或者矩陣化,通過點乘來計算加權求和,或者是係數相乘。

在Numpy當中我們採用dot函數來計算兩個矩陣的點積,既可以寫成a.dot(b),也可以寫成np.dot(a, b)。一般來說我更加喜歡前者,因為寫起來更加方便清晰。如果你喜歡後者也問題不大,這個只是個人喜好。

注意不要寫成*,這個符號代表兩個矩陣元素兩兩相乘,而不是進行點積運算。它等價於np當中的multiply函數。

轉置與逆矩陣

轉置我們曾經在之前的文章當中提到過,可以通過.T或者是np.transpose來完成。

Numpy中還提供了求解逆矩陣的操作,這個函數在numpy的linalg路徑下,這個路徑下實現了許多常用的線性代數函數。根據線性代數當中的知識,只有滿秩的方陣才有逆矩陣。我們可以通過numpy.linalg.det先來計算行列式來判斷,否則如果直接調用的話,對於沒有逆矩陣的矩陣會報錯。

在這個例子當中,由於矩陣b的行列式為0,說明它並不是滿秩的,所以我們求它的逆矩陣會報錯。

除了這些函數之外,linalg當中還封裝了其他一些常用的函數。比如進行qr分解的qr函數,進行奇異值分解的svd函數,求解線性方程組的solve函數等。相比之下,這些函數的使用頻率相對不高,所以就不展開一一介紹了,我們可以用到的時候再去詳細研究。

隨機

Numpy當中另外一個常用的領域就是隨機數,我們經常使用Numpy來生成各種各樣的隨機數。這一塊在Numpy當中其實也有很多的api以及很複雜的用法,同樣,我們不過多深入,挑其中比較重要也是經常使用的和大家分享一下。

隨機數的所有函數都在numpy.random這個路徑下,我們為了簡化,就不寫完整的路徑了,大家記住就好。

randn

這個函數我們經常在代碼當中看到,尤其是我們造數據的時候。它代表的是根據輸入的shape生成一批均值為0,標準差為1的正態分佈的隨機數。

要注意的是,我們傳入的shape不是一個元組,而是每一維的大小,這一點和其他地方的用法不太一樣,需要注意一下。除了正態分佈的randn之外,還有均勻分佈的uniform和Gamma分佈的gamma,卡方分佈的chisquare。

normal

normal其實也是生成正態分佈的樣本值,但不同的是,它支持我們指定樣本的均值和標準差。如果我們想要生成多個樣本,還可以在size參數當中傳入指定的shape。

randint

顧名思義,這個函數是用來生成隨機整數的。它接受傳入隨機數的上下界,最少也要傳入一個上界(默認下界是0)。

如果想要生成多個int,我們可以在size參數傳入一個shape,它會返回一個對應大小的數組,這一點和uniform用法一樣。

shuffle

shuffle的功能是對一個數組進行亂序,返回亂序之後的結果。一般用在機器學習當中,如果存在樣本聚集的情況,我們一般會使用shuffle進行亂序,避免模型受到樣本分佈的影響。

shuffle是一個inplace的方法,它會在原本值上進行改動,而不會返回一個新值。

choice

這也是一個非常常用的api,它可以在數據當中抽取指定條數據。

但是它只支持一維的數組,一般用在批量訓練的時候,我們通過choice採樣出樣本的下標,再通過數組索引去找到這些樣本的值。比如這樣:

總結

今天我們一起研究了Numpy中數據持久化、線性代數、隨機數相關api的使用方法,由於篇幅的限制,我們只是選擇了其中比較常用,或者是比較重要的用法,還存在一些較為冷門的api和用法,大家感興趣的可以自行研究一下,一般來說文章當中提到的用法已經足夠了。

今天這篇是Numpy專題的最後一篇了,如果你堅持看完本專題所有的文章,那麼相信你對於Numpy包一定有了一個深入的理解和認識了,給自己鼓鼓掌吧。之後周四會開啟Pandas專題,敬請期待哦。

如果喜歡本文,可以的話,請點個關注,給我一點鼓勵,也方便獲取更多文章。

本文使用 mdnice 排版

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

※自行創業缺乏曝光? 網頁設計幫您第一時間規劃公司的形象門面

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

※想知道最厲害的網頁設計公司"嚨底家"!

※幫你省時又省力,新北清潔一流服務好口碑

※別再煩惱如何寫文案,掌握八大原則!

※產品缺大量曝光嗎?你需要的是一流包裝設計!