summaryrefslogtreecommitdiff
path: root/third_party/aom/av1/common/av1_txfm.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/aom/av1/common/av1_txfm.h')
-rw-r--r--third_party/aom/av1/common/av1_txfm.h197
1 files changed, 167 insertions, 30 deletions
diff --git a/third_party/aom/av1/common/av1_txfm.h b/third_party/aom/av1/common/av1_txfm.h
index 269ef5705a..bd365de59a 100644
--- a/third_party/aom/av1/common/av1_txfm.h
+++ b/third_party/aom/av1/common/av1_txfm.h
@@ -17,9 +17,16 @@
#include <stdio.h>
#include "av1/common/enums.h"
+#include "av1/common/blockd.h"
#include "aom/aom_integer.h"
#include "aom_dsp/aom_dsp_common.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define MAX_TXFM_STAGE_NUM 12
+
static const int cos_bit_min = 10;
static const int cos_bit_max = 16;
@@ -110,27 +117,6 @@ static INLINE int32_t half_btf(int32_t w0, int32_t in0, int32_t w1, int32_t in1,
return round_shift(result_32, bit);
}
-static INLINE int get_max_bit(int x) {
- int max_bit = -1;
- while (x) {
- x = x >> 1;
- max_bit++;
- }
- return max_bit;
-}
-
-// TODO(angiebird): implement SSE
-static INLINE void clamp_block(int16_t *block, int block_size_row,
- int block_size_col, int stride, int low,
- int high) {
- int i, j;
- for (i = 0; i < block_size_row; ++i) {
- for (j = 0; j < block_size_col; ++j) {
- block[i * stride + j] = clamp(block[i * stride + j], low, high);
- }
- }
-}
-
typedef void (*TxfmFunc)(const int32_t *input, int32_t *output,
const int8_t *cos_bit, const int8_t *stage_range);
@@ -148,6 +134,7 @@ typedef enum TXFM_TYPE {
TXFM_TYPE_IDENTITY8,
TXFM_TYPE_IDENTITY16,
TXFM_TYPE_IDENTITY32,
+ TXFM_TYPE_IDENTITY64,
} TXFM_TYPE;
typedef struct TXFM_1D_CFG {
@@ -167,7 +154,7 @@ typedef struct TXFM_2D_FLIP_CFG {
const TXFM_1D_CFG *row_cfg;
} TXFM_2D_FLIP_CFG;
-static INLINE void set_flip_cfg(int tx_type, TXFM_2D_FLIP_CFG *cfg) {
+static INLINE void set_flip_cfg(TX_TYPE tx_type, TXFM_2D_FLIP_CFG *cfg) {
switch (tx_type) {
case DCT_DCT:
case ADST_DCT:
@@ -209,21 +196,171 @@ static INLINE void set_flip_cfg(int tx_type, TXFM_2D_FLIP_CFG *cfg) {
}
}
+#if CONFIG_TXMG
+static INLINE TX_SIZE av1_rotate_tx_size(TX_SIZE tx_size) {
+ switch (tx_size) {
+#if CONFIG_CHROMA_2X2
+ case TX_2X2: return TX_2X2;
+#endif
+ case TX_4X4: return TX_4X4;
+ case TX_8X8: return TX_8X8;
+ case TX_16X16: return TX_16X16;
+ case TX_32X32: return TX_32X32;
+#if CONFIG_TX64X64
+ case TX_64X64: return TX_64X64;
+ case TX_32X64: return TX_64X32;
+ case TX_64X32: return TX_32X64;
+#endif
+ case TX_4X8: return TX_8X4;
+ case TX_8X4: return TX_4X8;
+ case TX_8X16: return TX_16X8;
+ case TX_16X8: return TX_8X16;
+ case TX_16X32: return TX_32X16;
+ case TX_32X16: return TX_16X32;
+ case TX_4X16: return TX_16X4;
+ case TX_16X4: return TX_4X16;
+ case TX_8X32: return TX_32X8;
+ case TX_32X8: return TX_8X32;
+ default: assert(0); return TX_INVALID;
+ }
+}
+
+static INLINE TX_TYPE av1_rotate_tx_type(TX_TYPE tx_type) {
+ switch (tx_type) {
+ case DCT_DCT: return DCT_DCT;
+ case ADST_DCT: return DCT_ADST;
+ case DCT_ADST: return ADST_DCT;
+ case ADST_ADST: return ADST_ADST;
+#if CONFIG_EXT_TX
+ case FLIPADST_DCT: return DCT_FLIPADST;
+ case DCT_FLIPADST: return FLIPADST_DCT;
+ case FLIPADST_FLIPADST: return FLIPADST_FLIPADST;
+ case ADST_FLIPADST: return FLIPADST_ADST;
+ case FLIPADST_ADST: return ADST_FLIPADST;
+ case IDTX: return IDTX;
+ case V_DCT: return H_DCT;
+ case H_DCT: return V_DCT;
+ case V_ADST: return H_ADST;
+ case H_ADST: return V_ADST;
+ case V_FLIPADST: return H_FLIPADST;
+ case H_FLIPADST: return V_FLIPADST;
+#endif // CONFIG_EXT_TX
+#if CONFIG_MRC_TX
+ case MRC_DCT: return MRC_DCT;
+#endif // CONFIG_MRC_TX
+ default: assert(0); return TX_TYPES;
+ }
+}
+#endif // CONFIG_TXMG
+
#if CONFIG_MRC_TX
-static INLINE void get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
- int mask_stride, int width, int height) {
+static INLINE int get_mrc_diff_mask_inter(const int16_t *diff, int diff_stride,
+ uint8_t *mask, int mask_stride,
+ int width, int height) {
+ // placeholder mask generation function
+ assert(SIGNAL_MRC_MASK_INTER);
+ int n_masked_vals = 0;
for (int i = 0; i < height; ++i) {
- for (int j = 0; j < width; ++j)
+ for (int j = 0; j < width; ++j) {
+ mask[i * mask_stride + j] = diff[i * diff_stride + j] > 100 ? 1 : 0;
+ n_masked_vals += mask[i * mask_stride + j];
+ }
+ }
+ return n_masked_vals;
+}
+
+static INLINE int get_mrc_pred_mask_inter(const uint8_t *pred, int pred_stride,
+ uint8_t *mask, int mask_stride,
+ int width, int height) {
+ // placeholder mask generation function
+ int n_masked_vals = 0;
+ for (int i = 0; i < height; ++i) {
+ for (int j = 0; j < width; ++j) {
+ mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
+ n_masked_vals += mask[i * mask_stride + j];
+ }
+ }
+ return n_masked_vals;
+}
+
+static INLINE int get_mrc_diff_mask_intra(const int16_t *diff, int diff_stride,
+ uint8_t *mask, int mask_stride,
+ int width, int height) {
+ // placeholder mask generation function
+ assert(SIGNAL_MRC_MASK_INTRA);
+ int n_masked_vals = 0;
+ for (int i = 0; i < height; ++i) {
+ for (int j = 0; j < width; ++j) {
+ mask[i * mask_stride + j] = diff[i * diff_stride + j] > 100 ? 1 : 0;
+ n_masked_vals += mask[i * mask_stride + j];
+ }
+ }
+ return n_masked_vals;
+}
+
+static INLINE int get_mrc_pred_mask_intra(const uint8_t *pred, int pred_stride,
+ uint8_t *mask, int mask_stride,
+ int width, int height) {
+ // placeholder mask generation function
+ int n_masked_vals = 0;
+ for (int i = 0; i < height; ++i) {
+ for (int j = 0; j < width; ++j) {
mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
+ n_masked_vals += mask[i * mask_stride + j];
+ }
+ }
+ return n_masked_vals;
+}
+
+static INLINE int get_mrc_diff_mask(const int16_t *diff, int diff_stride,
+ uint8_t *mask, int mask_stride, int width,
+ int height, int is_inter) {
+ if (is_inter) {
+ assert(USE_MRC_INTER && "MRC invalid for inter blocks");
+ assert(SIGNAL_MRC_MASK_INTER);
+ return get_mrc_diff_mask_inter(diff, diff_stride, mask, mask_stride, width,
+ height);
+ } else {
+ assert(USE_MRC_INTRA && "MRC invalid for intra blocks");
+ assert(SIGNAL_MRC_MASK_INTRA);
+ return get_mrc_diff_mask_intra(diff, diff_stride, mask, mask_stride, width,
+ height);
+ }
+}
+
+static INLINE int get_mrc_pred_mask(const uint8_t *pred, int pred_stride,
+ uint8_t *mask, int mask_stride, int width,
+ int height, int is_inter) {
+ if (is_inter) {
+ assert(USE_MRC_INTER && "MRC invalid for inter blocks");
+ return get_mrc_pred_mask_inter(pred, pred_stride, mask, mask_stride, width,
+ height);
+ } else {
+ assert(USE_MRC_INTRA && "MRC invalid for intra blocks");
+ return get_mrc_pred_mask_intra(pred, pred_stride, mask, mask_stride, width,
+ height);
}
}
+
+static INLINE int is_valid_mrc_mask(int n_masked_vals, int width, int height) {
+ return !(n_masked_vals == 0 || n_masked_vals == (width * height));
+}
#endif // CONFIG_MRC_TX
-#ifdef __cplusplus
-extern "C" {
-#endif
-TXFM_2D_FLIP_CFG av1_get_fwd_txfm_cfg(int tx_type, int tx_size);
-TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x64_cfg(int tx_type);
+void av1_gen_fwd_stage_range(int8_t *stage_range_col, int8_t *stage_range_row,
+ const TXFM_2D_FLIP_CFG *cfg, int bd);
+
+void av1_gen_inv_stage_range(int8_t *stage_range_col, int8_t *stage_range_row,
+ const TXFM_2D_FLIP_CFG *cfg, int8_t fwd_shift,
+ int bd);
+
+TXFM_2D_FLIP_CFG av1_get_fwd_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
+#if CONFIG_TX64X64
+TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x64_cfg(TX_TYPE tx_type);
+TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x32_cfg(TX_TYPE tx_type);
+TXFM_2D_FLIP_CFG av1_get_fwd_txfm_32x64_cfg(TX_TYPE tx_type);
+#endif // CONFIG_TX64X64
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
#ifdef __cplusplus
}
#endif // __cplusplus