diff options
Diffstat (limited to 'third_party/aom/av1/encoder/rdopt.c')
-rw-r--r-- | third_party/aom/av1/encoder/rdopt.c | 2289 |
1 files changed, 1239 insertions, 1050 deletions
diff --git a/third_party/aom/av1/encoder/rdopt.c b/third_party/aom/av1/encoder/rdopt.c index 6f4fced871..fef6d28755 100644 --- a/third_party/aom/av1/encoder/rdopt.c +++ b/third_party/aom/av1/encoder/rdopt.c @@ -58,8 +58,11 @@ #include "av1/encoder/tokenize.h" #include "av1/encoder/tx_prune_model_weights.h" +#define DNN_BASED_RD_INTERP_FILTER 0 + // Set this macro as 1 to collect data about tx size selection. #define COLLECT_TX_SIZE_DATA 0 + #if COLLECT_TX_SIZE_DATA static const char av1_tx_size_data_output_file[] = "tx_size_data.txt"; #endif @@ -916,9 +919,9 @@ static double od_compute_dist(uint16_t *x, uint16_t *y, int bsize_w, int activity_masking = 0; int i, j; - DECLARE_ALIGNED(16, od_coeff, e[MAX_TX_SQUARE]); - DECLARE_ALIGNED(16, od_coeff, tmp[MAX_TX_SQUARE]); - DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_TX_SQUARE]); + DECLARE_ALIGNED(16, od_coeff, e[MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, od_coeff, tmp[MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_SB_SQUARE]); for (i = 0; i < bsize_h; i++) { for (j = 0; j < bsize_w; j++) { e[i * bsize_w + j] = x[i * bsize_w + j] - y[i * bsize_w + j]; @@ -944,9 +947,9 @@ static double od_compute_dist_diff(uint16_t *x, int16_t *e, int bsize_w, int activity_masking = 0; - DECLARE_ALIGNED(16, uint16_t, y[MAX_TX_SQUARE]); - DECLARE_ALIGNED(16, od_coeff, tmp[MAX_TX_SQUARE]); - DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_TX_SQUARE]); + DECLARE_ALIGNED(16, uint16_t, y[MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, od_coeff, tmp[MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_SB_SQUARE]); int i, j; for (i = 0; i < bsize_h; i++) { for (j = 0; j < bsize_w; j++) { @@ -975,8 +978,8 @@ int64_t av1_dist_8x8(const AV1_COMP *const cpi, const MACROBLOCK *x, int i, j; const MACROBLOCKD *xd = &x->e_mbd; - DECLARE_ALIGNED(16, uint16_t, orig[MAX_TX_SQUARE]); - DECLARE_ALIGNED(16, uint16_t, rec[MAX_TX_SQUARE]); + DECLARE_ALIGNED(16, uint16_t, orig[MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, uint16_t, rec[MAX_SB_SQUARE]); assert(bsw >= 8); assert(bsh >= 8); @@ -1068,8 +1071,8 @@ static int64_t dist_8x8_diff(const MACROBLOCK *x, const uint8_t *src, int i, j; const MACROBLOCKD *xd = &x->e_mbd; - DECLARE_ALIGNED(16, uint16_t, orig[MAX_TX_SQUARE]); - DECLARE_ALIGNED(16, int16_t, diff16[MAX_TX_SQUARE]); + DECLARE_ALIGNED(16, uint16_t, orig[MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, int16_t, diff16[MAX_SB_SQUARE]); assert(bsw >= 8); assert(bsh >= 8); @@ -1112,7 +1115,7 @@ static int64_t dist_8x8_diff(const MACROBLOCK *x, const uint8_t *src, d = (int64_t)od_compute_dist_diff(orig, diff16, bsw, bsh, qindex); } else if (x->tune_metric == AOM_TUNE_CDEF_DIST) { int coeff_shift = AOMMAX(xd->bd - 8, 0); - DECLARE_ALIGNED(16, uint16_t, dst16[MAX_TX_SQUARE]); + DECLARE_ALIGNED(16, uint16_t, dst16[MAX_SB_SQUARE]); for (i = 0; i < bsh; i++) { for (j = 0; j < bsw; j++) { @@ -1146,11 +1149,15 @@ static void get_energy_distribution_fine(const AV1_COMP *cpi, BLOCK_SIZE bsize, const int bh = block_size_high[bsize]; unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - const int f_index = bsize - BLOCK_16X16; - if (f_index < 0) { - const int w_shift = bw == 8 ? 1 : 2; - const int h_shift = bh == 8 ? 1 : 2; - if (cpi->common.use_highbitdepth) { + if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) { + // Special cases: calculate 'esq' values manually, as we don't have 'vf' + // functions for the 16 (very small) sub-blocks of this block. + const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3; + const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3; + assert(bw <= 32); + assert(bh <= 32); + assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15); + if (cpi->common.seq_params.use_highbitdepth) { const uint16_t *src16 = CONVERT_TO_SHORTPTR(src); const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst); for (int i = 0; i < bh; ++i) @@ -1168,43 +1175,49 @@ static void get_energy_distribution_fine(const AV1_COMP *cpi, BLOCK_SIZE bsize, (src[j + i * src_stride] - dst[j + i * dst_stride]); } } - } else { - cpi->fn_ptr[f_index].vf(src, src_stride, dst, dst_stride, &esq[0]); - cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, + } else { // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks. + const int f_index = + (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16; + assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL); + const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index; + assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]); + assert(block_size_high[bsize] == 4 * block_size_high[subsize]); + cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]); + cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, &esq[1]); - cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, + cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, &esq[2]); - cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, + cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, dst_stride, &esq[3]); src += bh / 4 * src_stride; dst += bh / 4 * dst_stride; - cpi->fn_ptr[f_index].vf(src, src_stride, dst, dst_stride, &esq[4]); - cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, + cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]); + cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, &esq[5]); - cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, + cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, &esq[6]); - cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, + cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, dst_stride, &esq[7]); src += bh / 4 * src_stride; dst += bh / 4 * dst_stride; - cpi->fn_ptr[f_index].vf(src, src_stride, dst, dst_stride, &esq[8]); - cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, + cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]); + cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, &esq[9]); - cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, + cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, &esq[10]); - cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, + cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, dst_stride, &esq[11]); src += bh / 4 * src_stride; dst += bh / 4 * dst_stride; - cpi->fn_ptr[f_index].vf(src, src_stride, dst, dst_stride, &esq[12]); - cpi->fn_ptr[f_index].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, + cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]); + cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride, &esq[13]); - cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, + cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride, &esq[14]); - cpi->fn_ptr[f_index].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, + cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4, dst_stride, &esq[15]); } @@ -1371,16 +1384,27 @@ static void get_energy_distribution_finer(const int16_t *diff, int stride, unsigned int esq[256]; const int w_shift = bw <= 8 ? 0 : 1; const int h_shift = bh <= 8 ? 0 : 1; - const int esq_w = bw <= 8 ? bw : bw / 2; - const int esq_h = bh <= 8 ? bh : bh / 2; + const int esq_w = bw >> w_shift; + const int esq_h = bh >> h_shift; const int esq_sz = esq_w * esq_h; int i, j; memset(esq, 0, esq_sz * sizeof(esq[0])); - for (i = 0; i < bh; i++) { - unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w; - const int16_t *cur_diff_row = diff + i * stride; - for (j = 0; j < bw; j++) { - cur_esq_row[j >> w_shift] += cur_diff_row[j] * cur_diff_row[j]; + if (w_shift) { + for (i = 0; i < bh; i++) { + unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w; + const int16_t *cur_diff_row = diff + i * stride; + for (j = 0; j < bw; j += 2) { + cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] + + cur_diff_row[j + 1] * cur_diff_row[j + 1]); + } + } + } else { + for (i = 0; i < bh; i++) { + unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w; + const int16_t *cur_diff_row = diff + i * stride; + for (j = 0; j < bw; j++) { + cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j]; + } } } @@ -1558,9 +1582,9 @@ static const float *prune_2D_adaptive_thresholds[] = { NULL, }; -static int prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size, - int blk_row, int blk_col, TxSetType tx_set_type, - TX_TYPE_PRUNE_MODE prune_mode) { +static uint16_t prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size, + int blk_row, int blk_col, TxSetType tx_set_type, + TX_TYPE_PRUNE_MODE prune_mode) { static const int tx_type_table_2D[16] = { DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT, ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST, @@ -1636,7 +1660,7 @@ static int prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size, const float score_thresh = prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness - 1]; - int prune_bitmask = 0; + uint16_t prune_bitmask = 0; for (int i = 0; i < 16; i++) { if (scores_2D[i] < score_thresh && i != max_score_i) prune_bitmask |= (1 << tx_type_table_2D[i]); @@ -1644,9 +1668,27 @@ static int prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size, return prune_bitmask; } +// ((prune >> vtx_tab[tx_type]) & 1) +static const uint16_t prune_v_mask[] = { + 0x0000, 0x0425, 0x108a, 0x14af, 0x4150, 0x4575, 0x51da, 0x55ff, + 0xaa00, 0xae25, 0xba8a, 0xbeaf, 0xeb50, 0xef75, 0xfbda, 0xffff, +}; + +// ((prune >> (htx_tab[tx_type] + 8)) & 1) +static const uint16_t prune_h_mask[] = { + 0x0000, 0x0813, 0x210c, 0x291f, 0x80e0, 0x88f3, 0xa1ec, 0xa9ff, + 0x5600, 0x5e13, 0x770c, 0x7f1f, 0xd6e0, 0xdef3, 0xf7ec, 0xffff, +}; + +static INLINE uint16_t gen_tx_search_prune_mask(int tx_search_prune) { + uint8_t prune_v = tx_search_prune & 0x0F; + uint8_t prune_h = (tx_search_prune >> 8) & 0x0F; + return (prune_v_mask[prune_v] & prune_h_mask[prune_h]); +} + static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x, const MACROBLOCKD *const xd, int tx_set_type) { - av1_zero(x->tx_search_prune); + x->tx_search_prune[tx_set_type] = 0; x->tx_split_prune_flag = 0; const MB_MODE_INFO *mbmi = xd->mi[0]; if (!is_inter_block(mbmi) || cpi->sf.tx_type_search.prune_mode == NO_PRUNE || @@ -1656,24 +1698,24 @@ static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x, int tx_set = ext_tx_set_index[1][tx_set_type]; assert(tx_set >= 0); const int *tx_set_1D = ext_tx_used_inter_1D[tx_set]; + int prune = 0; switch (cpi->sf.tx_type_search.prune_mode) { case NO_PRUNE: return; case PRUNE_ONE: if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return; - x->tx_search_prune[tx_set_type] = prune_one_for_sby(cpi, bsize, x, xd); + prune = prune_one_for_sby(cpi, bsize, x, xd); + x->tx_search_prune[tx_set_type] = gen_tx_search_prune_mask(prune); break; case PRUNE_TWO: if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) { if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return; - x->tx_search_prune[tx_set_type] = - prune_two_for_sby(cpi, bsize, x, xd, 0, 1); - } - if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) { - x->tx_search_prune[tx_set_type] = - prune_two_for_sby(cpi, bsize, x, xd, 1, 0); + prune = prune_two_for_sby(cpi, bsize, x, xd, 0, 1); + } else if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) { + prune = prune_two_for_sby(cpi, bsize, x, xd, 1, 0); + } else { + prune = prune_two_for_sby(cpi, bsize, x, xd, 1, 1); } - x->tx_search_prune[tx_set_type] = - prune_two_for_sby(cpi, bsize, x, xd, 1, 1); + x->tx_search_prune[tx_set_type] = gen_tx_search_prune_mask(prune); break; case PRUNE_2D_ACCURATE: case PRUNE_2D_FAST: break; @@ -1681,17 +1723,6 @@ static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x, } } -static int do_tx_type_search(TX_TYPE tx_type, int prune, - TX_TYPE_PRUNE_MODE mode) { - // TODO(sarahparker) implement for non ext tx - if (mode >= PRUNE_2D_ACCURATE) { - return !((prune >> tx_type) & 1); - } else { - return !(((prune >> vtx_tab[tx_type]) & 1) | - ((prune >> (htx_tab[tx_type] + 8)) & 1)); - } -} - static void model_rd_from_sse(const AV1_COMP *const cpi, const MACROBLOCKD *const xd, BLOCK_SIZE bsize, int plane, int64_t sse, int *rate, @@ -1764,9 +1795,11 @@ static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize, for (plane = plane_from; plane <= plane_to; ++plane) { struct macroblock_plane *const p = &x->plane[plane]; struct macroblockd_plane *const pd = &xd->plane[plane]; - const BLOCK_SIZE bs = + const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y); - unsigned int sse; + const int bw = block_size_wide[plane_bsize]; + const int bh = block_size_high[plane_bsize]; + int64_t sse; int rate; int64_t dist; @@ -1774,14 +1807,14 @@ static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize, // TODO(geza): Write direct sse functions that do not compute // variance as well. - cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, - &sse); + sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh); + sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2); - if (plane == 0) x->pred_sse[ref] = sse; + if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX); total_sse += sse; - model_rd_from_sse(cpi, xd, bs, plane, sse, &rate, &dist); + model_rd_from_sse(cpi, xd, plane_bsize, plane, sse, &rate, &dist); rate_sum += rate; dist_sum += dist; @@ -1934,7 +1967,8 @@ static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x, static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row, int blk_col, const BLOCK_SIZE plane_bsize, - const BLOCK_SIZE tx_bsize) { + const BLOCK_SIZE tx_bsize, + int force_sse) { int visible_rows, visible_cols; const MACROBLOCKD *xd = &x->e_mbd; get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL, @@ -1944,13 +1978,17 @@ static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane, #if CONFIG_DIST_8X8 int txb_height = block_size_high[tx_bsize]; int txb_width = block_size_wide[tx_bsize]; - if (x->using_dist_8x8 && plane == 0 && txb_width >= 8 && txb_height >= 8) { + if (!force_sse && x->using_dist_8x8 && plane == 0 && txb_width >= 8 && + txb_height >= 8) { const int src_stride = x->plane[plane].src.stride; const int src_idx = (blk_row * src_stride + blk_col) << tx_size_wide_log2[0]; + const int diff_idx = (blk_row * diff_stride + blk_col) + << tx_size_wide_log2[0]; const uint8_t *src = &x->plane[plane].src.buf[src_idx]; - return dist_8x8_diff(x, src, src_stride, diff, diff_stride, txb_width, - txb_height, visible_cols, visible_rows, x->qindex); + return dist_8x8_diff(x, src, src_stride, diff + diff_idx, diff_stride, + txb_width, txb_height, visible_cols, visible_rows, + x->qindex); } #endif diff += ((blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]); @@ -2182,10 +2220,14 @@ static void get_2x2_normalized_sses_and_sads( for (int col = 0; col < 2; ++col) { const int16_t *const this_src_diff = src_diff + row * half_height * diff_stride + col * half_width; - sse_norm_arr[row * 2 + col] = - get_sse_norm(this_src_diff, diff_stride, half_width, half_height); - sad_norm_arr[row * 2 + col] = - get_sad_norm(this_src_diff, diff_stride, half_width, half_height); + if (sse_norm_arr) { + sse_norm_arr[row * 2 + col] = + get_sse_norm(this_src_diff, diff_stride, half_width, half_height); + } + if (sad_norm_arr) { + sad_norm_arr[row * 2 + col] = + get_sad_norm(this_src_diff, diff_stride, half_width, half_height); + } } } } else { // use function pointers to calculate stats @@ -2199,28 +2241,35 @@ static void get_2x2_normalized_sses_and_sads( const uint8_t *const this_dst = dst + row * half_height * dst_stride + col * half_width; - unsigned int this_sse; - cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst, - dst_stride, &this_sse); - sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half; + if (sse_norm_arr) { + unsigned int this_sse; + cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst, + dst_stride, &this_sse); + sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half; + } - const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf( - this_src, src_stride, this_dst, dst_stride); - sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half; + if (sad_norm_arr) { + const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf( + this_src, src_stride, this_dst, dst_stride); + sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half; + } } } } } #if CONFIG_COLLECT_RD_STATS -// NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values -// 0: Do not collect any RD stats -// 1: Collect RD stats for transform units -// 2: Collect RD stats for partition units + // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values + // 0: Do not collect any RD stats + // 1: Collect RD stats for transform units + // 2: Collect RD stats for partition units + +#if CONFIG_COLLECT_RD_STATS == 1 static void PrintTransformUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats, int blk_row, int blk_col, BLOCK_SIZE plane_bsize, - TX_SIZE tx_size, TX_TYPE tx_type) { + TX_SIZE tx_size, TX_TYPE tx_type, + int64_t rd) { if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return; // Generate small sample to restrict output size. @@ -2304,9 +2353,12 @@ static void PrintTransformUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2], hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]); + fprintf(fout, " %d %" PRId64, x->rdmult, rd); + fprintf(fout, "\n"); fclose(fout); } +#endif // CONFIG_COLLECT_RD_STATS == 1 #if CONFIG_COLLECT_RD_STATS == 2 static void PrintPredictionUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, @@ -2327,12 +2379,14 @@ static void PrintPredictionUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, const int plane = 0; struct macroblock_plane *const p = &x->plane[plane]; const struct macroblockd_plane *const pd = &xd->plane[plane]; - const int bw = block_size_wide[plane_bsize]; - const int bh = block_size_high[plane_bsize]; + const int diff_stride = block_size_wide[plane_bsize]; + int bw, bh; + get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw, + &bh); + const int num_samples = bw * bh; const int dequant_shift = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3; const int q_step = pd->dequant_Q3[1] >> dequant_shift; - const double num_samples = bw * bh; const double rate_norm = (double)rd_stats->rate / num_samples; const double dist_norm = (double)rd_stats->dist / num_samples; @@ -2343,23 +2397,28 @@ static void PrintPredictionUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, const uint8_t *const src = p->src.buf; const int dst_stride = pd->dst.stride; const uint8_t *const dst = pd->dst.buf; - unsigned int sse; - cpi->fn_ptr[plane_bsize].vf(src, src_stride, dst, dst_stride, &sse); + const int16_t *const src_diff = p->src_diff; + const int shift = (xd->bd - 8); + + int64_t sse = aom_sum_squares_2d_i16(src_diff, diff_stride, bw, bh); + sse = ROUND_POWER_OF_TWO(sse, shift * 2); const double sse_norm = (double)sse / num_samples; const unsigned int sad = cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride); - const double sad_norm = (double)sad / num_samples; + const double sad_norm = + (double)sad / (1 << num_pels_log2_lookup[plane_bsize]); fprintf(fout, " %g %g", sse_norm, sad_norm); - const int diff_stride = block_size_wide[plane_bsize]; - const int16_t *const src_diff = p->src_diff; - double sse_norm_arr[4], sad_norm_arr[4]; get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst, dst_stride, src_diff, diff_stride, sse_norm_arr, sad_norm_arr); + if (shift) { + for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift)); + for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift); + } for (int i = 0; i < 4; ++i) { fprintf(fout, " %g", sse_norm_arr[i]); } @@ -2376,7 +2435,8 @@ static void PrintPredictionUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, const double model_dist_norm = (double)model_dist / num_samples; fprintf(fout, " %g %g", model_rate_norm, model_dist_norm); - const double mean = get_mean(src_diff, diff_stride, bw, bh); + double mean = get_mean(src_diff, diff_stride, bw, bh); + mean /= (1 << shift); double hor_corr, vert_corr; get_horver_correlation(src_diff, diff_stride, bw, bh, &hor_corr, &vert_corr); fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr); @@ -2393,20 +2453,19 @@ static void PrintPredictionUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x, #endif // CONFIG_COLLECT_RD_STATS == 2 #endif // CONFIG_COLLECT_RD_STATS -static void model_rd_with_dnn(const AV1_COMP *const cpi, - const MACROBLOCK *const x, BLOCK_SIZE bsize, - int plane, unsigned int *rsse, int *rate, - int64_t *dist) { +static void model_rd_with_dnn(const AV1_COMP *const cpi, MACROBLOCK *const x, + BLOCK_SIZE plane_bsize, int plane, int64_t *rsse, + int *rate, int64_t *dist) { const MACROBLOCKD *const xd = &x->e_mbd; const struct macroblockd_plane *const pd = &xd->plane[plane]; - const BLOCK_SIZE plane_bsize = - get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y); const int log_numpels = num_pels_log2_lookup[plane_bsize]; - const int num_samples = (1 << log_numpels); const struct macroblock_plane *const p = &x->plane[plane]; - const int bw = block_size_wide[plane_bsize]; - const int bh = block_size_high[plane_bsize]; + int bw, bh; + const int diff_stride = block_size_wide[plane_bsize]; + get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw, + &bh); + const int num_samples = bw * bh; const int dequant_shift = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3; const int q_step = pd->dequant_Q3[1] >> dequant_shift; @@ -2415,55 +2474,73 @@ static void model_rd_with_dnn(const AV1_COMP *const cpi, const uint8_t *const src = p->src.buf; const int dst_stride = pd->dst.stride; const uint8_t *const dst = pd->dst.buf; - unsigned int sse; - cpi->fn_ptr[plane_bsize].vf(src, src_stride, dst, dst_stride, &sse); + const int16_t *const src_diff = p->src_diff; + const int shift = (xd->bd - 8); + int64_t sse = aom_sum_squares_2d_i16(p->src_diff, diff_stride, bw, bh); + sse = ROUND_POWER_OF_TWO(sse, shift * 2); const double sse_norm = (double)sse / num_samples; - const int diff_stride = block_size_wide[plane_bsize]; - const int16_t *const src_diff = p->src_diff; + if (sse == 0) { + if (rate) *rate = 0; + if (dist) *dist = 0; + if (rsse) *rsse = sse; + return; + } + if (plane) { + int model_rate; + int64_t model_dist; + model_rd_from_sse(cpi, xd, plane_bsize, plane, sse, &model_rate, + &model_dist); + if (rate) *rate = model_rate; + if (dist) *dist = model_dist; + if (rsse) *rsse = sse; + return; + } - double sse_norm_arr[4], sad_norm_arr[4]; + double sse_norm_arr[4]; get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst, dst_stride, src_diff, diff_stride, - sse_norm_arr, sad_norm_arr); - const double mean = get_mean(src_diff, diff_stride, bw, bh); + sse_norm_arr, NULL); + double mean = get_mean(src_diff, bw, bw, bh); + if (shift) { + for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift)); + mean /= (1 << shift); + } const double variance = sse_norm - mean * mean; + assert(variance >= 0.0); const double q_sqr = (double)(q_step * q_step); - const double q_sqr_by_variance = q_sqr / variance; + const double q_sqr_by_sse_norm = q_sqr / (sse_norm + 1.0); double hor_corr, vert_corr; get_horver_correlation(src_diff, diff_stride, bw, bh, &hor_corr, &vert_corr); - double hdist[4] = { 0 }, vdist[4] = { 0 }; - get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst, - dst_stride, 1, hdist, vdist); - float features[20]; - features[0] = (float)hdist[0]; - features[1] = (float)hdist[1]; - features[2] = (float)hdist[2]; - features[3] = (float)hdist[3]; - features[4] = (float)hor_corr; - features[5] = (float)log_numpels; - features[6] = (float)mean; - features[7] = (float)q_sqr; - features[8] = (float)q_sqr_by_variance; - features[9] = (float)sse_norm_arr[0]; - features[10] = (float)sse_norm_arr[1]; - features[11] = (float)sse_norm_arr[2]; - features[12] = (float)sse_norm_arr[3]; - features[13] = (float)sse_norm_arr[3]; - features[14] = (float)variance; - features[15] = (float)vdist[0]; - features[16] = (float)vdist[1]; - features[17] = (float)vdist[2]; - features[18] = (float)vdist[3]; - features[19] = (float)vert_corr; - - float rate_f, dist_f; - av1_nn_predict(features, &av1_pustats_dist_nnconfig, &dist_f); + float features[11]; + features[0] = (float)hor_corr; + features[1] = (float)log_numpels; + features[2] = (float)q_sqr; + features[3] = (float)q_sqr_by_sse_norm; + features[4] = (float)sse_norm_arr[0]; + features[5] = (float)sse_norm_arr[1]; + features[6] = (float)sse_norm_arr[2]; + features[7] = (float)sse_norm_arr[3]; + features[8] = (float)sse_norm; + features[9] = (float)variance; + features[10] = (float)vert_corr; + + float rate_f, dist_by_sse_norm_f; + av1_nn_predict(features, &av1_pustats_dist_nnconfig, &dist_by_sse_norm_f); av1_nn_predict(features, &av1_pustats_rate_nnconfig, &rate_f); - const int rate_i = (int)(AOMMAX(0.0, rate_f * (1 << log_numpels)) + 0.5); - const int64_t dist_i = - (int64_t)(AOMMAX(0.0, dist_f * (1 << log_numpels)) + 0.5); + const float dist_f = (float)((double)dist_by_sse_norm_f * (1.0 + sse_norm)); + int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5); + int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5); + + // Check if skip is better + if (RDCOST(x->rdmult, rate_i, dist_i) >= RDCOST(x->rdmult, 0, (sse << 4))) { + dist_i = sse << 4; + rate_i = 0; + } else if (rate_i == 0) { + dist_i = sse << 4; + } + if (rate) *rate = rate_i; if (dist) *dist = dist_i; if (rsse) *rsse = sse; @@ -2488,15 +2565,18 @@ void model_rd_for_sb_with_dnn(const AV1_COMP *const cpi, BLOCK_SIZE bsize, x->pred_sse[ref] = 0; for (int plane = plane_from; plane <= plane_to; ++plane) { - unsigned int sse; + struct macroblockd_plane *const pd = &xd->plane[plane]; + const BLOCK_SIZE plane_bsize = + get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y); + int64_t sse; int rate; int64_t dist; if (x->skip_chroma_rd && plane) continue; - model_rd_with_dnn(cpi, x, bsize, plane, &sse, &rate, &dist); + model_rd_with_dnn(cpi, x, plane_bsize, plane, &sse, &rate, &dist); - if (plane == 0) x->pred_sse[ref] = sse; + if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX); total_sse += sse; rate_sum += rate; @@ -2586,27 +2666,16 @@ static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, int rate_cost = 0; TX_TYPE txk_start = DCT_DCT; TX_TYPE txk_end = TX_TYPES - 1; - if (!(!is_inter && x->use_default_intra_tx_type) && - !(is_inter && x->use_default_inter_tx_type)) - if (x->rd_model == LOW_TXFM_RD || x->cb_partition_scan) - if (plane == 0) txk_end = DCT_DCT; + if ((!is_inter && x->use_default_intra_tx_type) || + (is_inter && x->use_default_inter_tx_type)) { + txk_start = txk_end = get_default_tx_type(0, xd, tx_size); + } else if (x->rd_model == LOW_TXFM_RD || x->cb_partition_scan) { + if (plane == 0) txk_end = DCT_DCT; + } uint8_t best_txb_ctx = 0; const TxSetType tx_set_type = av1_get_ext_tx_set_type(tx_size, is_inter, cm->reduced_tx_set_used); - int prune = 0; - const int do_prune = plane == 0 && !fast_tx_search && txk_end != DCT_DCT && - !(!is_inter && x->use_default_intra_tx_type) && - !(is_inter && x->use_default_inter_tx_type) && - cpi->sf.tx_type_search.prune_mode > NO_PRUNE; - if (do_prune && is_inter) { - if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) { - prune = prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, - tx_set_type, cpi->sf.tx_type_search.prune_mode); - } else { - prune = x->tx_search_prune[tx_set_type]; - } - } TX_TYPE uv_tx_type = DCT_DCT; if (plane) { @@ -2615,39 +2684,38 @@ static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col, tx_size, cm->reduced_tx_set_used); } - if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32) { + const uint16_t ext_tx_used_flag = av1_ext_tx_used_flag[tx_set_type]; + if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 || + ext_tx_used_flag == 0x0001) { txk_start = txk_end = DCT_DCT; } - - int8_t allowed_tx_mask[TX_TYPES] = { 0 }; // 1: allow; 0: skip. - int allowed_tx_num = 0; - if (fast_tx_search) { - allowed_tx_mask[DCT_DCT] = 1; - allowed_tx_mask[H_DCT] = 1; - allowed_tx_mask[V_DCT] = 1; + uint16_t allowed_tx_mask = 0; // 1: allow; 0: skip. + if (txk_start == txk_end) { + allowed_tx_mask = 1 << txk_start; + allowed_tx_mask &= ext_tx_used_flag; + } else if (fast_tx_search) { + allowed_tx_mask = 0x0c01; // V_DCT, H_DCT, DCT_DCT + allowed_tx_mask &= ext_tx_used_flag; } else { - memset(allowed_tx_mask + txk_start, 1, txk_end - txk_start + 1); - } - for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) { - if (do_prune) { - if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode)) - allowed_tx_mask[tx_type] = 0; - } - if (plane == 0 && allowed_tx_mask[tx_type]) { - if (!av1_ext_tx_used[tx_set_type][tx_type]) - allowed_tx_mask[tx_type] = 0; - else if (!is_inter && x->use_default_intra_tx_type && - tx_type != get_default_tx_type(0, xd, tx_size)) - allowed_tx_mask[tx_type] = 0; - else if (is_inter && x->use_default_inter_tx_type && - tx_type != get_default_tx_type(0, xd, tx_size)) - allowed_tx_mask[tx_type] = 0; - } - allowed_tx_num += allowed_tx_mask[tx_type]; + assert(plane == 0); + allowed_tx_mask = ext_tx_used_flag; + // !fast_tx_search && txk_end != txk_start && plane == 0 + const int do_prune = cpi->sf.tx_type_search.prune_mode > NO_PRUNE; + if (do_prune && is_inter) { + if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) { + const uint16_t prune = + prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type, + cpi->sf.tx_type_search.prune_mode); + allowed_tx_mask &= (~prune); + } else { + allowed_tx_mask &= (~x->tx_search_prune[tx_set_type]); + } + } } // Need to have at least one transform type allowed. - if (allowed_tx_num == 0) { - allowed_tx_mask[plane ? uv_tx_type : DCT_DCT] = 1; + if (allowed_tx_mask == 0) { + txk_start = txk_end = (plane ? uv_tx_type : DCT_DCT); + allowed_tx_mask = (1 << txk_start); } int use_transform_domain_distortion = @@ -2664,20 +2732,21 @@ static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, cpi->sf.use_transform_domain_distortion == 1 && use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD && !x->cb_partition_scan; - if (calc_pixel_domain_distortion_final && allowed_tx_num <= 1) + if (calc_pixel_domain_distortion_final && + (txk_start == txk_end || allowed_tx_mask == 0x0001)) calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0; const uint16_t *eobs_ptr = x->plane[plane].eobs; const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size]; int64_t block_sse = - pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize); + pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize, 1); if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2); block_sse *= 16; for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) { - if (!allowed_tx_mask[tx_type]) continue; + if (!(allowed_tx_mask & (1 << tx_type))) continue; if (plane == 0) mbmi->txk_type[txk_type_idx] = tx_type; RD_STATS this_rd_stats; av1_invalid_rd_stats(&this_rd_stats); @@ -2686,8 +2755,8 @@ static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, av1_xform_quant( cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, tx_type, USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP); - rate_cost = av1_cost_coeffs(cm, x, plane, blk_row, blk_col, block, - tx_size, txb_ctx, use_fast_coef_costing); + rate_cost = av1_cost_coeffs(cm, x, plane, block, tx_size, tx_type, + txb_ctx, use_fast_coef_costing); } else { av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, tx_type, AV1_XFORM_QUANT_FP); @@ -2696,13 +2765,18 @@ static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, // Calculate distortion quickly in transform domain. dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist, &this_rd_stats.sse); - rate_cost = av1_cost_coeffs(cm, x, plane, blk_row, blk_col, block, - tx_size, txb_ctx, use_fast_coef_costing); + + const int64_t best_rd_ = AOMMIN(best_rd, ref_best_rd); + const int64_t dist_cost_estimate = + RDCOST(x->rdmult, 0, AOMMIN(this_rd_stats.dist, this_rd_stats.sse)); + if (dist_cost_estimate - (dist_cost_estimate >> 3) > best_rd_) continue; + + rate_cost = av1_cost_coeffs(cm, x, plane, block, tx_size, tx_type, + txb_ctx, use_fast_coef_costing); const int64_t rd_estimate = AOMMIN(RDCOST(x->rdmult, rate_cost, this_rd_stats.dist), RDCOST(x->rdmult, 0, this_rd_stats.sse)); - if (rd_estimate - (rd_estimate >> 3) > AOMMIN(best_rd, ref_best_rd)) - continue; + if (rd_estimate - (rd_estimate >> 3) > best_rd_) continue; } av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx, 1, &rate_cost); @@ -2741,7 +2815,7 @@ static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, #if CONFIG_COLLECT_RD_STATS == 1 if (plane == 0) { PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col, - plane_bsize, tx_size, tx_type); + plane_bsize, tx_size, tx_type, rd); } #endif // CONFIG_COLLECT_RD_STATS == 1 @@ -3097,6 +3171,7 @@ static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs, MACROBLOCK *x, int *r, int64_t *d, int *s, int64_t *sse, int64_t ref_best_rd) { RD_STATS rd_stats; + av1_subtract_plane(x, bs, 0); x->rd_model = LOW_TXFM_RD; int64_t rd = txfm_yrd(cpi, x, &rd_stats, ref_best_rd, bs, max_txsize_rect_lookup[bs], FTXS_NONE); @@ -3267,7 +3342,7 @@ static int intra_mode_info_cost_y(const AV1_COMP *cpi, const MACROBLOCK *x, const int n_cache = av1_get_palette_cache(xd, 0, color_cache); palette_mode_cost += av1_palette_color_cost_y(&mbmi->palette_mode_info, color_cache, - n_cache, cpi->common.bit_depth); + n_cache, cpi->common.seq_params.bit_depth); palette_mode_cost += av1_cost_color_map(x, 0, bsize, mbmi->tx_size, PALETTE_MAP); total_rate += palette_mode_cost; @@ -3318,8 +3393,8 @@ static int intra_mode_info_cost_uv(const AV1_COMP *cpi, const MACROBLOCK *x, write_uniform_cost(plt_size, color_map[0]); uint16_t color_cache[2 * PALETTE_MAX_SIZE]; const int n_cache = av1_get_palette_cache(xd, 1, color_cache); - palette_mode_cost += av1_palette_color_cost_uv(pmi, color_cache, n_cache, - cpi->common.bit_depth); + palette_mode_cost += av1_palette_color_cost_uv( + pmi, color_cache, n_cache, cpi->common.seq_params.bit_depth); palette_mode_cost += av1_cost_color_map(x, 1, bsize, mbmi->tx_size, PALETTE_MAP); total_rate += palette_mode_cost; @@ -3375,6 +3450,7 @@ static int64_t intra_model_yrd(const AV1_COMP *const cpi, MACROBLOCK *const x, } } // RD estimation. + av1_subtract_plane(x, bsize, 0); model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &this_rd_stats.rate, &this_rd_stats.dist, &this_rd_stats.skip, &temp_sse, NULL, NULL, NULL); @@ -3458,10 +3534,10 @@ static void palette_rd_y( return; } PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info; - if (cpi->common.use_highbitdepth) + if (cpi->common.seq_params.use_highbitdepth) for (int i = 0; i < k; ++i) - pmi->palette_colors[i] = - clip_pixel_highbd((int)centroids[i], cpi->common.bit_depth); + pmi->palette_colors[i] = clip_pixel_highbd( + (int)centroids[i], cpi->common.seq_params.bit_depth); else for (int i = 0; i < k; ++i) pmi->palette_colors[i] = clip_pixel(centroids[i]); @@ -3514,6 +3590,7 @@ static int rd_pick_palette_intra_sby( MB_MODE_INFO *const mbmi = xd->mi[0]; assert(!is_inter_block(mbmi)); assert(av1_allow_palette(cpi->common.allow_screen_content_tools, bsize)); + const SequenceHeader *const seq_params = &cpi->common.seq_params; int colors, n; const int src_stride = x->plane[0].src.stride; const uint8_t *const src = x->plane[0].src.buf; @@ -3523,9 +3600,9 @@ static int rd_pick_palette_intra_sby( &cols); int count_buf[1 << 12]; // Maximum (1 << 12) color levels. - if (cpi->common.use_highbitdepth) + if (seq_params->use_highbitdepth) colors = av1_count_colors_highbd(src, src_stride, rows, cols, - cpi->common.bit_depth, count_buf); + seq_params->bit_depth, count_buf); else colors = av1_count_colors(src, src_stride, rows, cols, count_buf); mbmi->filter_intra_mode_info.use_filter_intra = 0; @@ -3537,12 +3614,12 @@ static int rd_pick_palette_intra_sby( int centroids[PALETTE_MAX_SIZE]; int lb, ub, val; uint16_t *src16 = CONVERT_TO_SHORTPTR(src); - if (cpi->common.use_highbitdepth) + if (seq_params->use_highbitdepth) lb = ub = src16[0]; else lb = ub = src[0]; - if (cpi->common.use_highbitdepth) { + if (seq_params->use_highbitdepth) { for (r = 0; r < rows; ++r) { for (c = 0; c < cols; ++c) { val = src16[r * src_stride + c]; @@ -3576,7 +3653,7 @@ static int rd_pick_palette_intra_sby( int top_colors[PALETTE_MAX_SIZE] = { 0 }; for (i = 0; i < AOMMIN(colors, PALETTE_MAX_SIZE); ++i) { int max_count = 0; - for (int j = 0; j < (1 << cpi->common.bit_depth); ++j) { + for (int j = 0; j < (1 << seq_params->bit_depth); ++j) { if (count_buf[j] > max_count) { max_count = count_buf[j]; top_colors[i] = j; @@ -4316,6 +4393,244 @@ static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row, return (int)(score * 100); } +typedef struct { + int64_t rd; + int txb_entropy_ctx; + TX_TYPE tx_type; +} TxCandidateInfo; + +static void try_tx_block_no_split( + const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, + TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, + const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl, + int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd, + FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node, + TxCandidateInfo *no_split) { + MACROBLOCKD *const xd = &x->e_mbd; + MB_MODE_INFO *const mbmi = xd->mi[0]; + struct macroblock_plane *const p = &x->plane[0]; + const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0]; + + no_split->rd = INT64_MAX; + no_split->txb_entropy_ctx = 0; + no_split->tx_type = TX_TYPES; + + const ENTROPY_CONTEXT *const pta = ta + blk_col; + const ENTROPY_CONTEXT *const ptl = tl + blk_row; + + const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size); + TXB_CTX txb_ctx; + get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx); + const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y] + .txb_skip_cost[txb_ctx.txb_skip_ctx][1]; + + rd_stats->ref_rdcost = ref_best_rd; + rd_stats->zero_rate = zero_blk_rate; + const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col); + mbmi->inter_tx_size[index] = tx_size; + tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize, pta, + ptl, rd_stats, ftxs_mode, ref_best_rd, + rd_info_node != NULL ? rd_info_node->rd_info_array : NULL); + assert(rd_stats->rate < INT_MAX); + + if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >= + RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) || + rd_stats->skip == 1) && + !xd->lossless[mbmi->segment_id]) { +#if CONFIG_RD_DEBUG + av1_update_txb_coeff_cost(rd_stats, plane, tx_size, blk_row, blk_col, + zero_blk_rate - rd_stats->rate); +#endif // CONFIG_RD_DEBUG + rd_stats->rate = zero_blk_rate; + rd_stats->dist = rd_stats->sse; + rd_stats->skip = 1; + x->blk_skip[blk_row * bw + blk_col] = 1; + p->eobs[block] = 0; + update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size, + DCT_DCT); + } else { + x->blk_skip[blk_row * bw + blk_col] = 0; + rd_stats->skip = 0; + } + + if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) + rd_stats->rate += x->txfm_partition_cost[txfm_partition_ctx][0]; + + no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); + no_split->txb_entropy_ctx = p->txb_entropy_ctx[block]; + const int txk_type_idx = + av1_get_txk_type_index(plane_bsize, blk_row, blk_col); + no_split->tx_type = mbmi->txk_type[txk_type_idx]; +} + +static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, + int blk_col, int block, TX_SIZE tx_size, int depth, + BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta, + ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, + TXFM_CONTEXT *tx_left, RD_STATS *rd_stats, + int64_t ref_best_rd, int *is_cost_valid, + FAST_TX_SEARCH_MODE ftxs_mode, + TXB_RD_INFO_NODE *rd_info_node); + +static void try_tx_block_split( + const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, + TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta, + ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, + int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd, + FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node, + RD_STATS *split_rd_stats, int64_t *split_rd) { + MACROBLOCKD *const xd = &x->e_mbd; + const int max_blocks_high = max_block_high(xd, plane_bsize, 0); + const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0); + struct macroblock_plane *const p = &x->plane[0]; + const TX_SIZE sub_txs = sub_tx_size_map[tx_size]; + const int bsw = tx_size_wide_unit[sub_txs]; + const int bsh = tx_size_high_unit[sub_txs]; + const int sub_step = bsw * bsh; + RD_STATS this_rd_stats; + int this_cost_valid = 1; + int64_t tmp_rd = 0; +#if CONFIG_DIST_8X8 + int sub8x8_eob[4] = { 0, 0, 0, 0 }; + struct macroblockd_plane *const pd = &xd->plane[0]; +#endif + split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1]; + + assert(tx_size < TX_SIZES_ALL); + + int blk_idx = 0; + for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) { + for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) { + const int offsetr = blk_row + r; + const int offsetc = blk_col + c; + if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue; + assert(blk_idx < 4); + select_tx_block( + cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta, + tl, tx_above, tx_left, &this_rd_stats, ref_best_rd - tmp_rd, + &this_cost_valid, ftxs_mode, + (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL); + +#if CONFIG_DIST_8X8 + if (!x->using_dist_8x8) +#endif + if (!this_cost_valid) goto LOOP_EXIT; +#if CONFIG_DIST_8X8 + if (x->using_dist_8x8 && tx_size == TX_8X8) { + sub8x8_eob[2 * (r / bsh) + (c / bsw)] = p->eobs[block]; + } +#endif // CONFIG_DIST_8X8 + av1_merge_rd_stats(split_rd_stats, &this_rd_stats); + + tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist); +#if CONFIG_DIST_8X8 + if (!x->using_dist_8x8) +#endif + if (no_split_rd < tmp_rd) { + this_cost_valid = 0; + goto LOOP_EXIT; + } + block += sub_step; + } + } + +LOOP_EXIT : {} + +#if CONFIG_DIST_8X8 + if (x->using_dist_8x8 && this_cost_valid && tx_size == TX_8X8) { + const int src_stride = p->src.stride; + const int dst_stride = pd->dst.stride; + + const uint8_t *src = + &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]]; + const uint8_t *dst = + &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]]; + + int64_t dist_8x8; + const int qindex = x->qindex; + const int pred_stride = block_size_wide[plane_bsize]; + const int pred_idx = (blk_row * pred_stride + blk_col) + << tx_size_wide_log2[0]; + const int16_t *pred = &x->pred_luma[pred_idx]; + int i, j; + int row, col; + + uint8_t *pred8; + DECLARE_ALIGNED(16, uint16_t, pred8_16[8 * 8]); + + dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride, BLOCK_8X8, + 8, 8, 8, 8, qindex) * + 16; + +#ifdef DEBUG_DIST_8X8 + if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) + assert(sum_rd_stats.sse == dist_8x8); +#endif // DEBUG_DIST_8X8 + + split_rd_stats->sse = dist_8x8; + + if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) + pred8 = CONVERT_TO_BYTEPTR(pred8_16); + else + pred8 = (uint8_t *)pred8_16; + + if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) { + for (row = 0; row < 2; ++row) { + for (col = 0; col < 2; ++col) { + int idx = row * 2 + col; + int eob = sub8x8_eob[idx]; + + if (eob > 0) { + for (j = 0; j < 4; j++) + for (i = 0; i < 4; i++) + CONVERT_TO_SHORTPTR(pred8) + [(row * 4 + j) * 8 + 4 * col + i] = + pred[(row * 4 + j) * pred_stride + 4 * col + i]; + } else { + for (j = 0; j < 4; j++) + for (i = 0; i < 4; i++) + CONVERT_TO_SHORTPTR(pred8) + [(row * 4 + j) * 8 + 4 * col + i] = CONVERT_TO_SHORTPTR( + dst)[(row * 4 + j) * dst_stride + 4 * col + i]; + } + } + } + } else { + for (row = 0; row < 2; ++row) { + for (col = 0; col < 2; ++col) { + int idx = row * 2 + col; + int eob = sub8x8_eob[idx]; + + if (eob > 0) { + for (j = 0; j < 4; j++) + for (i = 0; i < 4; i++) + pred8[(row * 4 + j) * 8 + 4 * col + i] = + (uint8_t)pred[(row * 4 + j) * pred_stride + 4 * col + i]; + } else { + for (j = 0; j < 4; j++) + for (i = 0; i < 4; i++) + pred8[(row * 4 + j) * 8 + 4 * col + i] = + dst[(row * 4 + j) * dst_stride + 4 * col + i]; + } + } + } + } + dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, pred8, 8, BLOCK_8X8, 8, 8, + 8, 8, qindex) * + 16; + +#ifdef DEBUG_DIST_8X8 + if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) + assert(sum_rd_stats.dist == dist_8x8); +#endif // DEBUG_DIST_8X8 + + split_rd_stats->dist = dist_8x8; + tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist); + } +#endif // CONFIG_DIST_8X8 + if (this_cost_valid) *split_rd = tmp_rd; +} + // Search for the best tx partition/type for a given luma block. static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block, TX_SIZE tx_size, int depth, @@ -4338,8 +4653,6 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return; const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0]; - ENTROPY_CONTEXT *pta = ta + blk_col; - ENTROPY_CONTEXT *ptl = tl + blk_row; MB_MODE_INFO *const mbmi = xd->mi[0]; const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row, mbmi->sb_type, tx_size); @@ -4348,64 +4661,25 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, const int try_no_split = 1; int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH; - int64_t no_split_rd = INT64_MAX; - int no_split_txb_entropy_ctx = 0; - TX_TYPE no_split_tx_type = TX_TYPES; + TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES }; + // TX no split if (try_no_split) { - const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size); - TXB_CTX txb_ctx; - get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx); - const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y] - .txb_skip_cost[txb_ctx.txb_skip_ctx][1]; + try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth, + plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd, + ftxs_mode, rd_info_node, &no_split); - rd_stats->ref_rdcost = ref_best_rd; - rd_stats->zero_rate = zero_blk_rate; - const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col); - mbmi->inter_tx_size[index] = tx_size; - tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize, pta, - ptl, rd_stats, ftxs_mode, ref_best_rd, - rd_info_node != NULL ? rd_info_node->rd_info_array : NULL); - assert(rd_stats->rate < INT_MAX); - - if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >= - RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) || - rd_stats->skip == 1) && - !xd->lossless[mbmi->segment_id]) { -#if CONFIG_RD_DEBUG - av1_update_txb_coeff_cost(rd_stats, plane, tx_size, blk_row, blk_col, - zero_blk_rate - rd_stats->rate); -#endif // CONFIG_RD_DEBUG - rd_stats->rate = zero_blk_rate; - rd_stats->dist = rd_stats->sse; - rd_stats->skip = 1; - x->blk_skip[blk_row * bw + blk_col] = 1; - p->eobs[block] = 0; - update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size, - DCT_DCT); - } else { - x->blk_skip[blk_row * bw + blk_col] = 0; - rd_stats->skip = 0; - } - - if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) - rd_stats->rate += x->txfm_partition_cost[ctx][0]; - no_split_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); if (cpi->sf.adaptive_txb_search_level && - (no_split_rd - - (no_split_rd >> (1 + cpi->sf.adaptive_txb_search_level))) > + (no_split.rd - + (no_split.rd >> (1 + cpi->sf.adaptive_txb_search_level))) > ref_best_rd) { *is_cost_valid = 0; return; } - no_split_txb_entropy_ctx = p->txb_entropy_ctx[block]; - const int txk_type_idx = - av1_get_txk_type_index(plane_bsize, blk_row, blk_col); - no_split_tx_type = mbmi->txk_type[txk_type_idx]; - - if (cpi->sf.txb_split_cap) + if (cpi->sf.txb_split_cap) { if (p->eobs[block] == 0) try_split = 0; + } } if (x->e_mbd.bd == 8 && !x->cb_partition_scan && try_split) { @@ -4427,155 +4701,10 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, RD_STATS split_rd_stats; av1_init_rd_stats(&split_rd_stats); if (try_split) { - const TX_SIZE sub_txs = sub_tx_size_map[tx_size]; - const int bsw = tx_size_wide_unit[sub_txs]; - const int bsh = tx_size_high_unit[sub_txs]; - const int sub_step = bsw * bsh; - RD_STATS this_rd_stats; - int this_cost_valid = 1; - int64_t tmp_rd = 0; -#if CONFIG_DIST_8X8 - int sub8x8_eob[4] = { 0, 0, 0, 0 }; - struct macroblockd_plane *const pd = &xd->plane[0]; -#endif - split_rd_stats.rate = x->txfm_partition_cost[ctx][1]; - - assert(tx_size < TX_SIZES_ALL); - - ref_best_rd = AOMMIN(no_split_rd, ref_best_rd); - - int blk_idx = 0; - for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) { - for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) { - const int offsetr = blk_row + r; - const int offsetc = blk_col + c; - if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue; - assert(blk_idx < 4); - select_tx_block( - cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, - ta, tl, tx_above, tx_left, &this_rd_stats, ref_best_rd - tmp_rd, - &this_cost_valid, ftxs_mode, - (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL); - -#if CONFIG_DIST_8X8 - if (!x->using_dist_8x8) -#endif - if (!this_cost_valid) goto LOOP_EXIT; -#if CONFIG_DIST_8X8 - if (x->using_dist_8x8 && tx_size == TX_8X8) { - sub8x8_eob[2 * (r / bsh) + (c / bsw)] = p->eobs[block]; - } -#endif // CONFIG_DIST_8X8 - av1_merge_rd_stats(&split_rd_stats, &this_rd_stats); - - tmp_rd = RDCOST(x->rdmult, split_rd_stats.rate, split_rd_stats.dist); -#if CONFIG_DIST_8X8 - if (!x->using_dist_8x8) -#endif - if (no_split_rd < tmp_rd) { - this_cost_valid = 0; - goto LOOP_EXIT; - } - block += sub_step; - } - } - - LOOP_EXIT : {} - -#if CONFIG_DIST_8X8 - if (x->using_dist_8x8 && this_cost_valid && tx_size == TX_8X8) { - const int src_stride = p->src.stride; - const int dst_stride = pd->dst.stride; - - const uint8_t *src = - &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]]; - const uint8_t *dst = - &pd->dst - .buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]]; - - int64_t dist_8x8; - const int qindex = x->qindex; - const int pred_stride = block_size_wide[plane_bsize]; - const int pred_idx = (blk_row * pred_stride + blk_col) - << tx_size_wide_log2[0]; - const int16_t *pred = &x->pred_luma[pred_idx]; - int i, j; - int row, col; - - uint8_t *pred8; - DECLARE_ALIGNED(16, uint16_t, pred8_16[8 * 8]); - - dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride, - BLOCK_8X8, 8, 8, 8, 8, qindex) * - 16; - -#ifdef DEBUG_DIST_8X8 - if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) - assert(sum_rd_stats.sse == dist_8x8); -#endif // DEBUG_DIST_8X8 - - split_rd_stats.sse = dist_8x8; - - if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) - pred8 = CONVERT_TO_BYTEPTR(pred8_16); - else - pred8 = (uint8_t *)pred8_16; - - if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) { - for (row = 0; row < 2; ++row) { - for (col = 0; col < 2; ++col) { - int idx = row * 2 + col; - int eob = sub8x8_eob[idx]; - - if (eob > 0) { - for (j = 0; j < 4; j++) - for (i = 0; i < 4; i++) - CONVERT_TO_SHORTPTR(pred8) - [(row * 4 + j) * 8 + 4 * col + i] = - pred[(row * 4 + j) * pred_stride + 4 * col + i]; - } else { - for (j = 0; j < 4; j++) - for (i = 0; i < 4; i++) - CONVERT_TO_SHORTPTR(pred8) - [(row * 4 + j) * 8 + 4 * col + i] = CONVERT_TO_SHORTPTR( - dst)[(row * 4 + j) * dst_stride + 4 * col + i]; - } - } - } - } else { - for (row = 0; row < 2; ++row) { - for (col = 0; col < 2; ++col) { - int idx = row * 2 + col; - int eob = sub8x8_eob[idx]; - - if (eob > 0) { - for (j = 0; j < 4; j++) - for (i = 0; i < 4; i++) - pred8[(row * 4 + j) * 8 + 4 * col + i] = - (uint8_t)pred[(row * 4 + j) * pred_stride + 4 * col + i]; - } else { - for (j = 0; j < 4; j++) - for (i = 0; i < 4; i++) - pred8[(row * 4 + j) * 8 + 4 * col + i] = - dst[(row * 4 + j) * dst_stride + 4 * col + i]; - } - } - } - } - dist_8x8 = av1_dist_8x8(cpi, x, src, src_stride, pred8, 8, BLOCK_8X8, 8, - 8, 8, 8, qindex) * - 16; - -#ifdef DEBUG_DIST_8X8 - if (x->tune_metric == AOM_TUNE_PSNR && xd->bd == 8) - assert(sum_rd_stats.dist == dist_8x8); -#endif // DEBUG_DIST_8X8 - - split_rd_stats.dist = dist_8x8; - tmp_rd = RDCOST(x->rdmult, split_rd_stats.rate, split_rd_stats.dist); - } -#endif // CONFIG_DIST_8X8 - if (this_cost_valid) split_rd = tmp_rd; + try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth, + plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd, + AOMMIN(no_split.rd, ref_best_rd), ftxs_mode, + rd_info_node, &split_rd_stats, &split_rd); } #if COLLECT_TX_SIZE_DATA @@ -4626,9 +4755,11 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, } while (0); #endif // COLLECT_TX_SIZE_DATA - if (no_split_rd < split_rd) { + if (no_split.rd < split_rd) { + ENTROPY_CONTEXT *pta = ta + blk_col; + ENTROPY_CONTEXT *ptl = tl + blk_row; const TX_SIZE tx_size_selected = tx_size; - p->txb_entropy_ctx[block] = no_split_txb_entropy_ctx; + p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx; av1_set_txb_context(x, 0, block, tx_size_selected, pta, ptl); txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size, tx_size); @@ -4641,7 +4772,7 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, } mbmi->tx_size = tx_size_selected; update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size, - no_split_tx_type); + no_split.tx_type); x->blk_skip[blk_row * bw + blk_col] = rd_stats->skip; } else { *rd_stats = split_rd_stats; @@ -4707,13 +4838,19 @@ static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, } } } - int64_t zero_rd = RDCOST(x->rdmult, rd_stats->zero_rate, rd_stats->sse); - this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); - if (zero_rd < this_rd) { - this_rd = zero_rd; - rd_stats->rate = rd_stats->zero_rate; + + const int skip_ctx = av1_get_skip_context(xd); + const int s0 = x->skip_cost[skip_ctx][0]; + const int s1 = x->skip_cost[skip_ctx][1]; + int64_t skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse); + this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist); + if (skip_rd <= this_rd) { + this_rd = skip_rd; + rd_stats->rate = 0; rd_stats->dist = rd_stats->sse; rd_stats->skip = 1; + } else { + rd_stats->skip = 0; } if (this_rd > ref_best_rd) is_cost_valid = 0; @@ -4921,11 +5058,15 @@ static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, } } } - int64_t zero_rd = RDCOST(x->rdmult, rd_stats->zero_rate, rd_stats->sse); - this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); - if (zero_rd < this_rd) { - this_rd = zero_rd; - rd_stats->rate = rd_stats->zero_rate; + + const int skip_ctx = av1_get_skip_context(xd); + const int s0 = x->skip_cost[skip_ctx][0]; + const int s1 = x->skip_cost[skip_ctx][1]; + int64_t skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse); + this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist); + if (skip_rd < this_rd) { + this_rd = skip_rd; + rd_stats->rate = 0; rd_stats->dist = rd_stats->sse; rd_stats->skip = 1; } @@ -5159,7 +5300,7 @@ static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist, const MACROBLOCKD *xd = &x->e_mbd; const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd); - *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize); + *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, 1); const int64_t mse = *dist / bw / bh; // Normalized quantizer takes the transform upscaling factor (8 for tx size // smaller than 32) into account. @@ -5215,23 +5356,7 @@ static void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats, int bsize, mbmi->tx_size = tx_size; memset(x->blk_skip, 1, sizeof(x->blk_skip[0]) * n4); rd_stats->skip = 1; - - // Rate. - const int tx_size_ctx = get_txsize_entropy_ctx(tx_size); - ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE]; - ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE]; - av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl); - TXB_CTX txb_ctx; - // Because plane is 0, plane_bsize equal to bsize - get_txb_ctx(bsize, tx_size, 0, ctxa, ctxl, &txb_ctx); - int rate = x->coeff_costs[tx_size_ctx][PLANE_TYPE_Y] - .txb_skip_cost[txb_ctx.txb_skip_ctx][1]; - if (tx_size > TX_4X4) { - int ctx = txfm_partition_context( - xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size); - rate += x->txfm_partition_cost[ctx][0]; - } - rd_stats->rate = rate; + rd_stats->rate = 0; if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2); rd_stats->dist = rd_stats->sse = (dist << 4); @@ -5322,6 +5447,8 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, ref_best_rd, found_rd_info ? matched_rd_info : NULL); + assert(IMPLIES(this_rd_stats.skip && !this_rd_stats.invalid_rate, + this_rd_stats.rate == 0)); ref_best_rd = AOMMIN(rd, ref_best_rd); if (rd < best_rd) { @@ -5455,6 +5582,7 @@ static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x, av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type)); PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info; const BLOCK_SIZE bsize = mbmi->sb_type; + const SequenceHeader *const seq_params = &cpi->common.seq_params; int this_rate; int64_t this_rd; int colors_u, colors_v, colors; @@ -5470,11 +5598,11 @@ static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x, mbmi->uv_mode = UV_DC_PRED; int count_buf[1 << 12]; // Maximum (1 << 12) color levels. - if (cpi->common.use_highbitdepth) { + if (seq_params->use_highbitdepth) { colors_u = av1_count_colors_highbd(src_u, src_stride, rows, cols, - cpi->common.bit_depth, count_buf); + seq_params->bit_depth, count_buf); colors_v = av1_count_colors_highbd(src_v, src_stride, rows, cols, - cpi->common.bit_depth, count_buf); + seq_params->bit_depth, count_buf); } else { colors_u = av1_count_colors(src_u, src_stride, rows, cols, count_buf); colors_v = av1_count_colors(src_v, src_stride, rows, cols, count_buf); @@ -5494,7 +5622,7 @@ static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x, uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u); uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v); - if (cpi->common.use_highbitdepth) { + if (seq_params->use_highbitdepth) { lb_u = src_u16[0]; ub_u = src_u16[0]; lb_v = src_v16[0]; @@ -5508,7 +5636,7 @@ static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x, for (r = 0; r < rows; ++r) { for (c = 0; c < cols; ++c) { - if (cpi->common.use_highbitdepth) { + if (seq_params->use_highbitdepth) { val_u = src_u16[r * src_stride + c]; val_v = src_v16[r * src_stride + c]; data[(r * cols + c) * 2] = val_u; @@ -5557,9 +5685,9 @@ static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x, pmi->palette_size[1] = n; for (i = 1; i < 3; ++i) { for (j = 0; j < n; ++j) { - if (cpi->common.use_highbitdepth) + if (seq_params->use_highbitdepth) pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel_highbd( - (int)centroids[j * 2 + i - 1], cpi->common.bit_depth); + (int)centroids[j * 2 + i - 1], seq_params->bit_depth); else pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel((int)centroids[j * 2 + i - 1]); @@ -5907,8 +6035,9 @@ static void choose_intra_uv_mode(const AV1_COMP *const cpi, MACROBLOCK *const x, *mode_uv = UV_DC_PRED; return; } - xd->cfl.is_chroma_reference = is_chroma_reference( - mi_row, mi_col, bsize, cm->subsampling_x, cm->subsampling_y); + xd->cfl.is_chroma_reference = + is_chroma_reference(mi_row, mi_col, bsize, cm->seq_params.subsampling_x, + cm->seq_params.subsampling_y); bsize = scale_chroma_bsize(bsize, xd->plane[AOM_PLANE_U].subsampling_x, xd->plane[AOM_PLANE_U].subsampling_y); // Only store reconstructed luma when there's chroma RDO. When there's no @@ -7038,7 +7167,9 @@ static int estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x, // Choose the best wedge index and sign static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x, const BLOCK_SIZE bsize, const uint8_t *const p0, - const uint8_t *const p1, int *const best_wedge_sign, + const int16_t *const residual1, + const int16_t *const diff10, + int *const best_wedge_sign, int *const best_wedge_index) { const MACROBLOCKD *const xd = &x->e_mbd; const struct buf_2d *const src = &x->plane[0].src; @@ -7056,34 +7187,22 @@ static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x, const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH; const int bd_round = hbd ? (xd->bd - 8) * 2 : 0; - DECLARE_ALIGNED(32, int16_t, r0[MAX_SB_SQUARE]); - DECLARE_ALIGNED(32, int16_t, r1[MAX_SB_SQUARE]); - DECLARE_ALIGNED(32, int16_t, d10[MAX_SB_SQUARE]); - DECLARE_ALIGNED(32, int16_t, ds[MAX_SB_SQUARE]); - - int64_t sign_limit; - + DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]); // src - pred0 if (hbd) { - aom_highbd_subtract_block(bh, bw, r0, bw, src->buf, src->stride, - CONVERT_TO_BYTEPTR(p0), bw, xd->bd); - aom_highbd_subtract_block(bh, bw, r1, bw, src->buf, src->stride, - CONVERT_TO_BYTEPTR(p1), bw, xd->bd); - aom_highbd_subtract_block(bh, bw, d10, bw, CONVERT_TO_BYTEPTR(p1), bw, + aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, CONVERT_TO_BYTEPTR(p0), bw, xd->bd); } else { - aom_subtract_block(bh, bw, r0, bw, src->buf, src->stride, p0, bw); - aom_subtract_block(bh, bw, r1, bw, src->buf, src->stride, p1, bw); - aom_subtract_block(bh, bw, d10, bw, p1, bw, p0, bw); + aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw); } - sign_limit = ((int64_t)aom_sum_squares_i16(r0, N) - - (int64_t)aom_sum_squares_i16(r1, N)) * - (1 << WEDGE_WEIGHT_BITS) / 2; - + int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) - + (int64_t)aom_sum_squares_i16(residual1, N)) * + (1 << WEDGE_WEIGHT_BITS) / 2; + int16_t *ds = residual0; if (N < 64) - av1_wedge_compute_delta_squares_c(ds, r0, r1, N); + av1_wedge_compute_delta_squares_c(ds, residual0, residual1, N); else - av1_wedge_compute_delta_squares(ds, r0, r1, N); + av1_wedge_compute_delta_squares(ds, residual0, residual1, N); for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) { mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize); @@ -7096,9 +7215,9 @@ static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x, mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize); if (N < 64) - sse = av1_wedge_sse_from_residuals_c(r1, d10, mask, N); + sse = av1_wedge_sse_from_residuals_c(residual1, diff10, mask, N); else - sse = av1_wedge_sse_from_residuals(r1, d10, mask, N); + sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N); sse = ROUND_POWER_OF_TWO(sse, bd_round); model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist); @@ -7117,12 +7236,15 @@ static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x, } // Choose the best wedge index the specified sign -static int64_t pick_wedge_fixed_sign( - const AV1_COMP *const cpi, const MACROBLOCK *const x, - const BLOCK_SIZE bsize, const uint8_t *const p0, const uint8_t *const p1, - const int wedge_sign, int *const best_wedge_index) { +static int64_t pick_wedge_fixed_sign(const AV1_COMP *const cpi, + const MACROBLOCK *const x, + const BLOCK_SIZE bsize, + const int16_t *const residual1, + const int16_t *const diff10, + const int wedge_sign, + int *const best_wedge_index) { const MACROBLOCKD *const xd = &x->e_mbd; - const struct buf_2d *const src = &x->plane[0].src; + const int bw = block_size_wide[bsize]; const int bh = block_size_high[bsize]; const int N = bw * bh; @@ -7135,26 +7257,12 @@ static int64_t pick_wedge_fixed_sign( uint64_t sse; const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH; const int bd_round = hbd ? (xd->bd - 8) * 2 : 0; - - DECLARE_ALIGNED(32, int16_t, r1[MAX_SB_SQUARE]); - DECLARE_ALIGNED(32, int16_t, d10[MAX_SB_SQUARE]); - - if (hbd) { - aom_highbd_subtract_block(bh, bw, r1, bw, src->buf, src->stride, - CONVERT_TO_BYTEPTR(p1), bw, xd->bd); - aom_highbd_subtract_block(bh, bw, d10, bw, CONVERT_TO_BYTEPTR(p1), bw, - CONVERT_TO_BYTEPTR(p0), bw, xd->bd); - } else { - aom_subtract_block(bh, bw, r1, bw, src->buf, src->stride, p1, bw); - aom_subtract_block(bh, bw, d10, bw, p1, bw, p0, bw); - } - for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) { mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize); if (N < 64) - sse = av1_wedge_sse_from_residuals_c(r1, d10, mask, N); + sse = av1_wedge_sse_from_residuals_c(residual1, diff10, mask, N); else - sse = av1_wedge_sse_from_residuals(r1, d10, mask, N); + sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N); sse = ROUND_POWER_OF_TWO(sse, bd_round); model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist); @@ -7166,16 +7274,14 @@ static int64_t pick_wedge_fixed_sign( best_rd = rd; } } - return best_rd - RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0); } -static int64_t pick_interinter_wedge(const AV1_COMP *const cpi, - MACROBLOCK *const x, - const BLOCK_SIZE bsize, - const uint8_t *const p0, - const uint8_t *const p1) { +static int64_t pick_interinter_wedge( + const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize, + const uint8_t *const p0, const uint8_t *const p1, + const int16_t *const residual1, const int16_t *const diff10) { MACROBLOCKD *const xd = &x->e_mbd; MB_MODE_INFO *const mbmi = xd->mi[0]; const int bw = block_size_wide[bsize]; @@ -7189,9 +7295,11 @@ static int64_t pick_interinter_wedge(const AV1_COMP *const cpi, if (cpi->sf.fast_wedge_sign_estimate) { wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw); - rd = pick_wedge_fixed_sign(cpi, x, bsize, p0, p1, wedge_sign, &wedge_index); + rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign, + &wedge_index); } else { - rd = pick_wedge(cpi, x, bsize, p0, p1, &wedge_sign, &wedge_index); + rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign, + &wedge_index); } mbmi->interinter_comp.wedge_sign = wedge_sign; @@ -7202,10 +7310,11 @@ static int64_t pick_interinter_wedge(const AV1_COMP *const cpi, static int64_t pick_interinter_seg(const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize, const uint8_t *const p0, - const uint8_t *const p1) { + const uint8_t *const p1, + const int16_t *const residual1, + const int16_t *const diff10) { MACROBLOCKD *const xd = &x->e_mbd; MB_MODE_INFO *const mbmi = xd->mi[0]; - const struct buf_2d *const src = &x->plane[0].src; const int bw = block_size_wide[bsize]; const int bh = block_size_high[bsize]; const int N = bw * bh; @@ -7218,23 +7327,6 @@ static int64_t pick_interinter_seg(const AV1_COMP *const cpi, DIFFWTD_MASK_TYPE best_mask_type = 0; const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH; const int bd_round = hbd ? (xd->bd - 8) * 2 : 0; - DECLARE_ALIGNED(32, int16_t, r0[MAX_SB_SQUARE]); - DECLARE_ALIGNED(32, int16_t, r1[MAX_SB_SQUARE]); - DECLARE_ALIGNED(32, int16_t, d10[MAX_SB_SQUARE]); - - if (hbd) { - aom_highbd_subtract_block(bh, bw, r0, bw, src->buf, src->stride, - CONVERT_TO_BYTEPTR(p0), bw, xd->bd); - aom_highbd_subtract_block(bh, bw, r1, bw, src->buf, src->stride, - CONVERT_TO_BYTEPTR(p1), bw, xd->bd); - aom_highbd_subtract_block(bh, bw, d10, bw, CONVERT_TO_BYTEPTR(p1), bw, - CONVERT_TO_BYTEPTR(p0), bw, xd->bd); - } else { - aom_subtract_block(bh, bw, r0, bw, src->buf, src->stride, p0, bw); - aom_subtract_block(bh, bw, r1, bw, src->buf, src->stride, p1, bw); - aom_subtract_block(bh, bw, d10, bw, p1, bw, p0, bw); - } - // try each mask type and its inverse for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) { // build mask and inverse @@ -7247,7 +7339,7 @@ static int64_t pick_interinter_seg(const AV1_COMP *const cpi, bw, bh, bw); // compute rd for mask - sse = av1_wedge_sse_from_residuals(r1, d10, xd->seg_mask, N); + sse = av1_wedge_sse_from_residuals(residual1, diff10, xd->seg_mask, N); sse = ROUND_POWER_OF_TWO(sse, bd_round); model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist); @@ -7279,14 +7371,26 @@ static int64_t pick_interintra_wedge(const AV1_COMP *const cpi, const uint8_t *const p1) { const MACROBLOCKD *const xd = &x->e_mbd; MB_MODE_INFO *const mbmi = xd->mi[0]; - - int64_t rd; - int wedge_index = -1; - assert(is_interintra_wedge_used(bsize)); assert(cpi->common.seq_params.enable_interintra_compound); - rd = pick_wedge_fixed_sign(cpi, x, bsize, p0, p1, 0, &wedge_index); + const struct buf_2d *const src = &x->plane[0].src; + const int bw = block_size_wide[bsize]; + const int bh = block_size_high[bsize]; + DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1 + DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0 + if (get_bitdepth_data_path_index(xd)) { + aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, + CONVERT_TO_BYTEPTR(p1), bw, xd->bd); + aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw, + CONVERT_TO_BYTEPTR(p0), bw, xd->bd); + } else { + aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw); + aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw); + } + int wedge_index = -1; + int64_t rd = + pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0, &wedge_index); mbmi->interintra_wedge_sign = 0; mbmi->interintra_wedge_index = wedge_index; @@ -7296,11 +7400,15 @@ static int64_t pick_interintra_wedge(const AV1_COMP *const cpi, static int64_t pick_interinter_mask(const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize, const uint8_t *const p0, - const uint8_t *const p1) { + const uint8_t *const p1, + const int16_t *const residual1, + const int16_t *const diff10) { const COMPOUND_TYPE compound_type = x->e_mbd.mi[0]->interinter_comp.type; switch (compound_type) { - case COMPOUND_WEDGE: return pick_interinter_wedge(cpi, x, bsize, p0, p1); - case COMPOUND_DIFFWTD: return pick_interinter_seg(cpi, x, bsize, p0, p1); + case COMPOUND_WEDGE: + return pick_interinter_wedge(cpi, x, bsize, p0, p1, residual1, diff10); + case COMPOUND_DIFFWTD: + return pick_interinter_seg(cpi, x, bsize, p0, p1, residual1, diff10); default: assert(0); return 0; } } @@ -7336,7 +7444,7 @@ static int64_t build_and_cost_compound_type( const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv, const BLOCK_SIZE bsize, const int this_mode, int *rs2, int rate_mv, BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0, uint8_t **preds1, - int *strides, int mi_row, int mi_col) { + int16_t *residual1, int16_t *diff10, int *strides, int mi_row, int mi_col) { const AV1_COMMON *const cm = &cpi->common; MACROBLOCKD *xd = &x->e_mbd; MB_MODE_INFO *const mbmi = xd->mi[0]; @@ -7348,7 +7456,8 @@ static int64_t build_and_cost_compound_type( int64_t tmp_skip_sse_sb; const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type; - best_rd_cur = pick_interinter_mask(cpi, x, bsize, *preds0, *preds1); + best_rd_cur = + pick_interinter_mask(cpi, x, bsize, *preds0, *preds1, residual1, diff10); *rs2 += get_interinter_compound_mask_rate(x, mbmi); best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0); @@ -7357,6 +7466,7 @@ static int64_t build_and_cost_compound_type( *out_rate_mv = interinter_compound_motion_search(cpi, x, cur_mv, bsize, this_mode, mi_row, mi_col); av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, ctx, bsize); + av1_subtract_plane(x, bsize, 0); model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL); rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum); @@ -7367,7 +7477,6 @@ static int64_t build_and_cost_compound_type( av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides, preds1, strides); } - av1_subtract_plane(x, bsize, 0); rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); if (rd != INT64_MAX) @@ -7377,7 +7486,6 @@ static int64_t build_and_cost_compound_type( } else { av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides, preds1, strides); - av1_subtract_plane(x, bsize, 0); rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); if (rd != INT64_MAX) @@ -7393,11 +7501,11 @@ typedef struct { int above_pred_stride[MAX_MB_PLANE]; uint8_t *left_pred_buf[MAX_MB_PLANE]; int left_pred_stride[MAX_MB_PLANE]; - int_mv *single_newmv; + int_mv (*single_newmv)[REF_FRAMES]; // Pointer to array of motion vectors to use for each ref and their rates // Should point to first of 2 arrays in 2D array - int *single_newmv_rate; - int *single_newmv_valid; + int (*single_newmv_rate)[REF_FRAMES]; + int (*single_newmv_valid)[REF_FRAMES]; // Pointer to array of predicted rate-distortion // Should point to first of 2 arrays in 2D array int64_t (*modelled_rd)[REF_FRAMES]; @@ -7428,14 +7536,15 @@ static int64_t handle_newmv(const AV1_COMP *const cpi, MACROBLOCK *const x, const PREDICTION_MODE this_mode = mbmi->mode; const int refs[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1] }; + const int ref_mv_idx = mbmi->ref_mv_idx; int i; (void)args; if (is_comp_pred) { if (this_mode == NEW_NEWMV) { - cur_mv[0].as_int = args->single_newmv[refs[0]].as_int; - cur_mv[1].as_int = args->single_newmv[refs[1]].as_int; + cur_mv[0].as_int = args->single_newmv[ref_mv_idx][refs[0]].as_int; + cur_mv[1].as_int = args->single_newmv[ref_mv_idx][refs[1]].as_int; if (cpi->sf.comp_inter_joint_search_thresh <= bsize) { joint_motion_search(cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, NULL, @@ -7451,7 +7560,7 @@ static int64_t handle_newmv(const AV1_COMP *const cpi, MACROBLOCK *const x, } } } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) { - cur_mv[1].as_int = args->single_newmv[refs[1]].as_int; + cur_mv[1].as_int = args->single_newmv[ref_mv_idx][refs[1]].as_int; if (cpi->sf.comp_inter_joint_search_thresh <= bsize) { compound_single_motion_search_interinter( cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, 0, rate_mv, 0, 1); @@ -7464,7 +7573,7 @@ static int64_t handle_newmv(const AV1_COMP *const cpi, MACROBLOCK *const x, } } else { assert(this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV); - cur_mv[0].as_int = args->single_newmv[refs[0]].as_int; + cur_mv[0].as_int = args->single_newmv[ref_mv_idx][refs[0]].as_int; if (cpi->sf.comp_inter_joint_search_thresh <= bsize) { compound_single_motion_search_interinter( cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, 0, rate_mv, 0, 0); @@ -7480,9 +7589,9 @@ static int64_t handle_newmv(const AV1_COMP *const cpi, MACROBLOCK *const x, single_motion_search(cpi, x, bsize, mi_row, mi_col, 0, rate_mv); if (x->best_mv.as_int == INVALID_MV) return INT64_MAX; - args->single_newmv[refs[0]] = x->best_mv; - args->single_newmv_rate[refs[0]] = *rate_mv; - args->single_newmv_valid[refs[0]] = 1; + args->single_newmv[ref_mv_idx][refs[0]] = x->best_mv; + args->single_newmv_rate[ref_mv_idx][refs[0]] = *rate_mv; + args->single_newmv_valid[ref_mv_idx][refs[0]] = 1; cur_mv[0].as_int = x->best_mv.as_int; @@ -7508,12 +7617,25 @@ static INLINE void swap_dst_buf(MACROBLOCKD *xd, const BUFFER_SET *dst_bufs[2], restore_dst_buf(xd, *dst_bufs[0], num_planes); } +static INLINE int get_switchable_rate(MACROBLOCK *const x, + const InterpFilters filters, + const int ctx[2]) { + int inter_filter_cost; + const InterpFilter filter0 = av1_extract_interp_filter(filters, 0); + const InterpFilter filter1 = av1_extract_interp_filter(filters, 1); + inter_filter_cost = x->switchable_interp_costs[ctx[0]][filter0]; + inter_filter_cost += x->switchable_interp_costs[ctx[1]][filter1]; + return SWITCHABLE_INTERP_RATE_FACTOR * inter_filter_cost; +} + // calculate the rdcost of given interpolation_filter static INLINE int64_t interpolation_filter_rd( MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize, int mi_row, int mi_col, BUFFER_SET *const orig_dst, int64_t *const rd, int *const switchable_rate, int *const skip_txfm_sb, - int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2], int filter_idx) { + int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2], int filter_idx, + const int switchable_ctx[2], const int skip_pred, int *rate, + int64_t *dist) { const AV1_COMMON *cm = &cpi->common; const int num_planes = av1_num_planes(cm); MACROBLOCKD *const xd = &x->e_mbd; @@ -7523,23 +7645,136 @@ static INLINE int64_t interpolation_filter_rd( const InterpFilters last_best = mbmi->interp_filters; mbmi->interp_filters = filter_sets[filter_idx]; - const int tmp_rs = av1_get_switchable_rate(cm, x, xd); - av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize); - model_rd_for_sb(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate, &tmp_dist, - &tmp_skip_sb, &tmp_skip_sse, NULL, NULL, NULL); + const int tmp_rs = + get_switchable_rate(x, mbmi->interp_filters, switchable_ctx); + + if (!skip_pred) { + av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize); + av1_subtract_plane(x, bsize, 0); +#if DNN_BASED_RD_INTERP_FILTER + model_rd_for_sb_with_dnn(cpi, bsize, x, xd, 0, 0, &tmp_rate, &tmp_dist, + &tmp_skip_sb, &tmp_skip_sse, NULL, NULL, NULL); +#else + model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &tmp_rate, &tmp_dist, &tmp_skip_sb, + &tmp_skip_sse, NULL, NULL, NULL); +#endif + if (num_planes > 1) { + int64_t tmp_y_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate, tmp_dist); + if (tmp_y_rd > *rd) { + mbmi->interp_filters = last_best; + return 0; + } + int tmp_rate_uv, tmp_skip_sb_uv; + int64_t tmp_dist_uv, tmp_skip_sse_uv; + av1_build_inter_predictors_sbuv(cm, xd, mi_row, mi_col, orig_dst, bsize); + for (int plane = 1; plane < num_planes; ++plane) + av1_subtract_plane(x, bsize, plane); +#if DNN_BASED_RD_INTERP_FILTER + model_rd_for_sb_with_dnn(cpi, bsize, x, xd, 1, num_planes - 1, + &tmp_rate_uv, &tmp_dist_uv, &tmp_skip_sb_uv, + &tmp_skip_sse_uv, NULL, NULL, NULL); +#else + model_rd_for_sb(cpi, bsize, x, xd, 1, num_planes - 1, &tmp_rate_uv, + &tmp_dist_uv, &tmp_skip_sb_uv, &tmp_skip_sse_uv, NULL, + NULL, NULL); +#endif + tmp_rate += tmp_rate_uv; + tmp_skip_sb &= tmp_skip_sb_uv; + tmp_dist += tmp_dist_uv; + tmp_skip_sse += tmp_skip_sse_uv; + } + } else { + tmp_rate = *rate; + tmp_dist = *dist; + } int64_t tmp_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate, tmp_dist); if (tmp_rd < *rd) { *rd = tmp_rd; *switchable_rate = tmp_rs; *skip_txfm_sb = tmp_skip_sb; *skip_sse_sb = tmp_skip_sse; - swap_dst_buf(xd, dst_bufs, num_planes); + *rate = tmp_rate; + *dist = tmp_dist; + if (!skip_pred) { + swap_dst_buf(xd, dst_bufs, num_planes); + } return 1; } mbmi->interp_filters = last_best; return 0; } +// Find the best rd filter in horizontal direction +static INLINE int find_best_horiz_interp_filter_rd( + MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize, + int mi_row, int mi_col, BUFFER_SET *const orig_dst, int64_t *const rd, + int *const switchable_rate, int *const skip_txfm_sb, + int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2], + const int switchable_ctx[2], const int skip_hor, int *rate, int64_t *dist, + int best_dual_mode) { + int i; + const int bw = block_size_wide[bsize]; + assert(best_dual_mode == 0); + if ((bw <= 4) && (!skip_hor)) { + int skip_pred = 1; + // Process the filters in reverse order to enable reusing rate and + // distortion (calcuated during EIGHTTAP_REGULAR) for MULTITAP_SHARP + for (i = (SWITCHABLE_FILTERS - 1); i >= 1; --i) { + if (interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, + switchable_rate, skip_txfm_sb, skip_sse_sb, + dst_bufs, i, switchable_ctx, skip_pred, rate, + dist)) { + best_dual_mode = i; + } + skip_pred = 0; + } + } else { + for (i = 1; i < SWITCHABLE_FILTERS; ++i) { + if (interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, + switchable_rate, skip_txfm_sb, skip_sse_sb, + dst_bufs, i, switchable_ctx, skip_hor, rate, + dist)) { + best_dual_mode = i; + } + } + } + return best_dual_mode; +} + +// Find the best rd filter in vertical direction +static INLINE void find_best_vert_interp_filter_rd( + MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize, + int mi_row, int mi_col, BUFFER_SET *const orig_dst, int64_t *const rd, + int *const switchable_rate, int *const skip_txfm_sb, + int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2], + const int switchable_ctx[2], const int skip_ver, int *rate, int64_t *dist, + int best_dual_mode, int filter_set_size) { + int i; + const int bh = block_size_high[bsize]; + if ((bh <= 4) && (!skip_ver)) { + int skip_pred = 1; + // Process the filters in reverse order to enable reusing rate and + // distortion (calcuated during EIGHTTAP_REGULAR) for MULTITAP_SHARP + assert(filter_set_size == DUAL_FILTER_SET_SIZE); + for (i = (filter_set_size - SWITCHABLE_FILTERS + best_dual_mode); + i >= (best_dual_mode + SWITCHABLE_FILTERS); i -= SWITCHABLE_FILTERS) { + interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, + switchable_rate, skip_txfm_sb, skip_sse_sb, + dst_bufs, i, switchable_ctx, skip_pred, rate, + dist); + skip_pred = 0; + } + } else { + for (i = best_dual_mode + SWITCHABLE_FILTERS; i < filter_set_size; + i += SWITCHABLE_FILTERS) { + interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, + switchable_rate, skip_txfm_sb, skip_sse_sb, + dst_bufs, i, switchable_ctx, skip_ver, rate, + dist); + } + } +} + // check if there is saved result match with this search static INLINE int is_interp_filter_match(const INTERPOLATION_FILTER_STATS *st, MB_MODE_INFO *const mi) { @@ -7605,10 +7840,22 @@ static int64_t interpolation_filter_search( if (!need_search || match_found == -1) { set_default_interp_filters(mbmi, assign_filter); } - *switchable_rate = av1_get_switchable_rate(cm, x, xd); + int switchable_ctx[2]; + switchable_ctx[0] = av1_get_pred_context_switchable_interp(xd, 0); + switchable_ctx[1] = av1_get_pred_context_switchable_interp(xd, 1); + *switchable_rate = + get_switchable_rate(x, mbmi->interp_filters, switchable_ctx); av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize); + for (int plane = 0; plane < num_planes; ++plane) + av1_subtract_plane(x, bsize, plane); +#if DNN_BASED_RD_INTERP_FILTER + model_rd_for_sb_with_dnn(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate, + &tmp_dist, skip_txfm_sb, skip_sse_sb, NULL, NULL, + NULL); +#else model_rd_for_sb(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate, &tmp_dist, skip_txfm_sb, skip_sse_sb, NULL, NULL, NULL); +#endif // DNN_BASED_RD_INTERP_FILTER *rd = RDCOST(x->rdmult, *switchable_rate + tmp_rate, tmp_dist); if (assign_filter != SWITCHABLE || match_found != -1) { @@ -7619,6 +7866,23 @@ static int64_t interpolation_filter_search( av1_broadcast_interp_filter(EIGHTTAP_REGULAR)); return 0; } + int skip_hor = 1; + int skip_ver = 1; + const int is_compound = has_second_ref(mbmi); + for (int k = 0; k < num_planes - 1; ++k) { + struct macroblockd_plane *const pd = &xd->plane[k]; + const int bw = pd->width; + const int bh = pd->height; + for (int j = 0; j < 1 + is_compound; ++j) { + const MV mv = mbmi->mv[j].as_mv; + const MV mv_q4 = clamp_mv_to_umv_border_sb( + xd, &mv, bw, bh, pd->subsampling_x, pd->subsampling_y); + const int sub_x = (mv_q4.col & SUBPEL_MASK) << SCALE_EXTRA_BITS; + const int sub_y = (mv_q4.row & SUBPEL_MASK) << SCALE_EXTRA_BITS; + skip_hor &= (sub_x == 0); + skip_ver &= (sub_y == 0); + } + } // do interp_filter search const int filter_set_size = DUAL_FILTER_SET_SIZE; restore_dst_buf(xd, *tmp_dst, num_planes); @@ -7629,20 +7893,16 @@ static int64_t interpolation_filter_search( int best_dual_mode = 0; // Find best of {R}x{R,Sm,Sh} // EIGHTTAP_REGULAR mode is calculated beforehand - for (i = 1; i < SWITCHABLE_FILTERS; ++i) { - if (interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, - switchable_rate, skip_txfm_sb, skip_sse_sb, - dst_bufs, i)) { - best_dual_mode = i; - } - } + best_dual_mode = find_best_horiz_interp_filter_rd( + x, cpi, bsize, mi_row, mi_col, orig_dst, rd, switchable_rate, + skip_txfm_sb, skip_sse_sb, dst_bufs, switchable_ctx, skip_hor, + &tmp_rate, &tmp_dist, best_dual_mode); + // From best of horizontal EIGHTTAP_REGULAR modes, check vertical modes - for (i = best_dual_mode + SWITCHABLE_FILTERS; i < filter_set_size; - i += SWITCHABLE_FILTERS) { - interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, - switchable_rate, skip_txfm_sb, skip_sse_sb, - dst_bufs, i); - } + find_best_vert_interp_filter_rd( + x, cpi, bsize, mi_row, mi_col, orig_dst, rd, switchable_rate, + skip_txfm_sb, skip_sse_sb, dst_bufs, switchable_ctx, skip_ver, + &tmp_rate, &tmp_dist, best_dual_mode, filter_set_size); } else { // EIGHTTAP_REGULAR mode is calculated beforehand for (i = 1; i < filter_set_size; ++i) { @@ -7653,7 +7913,8 @@ static int64_t interpolation_filter_search( } interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd, switchable_rate, skip_txfm_sb, skip_sse_sb, - dst_bufs, i); + dst_bufs, i, switchable_ctx, 0, &tmp_rate, + &tmp_dist); } } swap_dst_buf(xd, dst_bufs, num_planes); @@ -7848,6 +8109,7 @@ static int64_t motion_mode_rd(const AV1_COMP *const cpi, MACROBLOCK *const x, av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst, intrapred, bw); av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw); + av1_subtract_plane(x, bsize, 0); model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL); rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum); @@ -7861,7 +8123,6 @@ static int64_t motion_mode_rd(const AV1_COMP *const cpi, MACROBLOCK *const x, av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst, intrapred, bw); av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw); - av1_subtract_plane(x, bsize, 0); rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); if (rd != INT64_MAX) @@ -7908,6 +8169,7 @@ static int64_t motion_mode_rd(const AV1_COMP *const cpi, MACROBLOCK *const x, mbmi->mv[0].as_int = tmp_mv.as_int; av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize); + av1_subtract_plane(x, bsize, 0); model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL); @@ -7925,7 +8187,6 @@ static int64_t motion_mode_rd(const AV1_COMP *const cpi, MACROBLOCK *const x, av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw); } // Evaluate closer to true rd - av1_subtract_plane(x, bsize, 0); rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); @@ -8323,6 +8584,148 @@ static INLINE int get_drl_cost(const MB_MODE_INFO *mbmi, return cost; } +static INLINE int compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x, + BLOCK_SIZE bsize, int mi_col, int mi_row, + int_mv *cur_mv, int masked_compound_used, + BUFFER_SET *orig_dst, BUFFER_SET *tmp_dst, + int *rate_mv, int64_t *rd, + RD_STATS *rd_stats, int64_t ref_best_rd) { + const AV1_COMMON *cm = &cpi->common; + MACROBLOCKD *xd = &x->e_mbd; + MB_MODE_INFO *mbmi = xd->mi[0]; + const int this_mode = mbmi->mode; + const int bw = block_size_wide[bsize]; + const int bh = block_size_high[bsize]; + int rate_sum, rs2; + int64_t dist_sum; + + int_mv best_mv[2]; + int best_tmp_rate_mv = *rate_mv; + int tmp_skip_txfm_sb; + int64_t tmp_skip_sse_sb; + INTERINTER_COMPOUND_DATA best_compound_data; + best_compound_data.type = COMPOUND_AVERAGE; + DECLARE_ALIGNED(16, uint8_t, pred0[2 * MAX_SB_SQUARE]); + DECLARE_ALIGNED(16, uint8_t, pred1[2 * MAX_SB_SQUARE]); + DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1 + DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0 + uint8_t tmp_best_mask_buf[2 * MAX_SB_SQUARE]; + uint8_t *preds0[1] = { pred0 }; + uint8_t *preds1[1] = { pred1 }; + int strides[1] = { bw }; + int tmp_rate_mv; + const int num_pix = 1 << num_pels_log2_lookup[bsize]; + const int mask_len = 2 * num_pix * sizeof(uint8_t); + COMPOUND_TYPE cur_type; + int best_compmode_interinter_cost = 0; + int can_use_previous = cm->allow_warped_motion; + + best_mv[0].as_int = cur_mv[0].as_int; + best_mv[1].as_int = cur_mv[1].as_int; + *rd = INT64_MAX; + if (masked_compound_used) { + // get inter predictors to use for masked compound modes + av1_build_inter_predictors_for_planes_single_buf( + xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides, can_use_previous); + av1_build_inter_predictors_for_planes_single_buf( + xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides, can_use_previous); + const struct buf_2d *const src = &x->plane[0].src; + if (get_bitdepth_data_path_index(xd)) { + aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, + CONVERT_TO_BYTEPTR(pred1), bw, xd->bd); + aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(pred1), + bw, CONVERT_TO_BYTEPTR(pred0), bw, xd->bd); + } else { + aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, pred1, + bw); + aom_subtract_block(bh, bw, diff10, bw, pred1, bw, pred0, bw); + } + } + const int orig_is_best = xd->plane[0].dst.buf == orig_dst->plane[0]; + const BUFFER_SET *backup_buf = orig_is_best ? tmp_dst : orig_dst; + const BUFFER_SET *best_buf = orig_is_best ? orig_dst : tmp_dst; + for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) { + if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break; + if (!is_interinter_compound_used(cur_type, bsize)) continue; + tmp_rate_mv = *rate_mv; + int64_t best_rd_cur = INT64_MAX; + mbmi->interinter_comp.type = cur_type; + int masked_type_cost = 0; + + const int comp_group_idx_ctx = get_comp_group_idx_context(xd); + const int comp_index_ctx = get_comp_index_context(cm, xd); + mbmi->compound_idx = 1; + if (cur_type == COMPOUND_AVERAGE) { + mbmi->comp_group_idx = 0; + if (masked_compound_used) { + masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0]; + } + masked_type_cost += x->comp_idx_cost[comp_index_ctx][1]; + rs2 = masked_type_cost; + // No need to call av1_build_inter_predictors_sby here + // 1. COMPOUND_AVERAGE is always the first candidate + // 2. av1_build_inter_predictors_sby has been called by + // interpolation_filter_search + int64_t est_rd = + estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, + &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); + // use spare buffer for following compound type try + restore_dst_buf(xd, *backup_buf, 1); + if (est_rd != INT64_MAX) + best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum); + } else { + mbmi->comp_group_idx = 1; + masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1]; + masked_type_cost += x->compound_type_cost[bsize][cur_type - 1]; + rs2 = masked_type_cost; + if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh && + *rd / 3 < ref_best_rd) { + best_rd_cur = build_and_cost_compound_type( + cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst, + &tmp_rate_mv, preds0, preds1, residual1, diff10, strides, mi_row, + mi_col); + } + } + if (best_rd_cur < *rd) { + *rd = best_rd_cur; + best_compound_data = mbmi->interinter_comp; + if (masked_compound_used && cur_type != COMPOUND_TYPES - 1) { + memcpy(tmp_best_mask_buf, xd->seg_mask, mask_len); + } + best_compmode_interinter_cost = rs2; + if (have_newmv_in_inter_mode(this_mode)) { + if (use_masked_motion_search(cur_type)) { + best_tmp_rate_mv = tmp_rate_mv; + best_mv[0].as_int = mbmi->mv[0].as_int; + best_mv[1].as_int = mbmi->mv[1].as_int; + } else { + best_mv[0].as_int = cur_mv[0].as_int; + best_mv[1].as_int = cur_mv[1].as_int; + } + } + } + // reset to original mvs for next iteration + mbmi->mv[0].as_int = cur_mv[0].as_int; + mbmi->mv[1].as_int = cur_mv[1].as_int; + } + if (mbmi->interinter_comp.type != best_compound_data.type) { + mbmi->comp_group_idx = + (best_compound_data.type == COMPOUND_AVERAGE) ? 0 : 1; + mbmi->interinter_comp = best_compound_data; + memcpy(xd->seg_mask, tmp_best_mask_buf, mask_len); + } + if (have_newmv_in_inter_mode(this_mode)) { + mbmi->mv[0].as_int = best_mv[0].as_int; + mbmi->mv[1].as_int = best_mv[1].as_int; + if (use_masked_motion_search(mbmi->interinter_comp.type)) { + rd_stats->rate += best_tmp_rate_mv - *rate_mv; + *rate_mv = best_tmp_rate_mv; + } + } + restore_dst_buf(xd, *best_buf, 1); + return best_compmode_interinter_cost; +} + static int64_t handle_inter_mode(const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv, @@ -8344,63 +8747,24 @@ static int64_t handle_inter_mode(const AV1_COMP *const cpi, MACROBLOCK *x, int refs[2] = { mbmi->ref_frame[0], (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) }; int rate_mv = 0; - const int bw = block_size_wide[bsize]; DECLARE_ALIGNED(32, uint8_t, tmp_buf_[2 * MAX_MB_PLANE * MAX_SB_SQUARE]); - uint8_t *tmp_buf; + uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_); int64_t rd = INT64_MAX; BUFFER_SET orig_dst, tmp_dst; int skip_txfm_sb = 0; int64_t skip_sse_sb = INT64_MAX; int16_t mode_ctx; - - mbmi->interinter_comp.type = COMPOUND_AVERAGE; - mbmi->comp_group_idx = 0; - mbmi->compound_idx = 1; - if (mbmi->ref_frame[1] == INTRA_FRAME) mbmi->ref_frame[1] = NONE_FRAME; - - mode_ctx = av1_mode_context_analyzer(mbmi_ext->mode_context, mbmi->ref_frame); - - if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) - tmp_buf = CONVERT_TO_BYTEPTR(tmp_buf_); - else - tmp_buf = tmp_buf_; - // Make sure that we didn't leave the plane destination buffers set - // to tmp_buf at the end of the last iteration - assert(xd->plane[0].dst.buf != tmp_buf); - - mbmi->num_proj_ref[0] = 0; - mbmi->num_proj_ref[1] = 0; - - if (is_comp_pred) { - for (int ref_idx = 0; ref_idx < is_comp_pred + 1; ++ref_idx) { - const int single_mode = get_single_mode(this_mode, ref_idx, is_comp_pred); - if (single_mode == NEWMV && - args->single_newmv[mbmi->ref_frame[ref_idx]].as_int == INVALID_MV) - return INT64_MAX; - } - } - - mbmi->motion_mode = SIMPLE_TRANSLATION; const int masked_compound_used = is_any_masked_compound_used(bsize) && cm->seq_params.enable_masked_compound; int64_t ret_val = INT64_MAX; const int8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame); - rd_stats->rate += args->ref_frame_cost + args->single_comp_cost; - rd_stats->rate += - get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type); - const RD_STATS backup_rd_stats = *rd_stats; - const RD_STATS backup_rd_stats_y = *rd_stats_y; - const RD_STATS backup_rd_stats_uv = *rd_stats_uv; - const MB_MODE_INFO backup_mbmi = *mbmi; - INTERINTER_COMPOUND_DATA best_compound_data; - uint8_t tmp_best_mask_buf[2 * MAX_SB_SQUARE]; RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv; int64_t best_rd = INT64_MAX; - int64_t best_ret_val = INT64_MAX; uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE]; MB_MODE_INFO best_mbmi = *mbmi; - int64_t early_terminate = 0; + int best_disable_skip; + int best_xskip; int plane_rate[MAX_MB_PLANE] = { 0 }; int64_t plane_sse[MAX_MB_PLANE] = { 0 }; int64_t plane_dist[MAX_MB_PLANE] = { 0 }; @@ -8411,387 +8775,311 @@ static int64_t handle_inter_mode(const AV1_COMP *const cpi, MACROBLOCK *x, int comp_idx; const int search_jnt_comp = is_comp_pred & cm->seq_params.enable_jnt_comp & (mbmi->mode != GLOBAL_GLOBALMV); - // If !search_jnt_comp, we need to force mbmi->compound_idx = 1. - for (comp_idx = 1; comp_idx >= !search_jnt_comp; --comp_idx) { - int rs = 0; - int compmode_interinter_cost = 0; - early_terminate = 0; - *rd_stats = backup_rd_stats; - *rd_stats_y = backup_rd_stats_y; - *rd_stats_uv = backup_rd_stats_uv; - *mbmi = backup_mbmi; - mbmi->compound_idx = comp_idx; - - if (is_comp_pred && comp_idx == 0) { - mbmi->comp_group_idx = 0; - mbmi->compound_idx = 0; - const int comp_group_idx_ctx = get_comp_group_idx_context(xd); - const int comp_index_ctx = get_comp_index_context(cm, xd); - if (masked_compound_used) { - compmode_interinter_cost += - x->comp_group_idx_cost[comp_group_idx_ctx][0]; + const int has_drl = (have_nearmv_in_inter_mode(mbmi->mode) && + mbmi_ext->ref_mv_count[ref_frame_type] > 2) || + ((mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) && + mbmi_ext->ref_mv_count[ref_frame_type] > 1); + + // TODO(jingning): This should be deprecated shortly. + const int idx_offset = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0; + const int ref_set = + has_drl ? AOMMIN(MAX_REF_MV_SERCH, + mbmi_ext->ref_mv_count[ref_frame_type] - idx_offset) + : 1; + + for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ++ref_mv_idx) { + if (cpi->sf.reduce_inter_modes && ref_mv_idx > 0) { + if (mbmi->ref_frame[0] == LAST2_FRAME || + mbmi->ref_frame[0] == LAST3_FRAME || + mbmi->ref_frame[1] == LAST2_FRAME || + mbmi->ref_frame[1] == LAST3_FRAME) { + if (mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx + idx_offset] + .weight < REF_CAT_LEVEL) { + continue; + } } - compmode_interinter_cost += x->comp_idx_cost[comp_index_ctx][0]; } - int_mv cur_mv[2]; - if (!build_cur_mv(cur_mv, this_mode, cm, x)) { - early_terminate = INT64_MAX; - continue; - } - if (have_newmv_in_inter_mode(this_mode)) { - if (comp_idx == 0) { - cur_mv[0] = backup_mv[0]; - cur_mv[1] = backup_mv[1]; - rate_mv = backup_rate_mv; - } + av1_init_rd_stats(rd_stats); - // when jnt_comp_skip_mv_search flag is on, new mv will be searched once - if (!(search_jnt_comp && cpi->sf.jnt_comp_skip_mv_search && - comp_idx == 0)) { - newmv_ret_val = - handle_newmv(cpi, x, bsize, cur_mv, mi_row, mi_col, &rate_mv, args); - - // Store cur_mv and rate_mv so that they can be restored in the next - // iteration of the loop - backup_mv[0] = cur_mv[0]; - backup_mv[1] = cur_mv[1]; - backup_rate_mv = rate_mv; - } - - if (newmv_ret_val != 0) { - early_terminate = INT64_MAX; - continue; - } else { - rd_stats->rate += rate_mv; - } - } - for (i = 0; i < is_comp_pred + 1; ++i) { - mbmi->mv[i].as_int = cur_mv[i].as_int; - } + mbmi->interinter_comp.type = COMPOUND_AVERAGE; + mbmi->comp_group_idx = 0; + mbmi->compound_idx = 1; + if (mbmi->ref_frame[1] == INTRA_FRAME) mbmi->ref_frame[1] = NONE_FRAME; - // Initialise tmp_dst and orig_dst buffers to prevent "may be used - // uninitialized" warnings in GCC when the stream is monochrome. - memset(tmp_dst.plane, 0, sizeof(tmp_dst.plane)); - memset(tmp_dst.stride, 0, sizeof(tmp_dst.stride)); - memset(orig_dst.plane, 0, sizeof(tmp_dst.plane)); - memset(orig_dst.stride, 0, sizeof(tmp_dst.stride)); + mode_ctx = + av1_mode_context_analyzer(mbmi_ext->mode_context, mbmi->ref_frame); - // do first prediction into the destination buffer. Do the next - // prediction into a temporary buffer. Then keep track of which one - // of these currently holds the best predictor, and use the other - // one for future predictions. In the end, copy from tmp_buf to - // dst if necessary. - for (i = 0; i < num_planes; i++) { - tmp_dst.plane[i] = tmp_buf + i * MAX_SB_SQUARE; - tmp_dst.stride[i] = MAX_SB_SIZE; - } - for (i = 0; i < num_planes; i++) { - orig_dst.plane[i] = xd->plane[i].dst.buf; - orig_dst.stride[i] = xd->plane[i].dst.stride; - } + mbmi->num_proj_ref[0] = 0; + mbmi->num_proj_ref[1] = 0; + mbmi->motion_mode = SIMPLE_TRANSLATION; + mbmi->ref_mv_idx = ref_mv_idx; - const int ref_mv_cost = cost_mv_ref(x, this_mode, mode_ctx); -#if USE_DISCOUNT_NEWMV_TEST - // We don't include the cost of the second reference here, because there - // are only three options: Last/Golden, ARF/Last or Golden/ARF, or in other - // words if you present them in that order, the second one is always known - // if the first is known. - // - // Under some circumstances we discount the cost of new mv mode to encourage - // initiation of a motion field. - if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) { - // discount_newmv_test only applies discount on NEWMV mode. - assert(this_mode == NEWMV); - rd_stats->rate += AOMMIN(cost_mv_ref(x, this_mode, mode_ctx), - cost_mv_ref(x, NEARESTMV, mode_ctx)); - } else { - rd_stats->rate += ref_mv_cost; + if (is_comp_pred) { + for (int ref_idx = 0; ref_idx < is_comp_pred + 1; ++ref_idx) { + const int single_mode = + get_single_mode(this_mode, ref_idx, is_comp_pred); + if (single_mode == NEWMV && + args->single_newmv[mbmi->ref_mv_idx][mbmi->ref_frame[ref_idx]] + .as_int == INVALID_MV) + continue; + } } -#else - rd_stats->rate += ref_mv_cost; -#endif - if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd && - mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) { - early_terminate = INT64_MAX; - continue; - } + rd_stats->rate += args->ref_frame_cost + args->single_comp_cost; + rd_stats->rate += + get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type); - ret_val = interpolation_filter_search( - x, cpi, bsize, mi_row, mi_col, &tmp_dst, &orig_dst, args->single_filter, - &rd, &rs, &skip_txfm_sb, &skip_sse_sb); - if (ret_val != 0) { - early_terminate = INT64_MAX; - restore_dst_buf(xd, orig_dst, num_planes); - continue; - } else if (cpi->sf.model_based_post_interp_filter_breakout && - ref_best_rd != INT64_MAX && (rd / 6) > ref_best_rd) { - early_terminate = INT64_MAX; - restore_dst_buf(xd, orig_dst, num_planes); - if ((rd >> 4) > ref_best_rd) break; - continue; - } + const RD_STATS backup_rd_stats = *rd_stats; + const MB_MODE_INFO backup_mbmi = *mbmi; + int64_t best_rd2 = INT64_MAX; - if (is_comp_pred && comp_idx) { - int rate_sum, rs2; - int64_t dist_sum; - int64_t best_rd_compound = INT64_MAX, best_rd_cur = INT64_MAX; - int_mv best_mv[2]; - int best_tmp_rate_mv = rate_mv; - int tmp_skip_txfm_sb; - int64_t tmp_skip_sse_sb; - DECLARE_ALIGNED(16, uint8_t, pred0[2 * MAX_SB_SQUARE]); - DECLARE_ALIGNED(16, uint8_t, pred1[2 * MAX_SB_SQUARE]); - uint8_t *preds0[1] = { pred0 }; - uint8_t *preds1[1] = { pred1 }; - int strides[1] = { bw }; - int tmp_rate_mv; - const int num_pix = 1 << num_pels_log2_lookup[bsize]; - COMPOUND_TYPE cur_type; - int best_compmode_interinter_cost = 0; - int can_use_previous = cm->allow_warped_motion; - - best_mv[0].as_int = cur_mv[0].as_int; - best_mv[1].as_int = cur_mv[1].as_int; + // If !search_jnt_comp, we need to force mbmi->compound_idx = 1. + for (comp_idx = 1; comp_idx >= !search_jnt_comp; --comp_idx) { + int rs = 0; + int compmode_interinter_cost = 0; + *rd_stats = backup_rd_stats; + *mbmi = backup_mbmi; + mbmi->compound_idx = comp_idx; - if (masked_compound_used) { - // get inter predictors to use for masked compound modes - av1_build_inter_predictors_for_planes_single_buf( - xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides, - can_use_previous); - av1_build_inter_predictors_for_planes_single_buf( - xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides, - can_use_previous); - } - - int best_comp_group_idx = 0; - int best_compound_idx = 1; - for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) { - if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break; - if (!is_interinter_compound_used(cur_type, bsize)) continue; - tmp_rate_mv = rate_mv; - best_rd_cur = INT64_MAX; - mbmi->interinter_comp.type = cur_type; - int masked_type_cost = 0; + if (is_comp_pred && comp_idx == 0) { + mbmi->comp_group_idx = 0; + mbmi->compound_idx = 0; const int comp_group_idx_ctx = get_comp_group_idx_context(xd); const int comp_index_ctx = get_comp_index_context(cm, xd); if (masked_compound_used) { - if (cur_type == COMPOUND_AVERAGE) { - mbmi->comp_group_idx = 0; - mbmi->compound_idx = 1; - - masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0]; - masked_type_cost += x->comp_idx_cost[comp_index_ctx][1]; - } else { - mbmi->comp_group_idx = 1; - mbmi->compound_idx = 1; - - masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1]; - masked_type_cost += - x->compound_type_cost[bsize][mbmi->interinter_comp.type - 1]; - } - } else { - mbmi->comp_group_idx = 0; - mbmi->compound_idx = 1; - - masked_type_cost += x->comp_idx_cost[comp_index_ctx][1]; + compmode_interinter_cost += + x->comp_group_idx_cost[comp_group_idx_ctx][0]; } - rs2 = masked_type_cost; + compmode_interinter_cost += x->comp_idx_cost[comp_index_ctx][0]; + } - switch (cur_type) { - case COMPOUND_AVERAGE: - av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst, - bsize); - av1_subtract_plane(x, bsize, 0); - rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, - &tmp_skip_txfm_sb, &tmp_skip_sse_sb, - INT64_MAX); - if (rd != INT64_MAX) - best_rd_cur = - RDCOST(x->rdmult, rs2 + rate_mv + rate_sum, dist_sum); - break; - case COMPOUND_WEDGE: - if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh && - best_rd_compound / 3 < ref_best_rd) { - best_rd_cur = build_and_cost_compound_type( - cpi, x, cur_mv, bsize, this_mode, &rs2, rate_mv, &orig_dst, - &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col); - } - break; - case COMPOUND_DIFFWTD: - if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh && - best_rd_compound / 3 < ref_best_rd) { - best_rd_cur = build_and_cost_compound_type( - cpi, x, cur_mv, bsize, this_mode, &rs2, rate_mv, &orig_dst, - &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col); - } - break; - default: assert(0); return INT64_MAX; + int_mv cur_mv[2]; + if (!build_cur_mv(cur_mv, this_mode, cm, x)) { + continue; + } + if (have_newmv_in_inter_mode(this_mode)) { + if (comp_idx == 0) { + cur_mv[0] = backup_mv[0]; + cur_mv[1] = backup_mv[1]; + rate_mv = backup_rate_mv; } - if (best_rd_cur < best_rd_compound) { - best_comp_group_idx = mbmi->comp_group_idx; - best_compound_idx = mbmi->compound_idx; - best_rd_compound = best_rd_cur; - best_compound_data = mbmi->interinter_comp; - memcpy(tmp_best_mask_buf, xd->seg_mask, - 2 * num_pix * sizeof(uint8_t)); - best_compmode_interinter_cost = rs2; - if (have_newmv_in_inter_mode(this_mode)) { - if (use_masked_motion_search(cur_type)) { - best_tmp_rate_mv = tmp_rate_mv; - best_mv[0].as_int = mbmi->mv[0].as_int; - best_mv[1].as_int = mbmi->mv[1].as_int; - } else { - best_mv[0].as_int = cur_mv[0].as_int; - best_mv[1].as_int = cur_mv[1].as_int; - } - } + // when jnt_comp_skip_mv_search flag is on, new mv will be searched once + if (!(search_jnt_comp && cpi->sf.jnt_comp_skip_mv_search && + comp_idx == 0)) { + newmv_ret_val = handle_newmv(cpi, x, bsize, cur_mv, mi_row, mi_col, + &rate_mv, args); + + // Store cur_mv and rate_mv so that they can be restored in the next + // iteration of the loop + backup_mv[0] = cur_mv[0]; + backup_mv[1] = cur_mv[1]; + backup_rate_mv = rate_mv; } - // reset to original mvs for next iteration - mbmi->mv[0].as_int = cur_mv[0].as_int; - mbmi->mv[1].as_int = cur_mv[1].as_int; - } - mbmi->comp_group_idx = best_comp_group_idx; - mbmi->compound_idx = best_compound_idx; - mbmi->interinter_comp = best_compound_data; - assert(IMPLIES(mbmi->comp_group_idx == 1, - mbmi->interinter_comp.type != COMPOUND_AVERAGE)); - memcpy(xd->seg_mask, tmp_best_mask_buf, 2 * num_pix * sizeof(uint8_t)); - if (have_newmv_in_inter_mode(this_mode)) { - mbmi->mv[0].as_int = best_mv[0].as_int; - mbmi->mv[1].as_int = best_mv[1].as_int; - if (use_masked_motion_search(mbmi->interinter_comp.type)) { - rd_stats->rate += best_tmp_rate_mv - rate_mv; - rate_mv = best_tmp_rate_mv; + + if (newmv_ret_val != 0) { + continue; + } else { + rd_stats->rate += rate_mv; } } + for (i = 0; i < is_comp_pred + 1; ++i) { + mbmi->mv[i].as_int = cur_mv[i].as_int; + } - if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) { - restore_dst_buf(xd, orig_dst, num_planes); - early_terminate = INT64_MAX; + // Initialise tmp_dst and orig_dst buffers to prevent "may be used + // uninitialized" warnings in GCC when the stream is monochrome. + memset(tmp_dst.plane, 0, sizeof(tmp_dst.plane)); + memset(tmp_dst.stride, 0, sizeof(tmp_dst.stride)); + memset(orig_dst.plane, 0, sizeof(tmp_dst.plane)); + memset(orig_dst.stride, 0, sizeof(tmp_dst.stride)); + + // do first prediction into the destination buffer. Do the next + // prediction into a temporary buffer. Then keep track of which one + // of these currently holds the best predictor, and use the other + // one for future predictions. In the end, copy from tmp_buf to + // dst if necessary. + for (i = 0; i < num_planes; i++) { + tmp_dst.plane[i] = tmp_buf + i * MAX_SB_SQUARE; + tmp_dst.stride[i] = MAX_SB_SIZE; + } + for (i = 0; i < num_planes; i++) { + orig_dst.plane[i] = xd->plane[i].dst.buf; + orig_dst.stride[i] = xd->plane[i].dst.stride; + } + + const int ref_mv_cost = cost_mv_ref(x, this_mode, mode_ctx); +#if USE_DISCOUNT_NEWMV_TEST + // We don't include the cost of the second reference here, because there + // are only three options: Last/Golden, ARF/Last or Golden/ARF, or in + // other words if you present them in that order, the second one is always + // known if the first is known. + // + // Under some circumstances we discount the cost of new mv mode to + // encourage initiation of a motion field. + if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) { + // discount_newmv_test only applies discount on NEWMV mode. + assert(this_mode == NEWMV); + rd_stats->rate += AOMMIN(cost_mv_ref(x, this_mode, mode_ctx), + cost_mv_ref(x, NEARESTMV, mode_ctx)); + } else { + rd_stats->rate += ref_mv_cost; + } +#else + rd_stats->rate += ref_mv_cost; +#endif + + if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd && + mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) { continue; } - compmode_interinter_cost = best_compmode_interinter_cost; - } - if (is_comp_pred) { - int tmp_rate; - int64_t tmp_dist; - av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, &orig_dst, bsize); - model_rd_for_sb(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate, - &tmp_dist, &skip_txfm_sb, &skip_sse_sb, plane_rate, - plane_sse, plane_dist); - rd = RDCOST(x->rdmult, rs + tmp_rate, tmp_dist); - } - - if (search_jnt_comp) { - // if 1/2 model rd is larger than best_rd in jnt_comp mode, - // use jnt_comp mode, save additional search - if ((rd >> 1) > best_rd) { + ret_val = interpolation_filter_search( + x, cpi, bsize, mi_row, mi_col, &tmp_dst, &orig_dst, + args->single_filter, &rd, &rs, &skip_txfm_sb, &skip_sse_sb); + if (ret_val != 0) { + restore_dst_buf(xd, orig_dst, num_planes); + continue; + } else if (cpi->sf.model_based_post_interp_filter_breakout && + ref_best_rd != INT64_MAX && (rd / 6 > ref_best_rd)) { restore_dst_buf(xd, orig_dst, num_planes); + if ((rd >> 4) > ref_best_rd) break; continue; } - } - if (!is_comp_pred) - args->single_filter[this_mode][refs[0]] = - av1_extract_interp_filter(mbmi->interp_filters, 0); + if (is_comp_pred && comp_idx) { + int64_t best_rd_compound; + compmode_interinter_cost = compound_type_rd( + cpi, x, bsize, mi_col, mi_row, cur_mv, masked_compound_used, + &orig_dst, &tmp_dst, &rate_mv, &best_rd_compound, rd_stats, + ref_best_rd); + if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) { + restore_dst_buf(xd, orig_dst, num_planes); + continue; + } + if (mbmi->interinter_comp.type != COMPOUND_AVERAGE) { + int tmp_rate; + int64_t tmp_dist; + av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, &orig_dst, + bsize); + for (int plane = 0; plane < num_planes; ++plane) + av1_subtract_plane(x, bsize, plane); + model_rd_for_sb(cpi, bsize, x, xd, 0, num_planes - 1, &tmp_rate, + &tmp_dist, &skip_txfm_sb, &skip_sse_sb, plane_rate, + plane_sse, plane_dist); + rd = RDCOST(x->rdmult, rs + tmp_rate, tmp_dist); + } + } - if (args->modelled_rd != NULL) { - if (is_comp_pred) { - const int mode0 = compound_ref0_mode(this_mode); - const int mode1 = compound_ref1_mode(this_mode); - const int64_t mrd = AOMMIN(args->modelled_rd[mode0][refs[0]], - args->modelled_rd[mode1][refs[1]]); - if (rd / 4 * 3 > mrd && ref_best_rd < INT64_MAX) { + if (search_jnt_comp) { + // if 1/2 model rd is larger than best_rd in jnt_comp mode, + // use jnt_comp mode, save additional search + if ((rd >> 1) > best_rd) { restore_dst_buf(xd, orig_dst, num_planes); - early_terminate = INT64_MAX; continue; } - } else { - args->modelled_rd[this_mode][refs[0]] = rd; } - } - if (cpi->sf.use_rd_breakout && ref_best_rd < INT64_MAX) { - // if current pred_error modeled rd is substantially more than the best - // so far, do not bother doing full rd - if (rd / 2 > ref_best_rd) { - restore_dst_buf(xd, orig_dst, num_planes); - early_terminate = INT64_MAX; - continue; + if (!is_comp_pred) + args->single_filter[this_mode][refs[0]] = + av1_extract_interp_filter(mbmi->interp_filters, 0); + + if (args->modelled_rd != NULL) { + if (is_comp_pred) { + const int mode0 = compound_ref0_mode(this_mode); + const int mode1 = compound_ref1_mode(this_mode); + const int64_t mrd = AOMMIN(args->modelled_rd[mode0][refs[0]], + args->modelled_rd[mode1][refs[1]]); + if (rd / 4 * 3 > mrd && ref_best_rd < INT64_MAX) { + restore_dst_buf(xd, orig_dst, num_planes); + continue; + } + } else { + args->modelled_rd[this_mode][refs[0]] = rd; + } } - } - rd_stats->rate += compmode_interinter_cost; + if (cpi->sf.use_rd_breakout && ref_best_rd < INT64_MAX) { + // if current pred_error modeled rd is substantially more than the best + // so far, do not bother doing full rd + if (rd / 2 > ref_best_rd) { + restore_dst_buf(xd, orig_dst, num_planes); + continue; + } + } - if (search_jnt_comp && cpi->sf.jnt_comp_fast_tx_search && comp_idx == 0) { - // TODO(chengchen): this speed feature introduces big loss. - // Need better estimation of rate distortion. - rd_stats->rate += rs; - rd_stats->rate += plane_rate[0] + plane_rate[1] + plane_rate[2]; - rd_stats_y->rate = plane_rate[0]; - rd_stats_uv->rate = plane_rate[1] + plane_rate[2]; - rd_stats->sse = plane_sse[0] + plane_sse[1] + plane_sse[2]; - rd_stats_y->sse = plane_sse[0]; - rd_stats_uv->sse = plane_sse[1] + plane_sse[2]; - rd_stats->dist = plane_dist[0] + plane_dist[1] + plane_dist[2]; - rd_stats_y->dist = plane_dist[0]; - rd_stats_uv->dist = plane_dist[1] + plane_dist[2]; - } else { + rd_stats->rate += compmode_interinter_cost; + + if (search_jnt_comp && cpi->sf.jnt_comp_fast_tx_search && comp_idx == 0) { + // TODO(chengchen): this speed feature introduces big loss. + // Need better estimation of rate distortion. + rd_stats->rate += rs; + rd_stats->rate += plane_rate[0] + plane_rate[1] + plane_rate[2]; + rd_stats_y->rate = plane_rate[0]; + rd_stats_uv->rate = plane_rate[1] + plane_rate[2]; + rd_stats->sse = plane_sse[0] + plane_sse[1] + plane_sse[2]; + rd_stats_y->sse = plane_sse[0]; + rd_stats_uv->sse = plane_sse[1] + plane_sse[2]; + rd_stats->dist = plane_dist[0] + plane_dist[1] + plane_dist[2]; + rd_stats_y->dist = plane_dist[0]; + rd_stats_uv->dist = plane_dist[1] + plane_dist[2]; + } else { #if CONFIG_COLLECT_INTER_MODE_RD_STATS - ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv, - disable_skip, mi_row, mi_col, args, ref_best_rd, - refs, rate_mv, &orig_dst, best_est_rd); + ret_val = + motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv, + disable_skip, mi_row, mi_col, args, ref_best_rd, + refs, rate_mv, &orig_dst, best_est_rd); #else - ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv, - disable_skip, mi_row, mi_col, args, ref_best_rd, - refs, rate_mv, &orig_dst); + ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, + rd_stats_uv, disable_skip, mi_row, mi_col, + args, ref_best_rd, refs, rate_mv, &orig_dst); #endif - } - if (ret_val != INT64_MAX) { - if (search_jnt_comp) { + } + if (ret_val != INT64_MAX) { int64_t tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); if (tmp_rd < best_rd) { best_rd_stats = *rd_stats; best_rd_stats_y = *rd_stats_y; best_rd_stats_uv = *rd_stats_uv; - best_ret_val = ret_val; best_rd = tmp_rd; best_mbmi = *mbmi; + best_disable_skip = *disable_skip; + best_xskip = x->skip; memcpy(best_blk_skip, x->blk_skip, sizeof(best_blk_skip[0]) * xd->n8_h * xd->n8_w); } + + if (tmp_rd < best_rd2) { + best_rd2 = tmp_rd; + } + if (tmp_rd < ref_best_rd) { ref_best_rd = tmp_rd; } } - } - if (!search_jnt_comp && ret_val != 0) { restore_dst_buf(xd, orig_dst, num_planes); - return ret_val; } - restore_dst_buf(xd, orig_dst, num_planes); + + args->modelled_rd = NULL; } + if (best_rd == INT64_MAX) return INT64_MAX; + // re-instate status of the best choice - if (is_comp_pred && best_ret_val != INT64_MAX) { - *rd_stats = best_rd_stats; - *rd_stats_y = best_rd_stats_y; - *rd_stats_uv = best_rd_stats_uv; - ret_val = best_ret_val; - *mbmi = best_mbmi; - assert(IMPLIES(mbmi->comp_group_idx == 1, - mbmi->interinter_comp.type != COMPOUND_AVERAGE)); - memcpy(x->blk_skip, best_blk_skip, - sizeof(best_blk_skip[0]) * xd->n8_h * xd->n8_w); - } - if (early_terminate == INT64_MAX) return INT64_MAX; - if (ret_val != 0) return ret_val; + *rd_stats = best_rd_stats; + *rd_stats_y = best_rd_stats_y; + *rd_stats_uv = best_rd_stats_uv; + *mbmi = best_mbmi; + *disable_skip = best_disable_skip; + x->skip = best_xskip; + assert(IMPLIES(mbmi->comp_group_idx == 1, + mbmi->interinter_comp.type != COMPOUND_AVERAGE)); + memcpy(x->blk_skip, best_blk_skip, + sizeof(best_blk_skip[0]) * xd->n8_h * xd->n8_w); + return RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist); } @@ -8822,6 +9110,13 @@ static int64_t rd_pick_intrabc_mode_sb(const AV1_COMP *cpi, MACROBLOCK *x, av1_find_best_ref_mvs_from_stack(0, mbmi_ext, ref_frame, &nearestmv, &nearmv, 0); + if (nearestmv.as_int == INVALID_MV) { + nearestmv.as_int = 0; + } + if (nearmv.as_int == INVALID_MV) { + nearmv.as_int = 0; + } + int_mv dv_ref = nearestmv.as_int == 0 ? nearmv : nearestmv; if (dv_ref.as_int == 0) av1_find_ref_dv(&dv_ref, tile, cm->seq_params.mib_size, mi_row, mi_col); @@ -9013,8 +9308,9 @@ void av1_rd_pick_intra_mode_sb(const AV1_COMP *cpi, MACROBLOCK *x, int mi_row, if (intra_yrd < best_rd) { // Only store reconstructed luma when there's chroma RDO. When there's no // chroma RDO, the reconstructed luma will be stored in encode_superblock(). - xd->cfl.is_chroma_reference = is_chroma_reference( - mi_row, mi_col, bsize, cm->subsampling_x, cm->subsampling_y); + xd->cfl.is_chroma_reference = + is_chroma_reference(mi_row, mi_col, bsize, cm->seq_params.subsampling_x, + cm->seq_params.subsampling_y); xd->cfl.store_y = store_cfl_required_rdo(cm, x); if (xd->cfl.store_y) { // Restore reconstructed luma values. @@ -9081,7 +9377,7 @@ static void restore_uv_color_map(const AV1_COMP *const cpi, MACROBLOCK *x) { for (r = 0; r < rows; ++r) { for (c = 0; c < cols; ++c) { - if (cpi->common.use_highbitdepth) { + if (cpi->common.seq_params.use_highbitdepth) { data[(r * cols + c) * 2] = src_u16[r * src_stride + c]; data[(r * cols + c) * 2 + 1] = src_v16[r * src_stride + c]; } else { @@ -9760,6 +10056,8 @@ static int inter_mode_search_order_independent_skip( if (comp_pred) { if (!cpi->allow_comp_inter_inter) return 1; + if (cm->reference_mode == SINGLE_REFERENCE) return 1; + // Skip compound inter modes if ARF is not available. if (!(cpi->ref_frame_flags & ref_frame_flag_list[ref_frame[1]])) return 1; @@ -9857,7 +10155,7 @@ static int handle_intra_mode(InterModeSearchState *search_state, av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type); const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]]; const int intra_cost_penalty = av1_get_intra_cost_penalty( - cm->base_qindex, cm->y_dc_delta_q, cm->bit_depth); + cm->base_qindex, cm->y_dc_delta_q, cm->seq_params.bit_depth); const int rows = block_size_high[bsize]; const int cols = block_size_wide[bsize]; const int num_planes = av1_num_planes(cm); @@ -10050,7 +10348,6 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, const int try_palette = av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type); PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info; - MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext; const struct segmentation *const seg = &cm->seg; PREDICTION_MODE this_mode; MV_REFERENCE_FRAME ref_frame, second_ref_frame; @@ -10097,7 +10394,6 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, int64_t distortion2 = 0; int skippable = 0; int this_skip2 = 0; - uint8_t ref_frame_type; this_mode = av1_mode_order[mode_index].mode; ref_frame = av1_mode_order[mode_index].ref_frame[0]; @@ -10195,7 +10491,6 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, mbmi->angle_delta[PLANE_TYPE_UV] = 0; mbmi->filter_intra_mode_info.use_filter_intra = 0; mbmi->ref_mv_idx = 0; - ref_frame_type = av1_ref_frame_type(mbmi->ref_frame); int64_t ref_best_rd = search_state.best_rd; { RD_STATS rd_stats, rd_stats_y, rd_stats_uv; @@ -10203,9 +10498,9 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, rd_stats.rate = rate2; // Point to variables that are maintained between loop iterations - args.single_newmv = search_state.single_newmv[0]; - args.single_newmv_rate = search_state.single_newmv_rate[0]; - args.single_newmv_valid = search_state.single_newmv_valid[0]; + args.single_newmv = search_state.single_newmv; + args.single_newmv_rate = search_state.single_newmv_rate; + args.single_newmv_valid = search_state.single_newmv_valid; args.modelled_rd = search_state.modelled_rd; args.single_comp_cost = real_compmode_cost; args.ref_frame_cost = ref_frame_cost; @@ -10218,10 +10513,6 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, &rd_stats_uv, &disable_skip, mi_row, mi_col, &args, ref_best_rd); #endif - if (this_rd < ref_best_rd) { - ref_best_rd = this_rd; - } - rate2 = rd_stats.rate; skippable = rd_stats.skip; distortion2 = rd_stats.dist; @@ -10229,108 +10520,6 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, rate_uv = rd_stats_uv.rate; } - // TODO(jingning): This needs some refactoring to improve code quality - // and reduce redundant steps. - if ((have_nearmv_in_inter_mode(mbmi->mode) && - mbmi_ext->ref_mv_count[ref_frame_type] > 2) || - ((mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) && - mbmi_ext->ref_mv_count[ref_frame_type] > 1)) { - MB_MODE_INFO backup_mbmi = *mbmi; - int backup_skip = x->skip; - int64_t tmp_ref_rd = this_rd; - int ref_idx; - - // TODO(jingning): This should be deprecated shortly. - int idx_offset = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0; - int ref_set = - AOMMIN(MAX_REF_MV_SERCH - 1, - mbmi_ext->ref_mv_count[ref_frame_type] - 1 - idx_offset); - memcpy(x->blk_skip_drl, x->blk_skip, - sizeof(x->blk_skip[0]) * ctx->num_4x4_blk); - - for (ref_idx = 0; ref_idx < ref_set; ++ref_idx) { - int64_t tmp_alt_rd = INT64_MAX; - int dummy_disable_skip = 0; - int_mv cur_mv; - RD_STATS tmp_rd_stats, tmp_rd_stats_y, tmp_rd_stats_uv; - - av1_invalid_rd_stats(&tmp_rd_stats); - - x->skip = 0; - - mbmi->ref_mv_idx = 1 + ref_idx; - - if (cpi->sf.reduce_inter_modes) { - if (mbmi->ref_frame[0] == LAST2_FRAME || - mbmi->ref_frame[0] == LAST3_FRAME || - mbmi->ref_frame[1] == LAST2_FRAME || - mbmi->ref_frame[1] == LAST3_FRAME) { - if (mbmi_ext - ->ref_mv_stack[ref_frame_type] - [mbmi->ref_mv_idx + idx_offset] - .weight < REF_CAT_LEVEL) { - *mbmi = backup_mbmi; - x->skip = backup_skip; - continue; - } - } - } - - cur_mv = - mbmi_ext->ref_mv_stack[ref_frame][mbmi->ref_mv_idx + idx_offset] - .this_mv; - clamp_mv2(&cur_mv.as_mv, xd); - - if (!mv_check_bounds(&x->mv_limits, &cur_mv.as_mv)) { - av1_init_rd_stats(&tmp_rd_stats); - - args.modelled_rd = NULL; - args.single_newmv = search_state.single_newmv[mbmi->ref_mv_idx]; - args.single_newmv_rate = - search_state.single_newmv_rate[mbmi->ref_mv_idx]; - args.single_newmv_valid = - search_state.single_newmv_valid[mbmi->ref_mv_idx]; - args.single_comp_cost = real_compmode_cost; - args.ref_frame_cost = ref_frame_cost; -#if CONFIG_COLLECT_INTER_MODE_RD_STATS - tmp_alt_rd = - handle_inter_mode(cpi, x, bsize, &tmp_rd_stats, &tmp_rd_stats_y, - &tmp_rd_stats_uv, &dummy_disable_skip, mi_row, - mi_col, &args, ref_best_rd, &best_est_rd); -#else - tmp_alt_rd = handle_inter_mode( - cpi, x, bsize, &tmp_rd_stats, &tmp_rd_stats_y, &tmp_rd_stats_uv, - &dummy_disable_skip, mi_row, mi_col, &args, ref_best_rd); -#endif - - // Prevent pointers from escaping local scope - args.single_newmv = search_state.single_newmv[0]; - args.single_newmv_rate = search_state.single_newmv_rate[0]; - args.single_newmv_valid = search_state.single_newmv_valid[0]; - } - - if (tmp_ref_rd > tmp_alt_rd) { - rate2 = tmp_rd_stats.rate; - disable_skip = dummy_disable_skip; - distortion2 = tmp_rd_stats.dist; - skippable = tmp_rd_stats.skip; - rate_y = tmp_rd_stats_y.rate; - rate_uv = tmp_rd_stats_uv.rate; - this_rd = tmp_alt_rd; - tmp_ref_rd = tmp_alt_rd; - backup_mbmi = *mbmi; - backup_skip = x->skip; - memcpy(x->blk_skip_drl, x->blk_skip, - sizeof(x->blk_skip[0]) * ctx->num_4x4_blk); - } else { - *mbmi = backup_mbmi; - x->skip = backup_skip; - } - } - - memcpy(x->blk_skip, x->blk_skip_drl, - sizeof(x->blk_skip[0]) * ctx->num_4x4_blk); - } if (this_rd == INT64_MAX) continue; this_skip2 = mbmi->skip; |