//
// This is a modified version of the SSAO renderer from Microsoft's MiniEngine
// library. The copyright notice from the original version is included below.
//
// The original source code of MiniEngine is available on GitHub.
// https://github.com/Microsoft/DirectX-Graphics-Samples
//

//
// Copyright (c) Microsoft. All rights reserved.
// This code is licensed under the MIT License (MIT).
// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
//
// Developed by Minigraph
//
// Author:  James Stanard
//

#pragma warning(disable : 3568)
#pragma exclude_renderers gles gles3 d3d11_9x

#pragma kernel MultiScaleVORender                           MAIN=MultiScaleVORender
#pragma kernel MultiScaleVORender_interleaved               MAIN=MultiScaleVORender_interleaved         INTERLEAVE_RESULT
#pragma kernel MultiScaleVORender_MSAA                      MAIN=MultiScaleVORender_MSAA                MSAA
#pragma kernel MultiScaleVORender_MSAA_interleaved          MAIN=MultiScaleVORender_MSAA_interleaved    MSAA INTERLEAVE_RESULT

#include "Packages/com.unity.postprocessing/PostProcessing/Shaders/StdLib.hlsl"

#ifndef INTERLEAVE_RESULT
#define WIDE_SAMPLING 1
#endif

#if WIDE_SAMPLING
// 32x32 cache size:  the 16x16 in the center forms the area of focus with the 8-pixel perimeter used for wide gathering.
#define TILE_DIM 32
#define THREAD_COUNT_X 16
#define THREAD_COUNT_Y 16
#else
// 16x16 cache size:  the 8x8 in the center forms the area of focus with the 4-pixel perimeter used for gathering.
#define TILE_DIM 16
#define THREAD_COUNT_X 8
#define THREAD_COUNT_Y 8
#endif


#ifdef MSAA
    // Input Textures
    #ifdef INTERLEAVE_RESULT
    Texture2DArray<float2> DepthTex;
    #else
    Texture2D<float2> DepthTex;
    #endif

    // Output texture
    RWTexture2D<float2> Occlusion;

    // Shared memory
    groupshared float2 DepthSamples[TILE_DIM * TILE_DIM];
#else
    // Input Textures
    #ifdef INTERLEAVE_RESULT
    Texture2DArray<float> DepthTex;
    #else
    Texture2D<float> DepthTex;
    #endif

    // Output texture
    RWTexture2D<float> Occlusion;

    // Shared memory
    groupshared float DepthSamples[TILE_DIM * TILE_DIM];
#endif

SamplerState samplerDepthTex;

CBUFFER_START(CB1)
    float4 gInvThicknessTable[3];
    float4 gSampleWeightTable[3];
    float4 gInvSliceDimension;
    float2 AdditionalParams;
CBUFFER_END

#define gRejectFadeoff AdditionalParams.x
#define gIntensity AdditionalParams.y

#ifdef MSAA
float2 TestSamplePair(float frontDepth, float2 invRange, uint base, int offset)
{
    // "Disocclusion" measures the penetration distance of the depth sample within the sphere.
    // Disocclusion < 0 (full occlusion) -> the sample fell in front of the sphere
    // Disocclusion > 1 (no occlusion) -> the sample fell behind the sphere
    float2 disocclusion1 = DepthSamples[base + offset] * invRange - frontDepth;
    float2 disocclusion2 = DepthSamples[base - offset] * invRange - frontDepth;

    float2 pseudoDisocclusion1 = saturate(gRejectFadeoff * disocclusion1);
    float2 pseudoDisocclusion2 = saturate(gRejectFadeoff * disocclusion2);

    return saturate(
        clamp(disocclusion1, pseudoDisocclusion2, 1.0) +
        clamp(disocclusion2, pseudoDisocclusion1, 1.0) -
        pseudoDisocclusion1 * pseudoDisocclusion2);
}

float2 TestSamples(uint centerIdx, uint x, uint y, float2 invDepth, float invThickness)
{
#if WIDE_SAMPLING
    x <<= 1;
    y <<= 1;
#endif

    float2 invRange = invThickness * invDepth;
    float frontDepth = invThickness - 0.5;

    if (y == 0)
    {
        // Axial
        return 0.5 * (
            TestSamplePair(frontDepth, invRange, centerIdx, x) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM)
        );
    }
    else if (x == y)
    {
        // Diagonal
        return 0.5 * (
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM - x) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM + x)
        );
    }
    else
    {
        // L-Shaped
        return 0.25 * (
            TestSamplePair(frontDepth, invRange, centerIdx, y * TILE_DIM + x) +
            TestSamplePair(frontDepth, invRange, centerIdx, y * TILE_DIM - x) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM + y) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM - y)
        );
    }
}
#else
float TestSamplePair(float frontDepth, float invRange, uint base, int offset)
{
    // "Disocclusion" measures the penetration distance of the depth sample within the sphere.
    // Disocclusion < 0 (full occlusion) -> the sample fell in front of the sphere
    // Disocclusion > 1 (no occlusion) -> the sample fell behind the sphere
    float disocclusion1 = DepthSamples[base + offset] * invRange - frontDepth;
    float disocclusion2 = DepthSamples[base - offset] * invRange - frontDepth;

    float pseudoDisocclusion1 = saturate(gRejectFadeoff * disocclusion1);
    float pseudoDisocclusion2 = saturate(gRejectFadeoff * disocclusion2);

    return saturate(
        clamp(disocclusion1, pseudoDisocclusion2, 1.0) +
        clamp(disocclusion2, pseudoDisocclusion1, 1.0) -
        pseudoDisocclusion1 * pseudoDisocclusion2);
}

float TestSamples(uint centerIdx, uint x, uint y, float invDepth, float invThickness)
{
#if WIDE_SAMPLING
    x <<= 1;
    y <<= 1;
#endif

    float invRange = invThickness * invDepth;
    float frontDepth = invThickness - 0.5;

    if (y == 0)
    {
        // Axial
        return 0.5 * (
            TestSamplePair(frontDepth, invRange, centerIdx, x) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM)
        );
    }
    else if (x == y)
    {
        // Diagonal
        return 0.5 * (
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM - x) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM + x)
        );
    }
    else
    {
        // L-Shaped
        return 0.25 * (
            TestSamplePair(frontDepth, invRange, centerIdx, y * TILE_DIM + x) +
            TestSamplePair(frontDepth, invRange, centerIdx, y * TILE_DIM - x) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM + y) +
            TestSamplePair(frontDepth, invRange, centerIdx, x * TILE_DIM - y)
        );
    }
}
#endif

#ifdef DISABLE_COMPUTE_SHADERS

TRIVIAL_COMPUTE_KERNEL(MAIN)

#else

[numthreads(THREAD_COUNT_X, THREAD_COUNT_Y, 1)]
void MAIN(uint3 Gid : SV_GroupID, uint GI : SV_GroupIndex, uint3 GTid : SV_GroupThreadID, uint3 DTid : SV_DispatchThreadID)
{
#if WIDE_SAMPLING
    float2 QuadCenterUV = int2(DTid.xy + GTid.xy - 7) * gInvSliceDimension.xy;
#else
    float2 QuadCenterUV = int2(DTid.xy + GTid.xy - 3) * gInvSliceDimension.xy;
#endif

#ifdef MSAA
    // Fetch four depths and store them in LDS
#ifdef INTERLEAVE_RESULT
    float4 depths0 = DepthTex.GatherRed(samplerDepthTex, float3(QuadCenterUV, DTid.z));
    float4 depths1 = DepthTex.GatherGreen(samplerDepthTex, float3(QuadCenterUV, DTid.z));
#else
    float4 depths0 = DepthTex.GatherRed(samplerDepthTex, QuadCenterUV);
    float4 depths1 = DepthTex.GatherGreen(samplerDepthTex, QuadCenterUV);
#endif
    int destIdx = GTid.x * 2 + GTid.y * 2 * TILE_DIM;
    DepthSamples[destIdx] = float2(depths0.w, depths1.w);
    DepthSamples[destIdx + 1] = float2(depths0.z, depths1.z);
    DepthSamples[destIdx + TILE_DIM] = float2(depths0.x, depths1.x);
    DepthSamples[destIdx + TILE_DIM + 1] = float2(depths0.y, depths1.y);
#else
#ifdef INTERLEAVE_RESULT
    float4 depths = DepthTex.Gather(samplerDepthTex, float3(QuadCenterUV, DTid.z));
#else
    float4 depths = DepthTex.Gather(samplerDepthTex, QuadCenterUV);
#endif
    int destIdx = GTid.x * 2 + GTid.y * 2 * TILE_DIM;
    DepthSamples[destIdx] = depths.w;
    DepthSamples[destIdx + 1] = depths.z;
    DepthSamples[destIdx + TILE_DIM] = depths.x;
    DepthSamples[destIdx + TILE_DIM + 1] = depths.y;
#endif

    GroupMemoryBarrierWithGroupSync();

#if WIDE_SAMPLING
    uint thisIdx = GTid.x + GTid.y * TILE_DIM + 8 * TILE_DIM + 8;
#else
    uint thisIdx = GTid.x + GTid.y * TILE_DIM + 4 * TILE_DIM + 4;
#endif

#ifdef MSAA
    const float2 invThisDepth = float2(1.0 / DepthSamples[thisIdx].x, 1.0 / DepthSamples[thisIdx].y);
    float2 ao = 0.0;
#else
    const float invThisDepth = 1.0 / DepthSamples[thisIdx];
    float ao = 0.0;
#endif


//#define SAMPLE_EXHAUSTIVELY

#ifdef SAMPLE_EXHAUSTIVELY
    // 68 samples:  sample all cells in *within* a circular radius of 5
    ao += gSampleWeightTable[0].x * TestSamples(thisIdx, 1, 0, invThisDepth, gInvThicknessTable[0].x);
    ao += gSampleWeightTable[0].y * TestSamples(thisIdx, 2, 0, invThisDepth, gInvThicknessTable[0].y);
    ao += gSampleWeightTable[0].z * TestSamples(thisIdx, 3, 0, invThisDepth, gInvThicknessTable[0].z);
    ao += gSampleWeightTable[0].w * TestSamples(thisIdx, 4, 0, invThisDepth, gInvThicknessTable[0].w);
    ao += gSampleWeightTable[1].x * TestSamples(thisIdx, 1, 1, invThisDepth, gInvThicknessTable[1].x);
    ao += gSampleWeightTable[2].x * TestSamples(thisIdx, 2, 2, invThisDepth, gInvThicknessTable[2].x);
    ao += gSampleWeightTable[2].w * TestSamples(thisIdx, 3, 3, invThisDepth, gInvThicknessTable[2].w);
    ao += gSampleWeightTable[1].y * TestSamples(thisIdx, 1, 2, invThisDepth, gInvThicknessTable[1].y);
    ao += gSampleWeightTable[1].z * TestSamples(thisIdx, 1, 3, invThisDepth, gInvThicknessTable[1].z);
    ao += gSampleWeightTable[1].w * TestSamples(thisIdx, 1, 4, invThisDepth, gInvThicknessTable[1].w);
    ao += gSampleWeightTable[2].y * TestSamples(thisIdx, 2, 3, invThisDepth, gInvThicknessTable[2].y);
    ao += gSampleWeightTable[2].z * TestSamples(thisIdx, 2, 4, invThisDepth, gInvThicknessTable[2].z);
#else // SAMPLE_CHECKER
    // 36 samples:  sample every-other cell in a checker board pattern
    ao += gSampleWeightTable[0].y * TestSamples(thisIdx, 2, 0, invThisDepth, gInvThicknessTable[0].y);
    ao += gSampleWeightTable[0].w * TestSamples(thisIdx, 4, 0, invThisDepth, gInvThicknessTable[0].w);
    ao += gSampleWeightTable[1].x * TestSamples(thisIdx, 1, 1, invThisDepth, gInvThicknessTable[1].x);
    ao += gSampleWeightTable[2].x * TestSamples(thisIdx, 2, 2, invThisDepth, gInvThicknessTable[2].x);
    ao += gSampleWeightTable[2].w * TestSamples(thisIdx, 3, 3, invThisDepth, gInvThicknessTable[2].w);
    ao += gSampleWeightTable[1].z * TestSamples(thisIdx, 1, 3, invThisDepth, gInvThicknessTable[1].z);
    ao += gSampleWeightTable[2].z * TestSamples(thisIdx, 2, 4, invThisDepth, gInvThicknessTable[2].z);
#endif

#ifdef INTERLEAVE_RESULT
    uint2 OutPixel = DTid.xy << 2 | uint2(DTid.z & 3, DTid.z >> 2);
#else
    uint2 OutPixel = DTid.xy;
#endif
    Occlusion[OutPixel] = lerp(1, ao, gIntensity);
}

#endif // DISABLE_COMPUTE_SHADERS