注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

BeyondEgo

Welcome to Winsolider's yard! 超越自我,谁与争锋?

 
 
 

日志

 
 
关于我

本博为记事、畅聊、交友博客,邀你共同探讨人生、探讨成长,广交天下有志之士!愿与有相同兴趣爱好的你,共同学习、一起成长、收获喜悦!

网易考拉推荐

TLD源码理解之TLD.cpp(二)  

2013-05-20 15:53:58|  分类: TLD算法 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |
  1.   
  2. //网格均匀撒点,box共10*10=100个特征点   
  3. void TLD::bbPoints(vector<cv::Point2f>& points, const BoundingBox& bb){  
  4.   int max_pts=10;  
  5.   int margin_h=0; //采样边界  
  6.   int margin_v=0;  
  7.   //网格均匀撒点   
  8.   int stepx = ceil((bb.width-2*margin_h)/max_pts);  //ceil返回大于或者等于指定表达式的最小整数  
  9.   int stepy = ceil((bb.height-2*margin_v)/max_pts);  
  10.   //网格均匀撒点,box共10*10=100个特征点   
  11.   for (int y=bb.y+margin_v; y<bb.y+bb.height-margin_v; y+=stepy){  
  12.       for (int x=bb.x+margin_h;x<bb.x+bb.width-margin_h;x+=stepx){  
  13.           points.push_back(Point2f(x,y));  
  14.       }  
  15.   }  
  16.   
  17. //利用剩下的这不到一半的跟踪点输入来预测bounding box在当前帧的位置和大小  
  18. void TLD::bbPredict(const vector<cv::Point2f>& points1,const vector<cv::Point2f>& points2,  
  19.                     const BoundingBox& bb1,BoundingBox& bb2)    {  
  20.   int npoints = (int)points1.size();  
  21.   vector<float> xoff(npoints);  //位移  
  22.   vector<float> yoff(npoints);  
  23.   printf("tracked points : %d\n", npoints);  
  24.   for (int i=0;i<npoints;i++){   //计算每个特征点在两帧之间的位移  
  25.       xoff[i]=points2[i].x - points1[i].x;  
  26.       yoff[i]=points2[i].y - points1[i].y;  
  27.   }  
  28.   float dx = median(xoff);   //计算位移的中值  
  29.   float dy = median(yoff);  
  30.   float s;  
  31.   //计算bounding box尺度scale的变化:通过计算 当前特征点相互间的距离 与 先前(上一帧)特征点相互间的距离 的  
  32.   //比值,以比值的中值作为尺度的变化因子   
  33.   if (npoints>1){  
  34.       vector<float> d;  
  35.       d.reserve(npoints*(npoints-1)/2);  //等差数列求和:1+2+...+(npoints-1)  
  36.       for (int i=0;i<npoints;i++){  
  37.           for (int j=i+1;j<npoints;j++){  
  38.           //计算 当前特征点相互间的距离 与 先前(上一帧)特征点相互间的距离 的比值(位移用绝对值)  
  39.               d.push_back(norm(points2[i]-points2[j])/norm(points1[i]-points1[j]));  
  40.           }  
  41.       }  
  42.       s = median(d);  
  43.   }  
  44.   else {  
  45.       s = 1.0;  
  46.   }  
  47.   
  48.   float s1 = 0.5*(s-1)*bb1.width;  
  49.   float s2 = 0.5*(s-1)*bb1.height;  
  50.   printf("s= %f s1= %f s2= %f \n", s, s1, s2);  
  51.     
  52.   //得到当前bounding box的位置与大小信息   
  53.   //当前box的x坐标 = 前一帧box的x坐标 + 全部特征点位移的中值(可理解为box移动近似的位移) - 当前box宽的一半  
  54.   bb2.x = round( bb1.x + dx - s1);  
  55.   bb2.y = round( bb1.y + dy -s2);  
  56.   bb2.width = round(bb1.width*s);  
  57.   bb2.height = round(bb1.height*s);  
  58.   printf("predicted bb: %d %d %d %d\n",bb2.x,bb2.y,bb2.br().x,bb2.br().y);  
  59. }  
  60.   
  61. void TLD::detect(const cv::Mat& frame){  
  62.   //cleaning   
  63.   dbb.clear();  
  64.   dconf.clear();  
  65.   dt.bb.clear();  
  66.   //GetTickCount返回从操作系统启动到现在所经过的时间  
  67.   double t = (double)getTickCount();  
  68.   Mat img(frame.rows, frame.cols, CV_8U);  
  69.   integral(frame,iisum,iisqsum);   //计算frame的积分图   
  70.   GaussianBlur(frame,img,Size(9,9),1.5);  //高斯模糊,去噪?  
  71.   int numtrees = classifier.getNumStructs();  
  72.   float fern_th = classifier.getFernTh(); //getFernTh()返回thr_fern; 集合分类器的分类阈值  
  73.   vector <int> ferns(10);  
  74.   float conf;  
  75.   int a=0;  
  76.   Mat patch;  
  77.   //级联分类器模块一:方差检测模块,利用积分图计算每个待检测窗口的方差,方差大于var阈值(目标patch方差的50%)的,  
  78.   //则认为其含有前景目标   
  79.   for (int i=0; i<grid.size(); i++){  //FIXME: BottleNeck 瓶颈  
  80.       if (getVar(grid[i],iisum,iisqsum) >= var){  //计算每一个扫描窗口的方差  
  81.           a++;  
  82.           //级联分类器模块二:集合分类器检测模块  
  83.           patch = img(grid[i]);  
  84.           classifier.getFeatures(patch,grid[i].sidx,ferns); //得到该patch特征(13位的二进制代码)  
  85.           conf = classifier.measure_forest(ferns);  //计算该特征值对应的后验概率累加值  
  86.           tmp.conf[i]=conf;   //Detector data中定义TempStruct tmp;   
  87.           tmp.patt[i]=ferns;  
  88.           //如果集合分类器的后验概率的平均值大于阈值fern_th(由训练得到),就认为含有前景目标  
  89.           if (conf > numtrees*fern_th){    
  90.               dt.bb.push_back(i);  //将通过以上两个检测模块的扫描窗口记录在detect structure中  
  91.           }  
  92.       }  
  93.       else  
  94.         tmp.conf[i]=0.0;  
  95.   }  
  96.   int detections = dt.bb.size();  
  97.   printf("%d Bounding boxes passed the variance filter\n",a);  
  98.   printf("%d Initial detection from Fern Classifier\n", detections);  
  99.     
  100.   //如果通过以上两个检测模块的扫描窗口数大于100个,则只取后验概率大的前100个  
  101.   if (detections>100){   //CComparator(tmp.conf)指定比较方式???  
  102.       nth_element(dt.bb.begin(), dt.bb.begin()+100, dt.bb.end(), CComparator(tmp.conf));  
  103.       dt.bb.resize(100);  
  104.       detections=100;  
  105.   }  
  106. //  for (int i=0;i<detections;i++){  
  107. //        drawBox(img,grid[dt.bb[i]]);   
  108. //    }   
  109. //  imshow("detections",img);   
  110.   if (detections==0){  
  111.         detected=false;  
  112.         return;  
  113.       }  
  114.   printf("Fern detector made %d detections ",detections);  
  115.     
  116.   //两次使用getTickCount(),然后再除以getTickFrequency(),计算出来的是以秒s为单位的时间(opencv 2.0 以前是ms)  
  117.   t=(double)getTickCount()-t;    
  118.   printf("in %gms\n", t*1000/getTickFrequency());  //打印以上代码运行使用的毫秒数  
  119.     
  120.   //  Initialize detection structure  
  121.   dt.patt = vector<vector<int> >(detections,vector<int>(10,0));        //  Corresponding codes of the Ensemble Classifier  
  122.   dt.conf1 = vector<float>(detections);                                //  Relative Similarity (for final nearest neighbour classifier)  
  123.   dt.conf2 =vector<float>(detections);                                 //  Conservative Similarity (for integration with tracker)  
  124.   dt.isin = vector<vector<int> >(detections,vector<int>(3,-1));        //  Detected (isin=1) or rejected (isin=0) by nearest neighbour classifier  
  125.   dt.patch = vector<Mat>(detections,Mat(patch_size,patch_size,CV_32F));//  Corresponding patches  
  126.   int idx;  
  127.   Scalar mean, stdev;  
  128.   float nn_th = classifier.getNNTh();  
  129.   //级联分类器模块三:最近邻分类器检测模块   
  130.   for (int i=0;i<detections;i++){                                         //  for every remaining detection  
  131.       idx=dt.bb[i];                                                       //  Get the detected bounding box index  
  132.       patch = frame(grid[idx]);  
  133.       getPattern(patch,dt.patch[i],mean,stdev);                //  Get pattern within bounding box  
  134.       //计算图像片pattern到在线模型M的相关相似度和保守相似度  
  135.       classifier.NNConf(dt.patch[i],dt.isin[i],dt.conf1[i],dt.conf2[i]);  //  Evaluate nearest neighbour classifier  
  136.       dt.patt[i]=tmp.patt[idx];  
  137.       //printf("Testing feature %d, conf:%f isin:(%d|%d|%d)\n",i,dt.conf1[i],dt.isin[i][0],dt.isin[i][1],dt.isin[i][2]);  
  138.       //相关相似度大于阈值,则认为含有前景目标   
  139.       if (dt.conf1[i]>nn_th){                                               //  idx = dt.conf1 > tld.model.thr_nn; % get all indexes that made it through the nearest neighbour  
  140.           dbb.push_back(grid[idx]);                                         //  BB    = dt.bb(:,idx); % bounding boxes  
  141.           dconf.push_back(dt.conf2[i]);                                     //  Conf  = dt.conf2(:,idx); % conservative confidences  
  142.       }  
  143.   }  
  144.   //打印检测到的可能存在目标的扫描窗口数(可以通过三个级联检测器的)  
  145.   if (dbb.size()>0){  
  146.       printf("Found %d NN matches\n",(int)dbb.size());  
  147.       detected=true;  
  148.   }  
  149.   else{  
  150.       printf("No NN matches found.\n");  
  151.       detected=false;  
  152.   }  
  153. }  
  154.   
  155. //作者已经用python脚本../datasets/evaluate_vis.py来完成算法评估功能,具体见README  
  156. void TLD::evaluate(){  
  157. }  
  158.   
  159. void TLD::learn(const Mat& img){  
  160.   printf("[Learning] ");  
  161.     
  162.   ///Check consistency   
  163.   //检测一致性   
  164.   BoundingBox bb;  
  165.   bb.x = max(lastbox.x,0);  
  166.   bb.y = max(lastbox.y,0);  
  167.   bb.width = min(min(img.cols-lastbox.x,lastbox.width),min(lastbox.width,lastbox.br().x));  
  168.   bb.height = min(min(img.rows-lastbox.y,lastbox.height),min(lastbox.height,lastbox.br().y));  
  169.   Scalar mean, stdev;  
  170.   Mat pattern;  
  171.   //归一化img(bb)对应的patch的size(放缩至patch_size = 15*15),存入pattern  
  172.   getPattern(img(bb), pattern, mean, stdev);  
  173.   vector<int> isin;  
  174.   float dummy, conf;  
  175.   //计算输入图像片(跟踪器的目标box)与在线模型之间的相关相似度conf   
  176.   classifier.NNConf(pattern,isin,conf,dummy);  
  177.   if (conf<0.5) {   //如果相似度太小了,就不训练  
  178.       printf("Fast change..not training\n");  
  179.       lastvalid =false;  
  180.       return;  
  181.   }  
  182.   if (pow(stdev.val[0], 2)< var){  //如果方差太小了,也不训练  
  183.       printf("Low variance..not training\n");  
  184.       lastvalid=false;  
  185.       return;  
  186.   }  
  187.   if(isin[2]==1){   //如果被被识别为负样本,也不训练  
  188.       printf("Patch in negative data..not traing");  
  189.       lastvalid=false;  
  190.       return;  
  191.   }  
  192.     
  193.   /// Data generation  样本产生   
  194.   for (int i=0;i<grid.size();i++){   //计算所有的扫描窗口与目标box的重叠度  
  195.       grid[i].overlap = bbOverlap(lastbox, grid[i]);  
  196.   }  
  197.   //集合分类器   
  198.   vector<pair<vector<int>,int> > fern_examples;  
  199.   good_boxes.clear();    
  200.   bad_boxes.clear();  
  201.   //此函数根据传入的lastbox,在整帧图像中的全部窗口中寻找与该lastbox距离最小(即最相似,  
  202.   //重叠度最大)的num_closest_update个窗口,然后把这些窗口 归入good_boxes容器(只是把网格数组的索引存入)  
  203.   //同时,把重叠度小于0.2的,归入 bad_boxes 容器   
  204.   getOverlappingBoxes(lastbox, num_closest_update);  
  205.   if (good_boxes.size()>0)  
  206.     generatePositiveData(img, num_warps_update);  //用仿射模型产生正样本(类似于第一帧的方法,但只产生10*10=100个)  
  207.   else{  
  208.     lastvalid = false;  
  209.     printf("No good boxes..Not training");  
  210.     return;  
  211.   }  
  212.   fern_examples.reserve(pX.size() + bad_boxes.size());  
  213.   fern_examples.assign(pX.begin(), pX.end());  
  214.   int idx;  
  215.   for (int i=0;i<bad_boxes.size();i++){  
  216.       idx=bad_boxes[i];  
  217.       if (tmp.conf[idx]>=1){   //加入负样本,相似度大于1??相似度不是出于0和1之间吗?  
  218.           fern_examples.push_back(make_pair(tmp.patt[idx],0));  
  219.       }  
  220.   }  
  221.   //最近邻分类器   
  222.   vector<Mat> nn_examples;  
  223.   nn_examples.reserve(dt.bb.size()+1);  
  224.   nn_examples.push_back(pEx);  
  225.   for (int i=0;i<dt.bb.size();i++){  
  226.       idx = dt.bb[i];  
  227.       if (bbOverlap(lastbox,grid[idx]) < bad_overlap)  
  228.         nn_examples.push_back(dt.patch[i]);  
  229.   }  
  230.     
  231.   /// Classifiers update  分类器训练   
  232.   classifier.trainF(fern_examples,2);  
  233.   classifier.trainNN(nn_examples);  
  234.   classifier.show(); //把正样本库(在线模型)包含的所有正样本显示在窗口上  
  235. }  
  236.  
  237.   
  238. //检测器采用扫描窗口的策略   
  239. //此函数根据传入的box(目标边界框)在传入的图像中构建全部的扫描窗口,并计算每个窗口与box的重叠度  
  240. void TLD::buildGrid(const cv::Mat& img, const cv::Rect& box){  
  241.   const float SHIFT = 0.1;  //扫描窗口步长为 宽高的 10%  
  242.   //尺度缩放系数为1.2 (0.16151*1.2=0.19381),共21种尺度变换  
  243.   const float SCALES[] = {0.16151,0.19381,0.23257,0.27908,0.33490,0.40188,0.48225,  
  244.                           0.57870,0.69444,0.83333,1,1.20000,1.44000,1.72800,  
  245.                           2.07360,2.48832,2.98598,3.58318,4.29982,5.15978,6.19174};  
  246.   int width, height, min_bb_side;  
  247.   //Rect bbox;   
  248.   BoundingBox bbox;  
  249.   Size scale;  
  250.   int sc=0;  
  251.     
  252.   for (int s=0; s < 21; s++){  
  253.     width = round(box.width*SCALES[s]);  
  254.     height = round(box.height*SCALES[s]);  
  255.     min_bb_side = min(height,width);  //bounding box最短的边  
  256.     //由于图像片(min_win 为15x15像素)是在bounding box中采样得到的,所以box必须比min_win要大  
  257.     //另外,输入的图像肯定得比 bounding box 要大了  
  258.     if (min_bb_side < min_win || width > img.cols || height > img.rows)  
  259.       continue;  
  260.     scale.width = width;  
  261.     scale.height = height;  
  262.     //push_back在vector类中作用为在vector尾部加入一个数据  
  263.     //scales在类TLD中定义:std::vector<cv::Size> scales;  
  264.     scales.push_back(scale);  //把该尺度的窗口存入scales容器,避免在扫描时计算,加快检测速度  
  265.     for (int y=1; y<img.rows-height; y+=round(SHIFT*min_bb_side)){  //按步长移动窗口  
  266.       for (int x=1; x<img.cols-width; x+=round(SHIFT*min_bb_side)){  
  267.         bbox.x = x;  
  268.         bbox.y = y;  
  269.         bbox.width = width;  
  270.         bbox.height = height;  
  271.         //判断传入的bounding box(目标边界框)与 传入图像中的此时窗口的 重叠度,  
  272.         //以此来确定该图像窗口是否含有目标   
  273.         bbox.overlap = bbOverlap(bbox, BoundingBox(box));  
  274.         bbox.sidx = sc;  //属于第几个尺度   
  275.         //grid在类TLD中定义:std::vector<BoundingBox> grid;  
  276.         //把本位置和本尺度的扫描窗口存入grid容器   
  277.         grid.push_back(bbox);  
  278.       }  
  279.     }  
  280.     sc++;  
  281.   }  
  282. }  
  283.   
  284. //此函数计算两个bounding box 的重叠度   
  285. //重叠度定义为 两个box的交集 与 它们的并集 的比   
  286. float TLD::bbOverlap(const BoundingBox& box1, const BoundingBox& box2){  
  287.   //先判断坐标,假如它们都没有重叠的地方,就直接返回0   
  288.   if (box1.x > box2.x + box2.width) { return 0.0; }  
  289.   if (box1.y > box2.y + box2.height) { return 0.0; }  
  290.   if (box1.x + box1.width < box2.x) { return 0.0; }  
  291.   if (box1.y + box1.height < box2.y) { return 0.0; }  
  292.   
  293.   float colInt =  min(box1.x + box1.width, box2.x + box2.width) - max(box1.x, box2.x);  
  294.   float rowInt =  min(box1.y + box1.height, box2.y + box2.height) - max(box1.y, box2.y);  
  295.   
  296.   float intersection = colInt * rowInt;  
  297.   float area1 = box1.width * box1.height;  
  298.   float area2 = box2.width * box2.height;  
  299.   return intersection / (area1 + area2 - intersection);  
  300. }  
  301.   


《TLD源码理解之TLD.cpp(三)》 
  评论这张
 
阅读(812)| 评论(0)
推荐 转载

历史上的今天

在LOFTER的更多文章

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017