Ejemplo n.º 1
0
export function residualDown(x: tf.Tensor4D, params: ResidualLayerParams): tf.Tensor4D {
  let out = convDown(x, params.conv1)
  out = convNoRelu(out, params.conv2)

  let pooled = tf.avgPool(x, 2, 2, 'valid') as tf.Tensor4D
  const zeros = tf.zeros<tf.Rank.R4>(pooled.shape)
  const isPad = pooled.shape[3] !== out.shape[3]
  const isAdjustShape = pooled.shape[1] !== out.shape[1] || pooled.shape[2] !== out.shape[2]

  if (isAdjustShape) {
    const padShapeX = [...out.shape] as [number, number, number, number]
    padShapeX[1] = 1
    const zerosW = tf.zeros<tf.Rank.R4>(padShapeX)
    out = tf.concat([out, zerosW], 1)

    const padShapeY = [...out.shape] as [number, number, number, number]
    padShapeY[2] = 1
    const zerosH = tf.zeros<tf.Rank.R4>(padShapeY)
    out = tf.concat([out, zerosH], 2)
  }

  pooled = isPad ? tf.concat([pooled, zeros], 3) : pooled
  out = tf.add(pooled, out) as tf.Tensor4D

  out = tf.relu(out)
  return out
}
Ejemplo n.º 2
0
    (node: Node, tensorMap: NamedTensorsMap,
     context: ExecutionContext): tfc.Tensor[] => {
      switch (node.op) {
        case 'conv1d': {
          const stride =
              getParamValue('stride', node, tensorMap, context) as number;
          const pad = getParamValue('pad', node, tensorMap, context);
          const dataFormat =
              (getParamValue('dataFormat', node, tensorMap, context) as string)
                  .toUpperCase();
          const dilation =
              getParamValue('dilation', node, tensorMap, context) as number;
          return [tfc.conv1d(
              getParamValue('x', node, tensorMap, context) as tfc.Tensor3D,
              getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D,
              stride, pad as 'valid' | 'same', dataFormat as 'NWC' | 'NCW',
              dilation)];
        }
        case 'conv2d': {
          const stride =
              getParamValue('strides', node, tensorMap, context) as number[];
          const pad = getParamValue('pad', node, tensorMap, context);
          const dataFormat =
              (getParamValue('dataFormat', node, tensorMap, context) as string)
                  .toUpperCase();
          const dilations =
              getParamValue('dilations', node, tensorMap, context) as number[];
          return [tfc.conv2d(
              getParamValue('x', node, tensorMap, context) as tfc.Tensor3D |
                  tfc.Tensor4D,
              getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D,
              [stride[1], stride[2]], pad as 'valid' | 'same',
              dataFormat as 'NHWC' | 'NCHW', [dilations[0], dilations[1]])];
        }
        case 'conv2dTranspose': {
          const shape = getParamValue(
                            'outputShape', node, tensorMap,
                            context) as [number, number, number] |
              [number, number, number, number];
          const stride =
              getParamValue('strides', node, tensorMap, context) as number[];
          const pad = getParamValue('pad', node, tensorMap, context);
          return [tfc.conv2dTranspose(
              getParamValue('x', node, tensorMap, context) as tfc.Tensor3D |
                  tfc.Tensor4D,
              getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D,
              shape, [stride[1], stride[2]], pad as 'valid' | 'same')];
        }
        case 'depthwiseConv2d': {
          const stride =
              getParamValue('strides', node, tensorMap, context) as number[];
          const pad = getParamValue('pad', node, tensorMap, context);
          const dilations =
              getParamValue('dilations', node, tensorMap, context) as number[];
          const dataFormat =
              (getParamValue('dataFormat', node, tensorMap, context) as string)
                  .toUpperCase();

          return [tfc.depthwiseConv2d(
              getParamValue('input', node, tensorMap, context) as tfc.Tensor3D |
                  tfc.Tensor4D,
              getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D,
              [stride[1], stride[2]], pad as 'valid' | 'same',
              dataFormat as 'NHWC' | 'NCHW', [dilations[0], dilations[1]])];
        }

        case 'avgPool': {
          const stride =
              getParamValue('strides', node, tensorMap, context) as number[];
          const pad = getParamValue('pad', node, tensorMap, context);
          const kernelSize =
              getParamValue('kernelSize', node, tensorMap, context) as number[];

          return [tfc.avgPool(
              getParamValue('x', node, tensorMap, context) as tfc.Tensor3D |
                  tfc.Tensor4D,
              [kernelSize[1], kernelSize[2]], [stride[1], stride[2]],
              pad as 'valid' | 'same')];
        }

        case 'maxPool': {
          const stride =
              getParamValue('strides', node, tensorMap, context) as number[];
          const pad = getParamValue('pad', node, tensorMap, context);
          const kernelSize =
              getParamValue('kernelSize', node, tensorMap, context) as number[];

          return [tfc.maxPool(
              getParamValue('x', node, tensorMap, context) as tfc.Tensor3D |
                  tfc.Tensor4D,
              [kernelSize[1], kernelSize[2]], [stride[1], stride[2]],
              pad as 'valid' | 'same')];
        }
        default:
          throw TypeError(`Node type ${node.op} is not implemented`);
      }
    };