查看原文
其他

理解高斯混合模型

gloomyfish OpenCV学堂 2019-03-29

理解高斯混合模型

一:概述

高斯混合模型(GMM)在图像分割、对象识别、视频分析等方面均有应用,对于任意给定的数据样本集合,根据其分布概率, 可以计算每个样本数据向量的概率分布,从而根据概率分布对其进行分类,但是这些概率分布是混合在一起的,要从中分离出单个样本的概率分布就实现了样本数据聚类,而概率分布描述我们可以使用高斯函数实现,这个就是高斯混合模型-GMM。

二:数学原理

这种方法也称为D-EM即基于距离的期望最大化。

三 算法步骤:

  1. 初始化变量定义-指定的聚类数目K与数据维度D

  2. 初始化均值、协方差、先验概率分布

  3. 迭代E-M步骤

  • E步计算期望

  • M步更新均值、协方差、先验概率分布

  • 检测是否达到停止条件(最大迭代次数与最小误差满足),达到则退出迭代,否则继续E-M步骤

    4.  输出最终分类

四:代码详解

计算单个样本的先验概率分布

  1. public double getProbability(double[] sample)

  2. {

  3.    double p = 0;

  4.    for (int i = 0; i < mixNum; i++)

  5.    {

  6.        p += weights[i] * getProbability(sample, i);

  7.    }

  8.    return p;

  9. }

每个维度的概率分布

  1. /**

  2. * Gaussian Model -> PDF

  3. * @param x - 表示采样数据点向量

  4. * @param j - 表示对对应的第J个分类的概率密度分布

  5. * @return - 返回概率密度分布可能性值

  6. */

  7. public double getProbability(double[] x, int j)

  8. {

  9.    double p = 1;

  10.    for (int d = 0; d < dimNum; d++)

  11.    {

  12.        p *= 1 / Math.sqrt(2 * 3.14159 * m_vars[j][d]);

  13.        p *= Math.exp(-0.5 * (x[d] - m_means[j][d]) * (x[d] - m_means[j][d]) / m_vars[j][d]);

  14.    }

  15.    return p;

  16. }

参数初始化步骤

  1. private void initParameters(double[] data) {

  2.    // 随机方法初始化均值

  3.    int size = data.length;

  4.    for (int i = 0; i < mixNum; i++)

  5.    {

  6.        for (int d = 0; d < dimNum; d++)

  7.        {

  8.            m_means[i][d] = data[(int)(Math.random()*size)];

  9.        }

  10.    }

  11.    // 根据均值获取分类

  12.    int[] types = new int[size];

  13.    for (int k = 0; k < size; k++)

  14.    {

  15.        double max = 0;

  16.        for (int i = 0; i < mixNum; i++)

  17.        {

  18.            double v = 0;

  19.            for(int j=0;j<dimNum;j++) {

  20.                v += Math.abs(data[k*dimNum+j] - m_means[i][j]);

  21.            }

  22.            if(v > max) {

  23.                max = v;

  24.                types[k] = i;

  25.            }

  26.        }

  27.    }

  28.    double[] counts = new double[mixNum];

  29.    for(int i=0; i<types.length; i++) {

  30.        counts[types[i]]++;

  31.    }

  32.    // 计算先验概率权重

  33.    for (int i = 0; i < mixNum; i++)

  34.    {

  35.        weights[i] = counts[i] / size;

  36.    }

  37.    // 计算每个分类的方差

  38.    int label = -1;

  39.    int[] Label = new int[size];

  40.    double[] overMeans = new double[dimNum];

  41.    double[] x = new double[dimNum];

  42.    for (int i = 0; i < size; i++)

  43.    {

  44.        for(int j=0;j<dimNum;j++)

  45.            x[j]=data[i*dimNum+j];

  46.        label=Label[i];

  47.        // Count each Gaussian

  48.        counts[label]++;

  49.        for (int d = 0; d < dimNum; d++)

  50.        {

  51.            m_vars[label][d] += (x[d] - m_means[types[i]][d]) * (x[d] - m_means[types[i]][d]);

  52.        }

  53.        // Count the overall mean and variance.

  54.        for (int d = 0; d < dimNum; d++)

  55.        {

  56.            overMeans[d] += x[d];

  57.            m_minVars[d] += x[d] * x[d];

  58.        }

  59.    }

  60.    // Compute the overall variance (* 0.01) as the minimum variance.

  61.    for (int d = 0; d < dimNum; d++)

  62.    {

  63.        overMeans[d] /= size;

  64.        m_minVars[d] = Math.max(MIN_VAR, 0.01 * (m_minVars[d] / size - overMeans[d] * overMeans[d]));

  65.    }

  66.    // Initialize each Gaussian.

  67.    for (int i = 0; i < mixNum; i++)

  68.    {

  69.        if (weights[i] > 0)

  70.        {

  71.            for (int d = 0; d < dimNum; d++)

  72.            {

  73.                m_vars[i][d] = m_vars[i][d] / counts[i];

  74.                // A minimum variance for each dimension is required.

  75.                if (m_vars[i][d] < m_minVars[d])

  76.                {

  77.                    m_vars[i][d] = m_minVars[d];

  78.                }

  79.            }

  80.        }

  81.    }

  82. }

E-Step

  1. for (int k = 0; k < size; k++)

  2. {

  3.    for(int j=0;j<dimNum;j++)

  4.        x[j]=data[k*dimNum+j];

  5.    double p = getProbability(x); // 总的概率密度分布

  6.    DataNode dn = new DataNode(x);

  7.    dn.index = k;

  8.    cList.add(dn);

  9.    double maxp = 0;

  10.    for (int j = 0; j < mixNum; j++)

  11.    {

  12.        double pj = getProbability(x, j) * weights[j] / p; // 每个分类的概率密度分布百分比

  13.        if(maxp < pj) {

  14.            maxp = pj;

  15.            dn.cindex = j;

  16.        }

  17.        next_weights[j] += pj; // 得到后验概率

  18.        for (int d = 0; d < dimNum; d++)

  19.        {

  20.            next_means[j][d] += pj * x[d];

  21.            next_vars[j][d] += pj* x[d] * x[d];

  22.        }

  23.    }

  24.    currL += (p > 1E-20) ? Math.log10(p) : -20;

  25. }

M-Step

  1. for (int j = 0; j < mixNum; j++)

  2. {

  3.    weights[j] = next_weights[j] / size;

  4.    if (weights[j] > 0)

  5.    {

  6.        for (int d = 0; d < dimNum; d++)

  7.        {

  8.            m_means[j][d] = next_means[j][d] / next_weights[j];

  9.            m_vars[j][d] = next_vars[j][d] / next_weights[j] - m_means[j][d] * m_means[j][d];

  10.            if (m_vars[j][d] < m_minVars[d])

  11.            {

  12.                m_vars[j][d] = m_minVars[d];

  13.            }

  14.        }

  15.    }

  16. }

这里初始中心均值的方法我是通过随机数来实现,GMM算法运行结果跟初始化有很大关系,常见初始化中心点的方法是通过K-Means来计算出中心点。大家可以尝试修改代码基于K-Means初始化参数,我之所以选择随机参数初始,主要是为了方便大家理解!

完整源代码可以从作者博客获取:

http://blog.csdn.net/jia20003/article/details/72771737


世间最容易的事是坚持。

最难的事也是坚持!


关注【OpenCV学堂】

长按或者扫描二维码即可关注

+OpenCV学习群 376281510

进群暗号:OpenCV


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存