#pragma warning(disable : 3568)
#pragma exclude_renderers gles gles3 d3d11_9x

#pragma kernel KAutoExposureAvgLuminance_fixed MAIN=KAutoExposureAvgLuminance_fixed
#pragma kernel KAutoExposureAvgLuminance_progressive MAIN=KAutoExposureAvgLuminance_progressive PROGRESSIVE

#include "Packages/com.unity.postprocessing/PostProcessing/Shaders/StdLib.hlsl"
#include "Packages/com.unity.postprocessing/PostProcessing/Shaders/Builtins/ExposureHistogram.hlsl"

StructuredBuffer<uint> _HistogramBuffer;
Texture2D<float> _Source;
RWTexture2D<float> _Destination;

CBUFFER_START(Params)
    float4 _Params1; // x: lowPercent, y: highPercent, z: minBrightness, w: maxBrightness
    float4 _Params2; // x: speed down, y: speed up, z: exposure compensation, w: delta time
    float4 _ScaleOffsetRes; // x: scale, y: offset, w: histogram pass width, h: histogram pass height
CBUFFER_END

groupshared uint gs_pyramid[HISTOGRAM_REDUCTION_BINS];

float GetExposureMultiplier(float avgLuminance)
{
    avgLuminance = max(EPSILON, avgLuminance);
    //float keyValue = 1.03 - (2.0 / (2.0 + log2(avgLuminance + 1.0)));
    float keyValue = _Params2.z;
    float exposure = keyValue / avgLuminance;
    return exposure;
}

float InterpolateExposure(float newExposure, float oldExposure)
{
    float delta = newExposure - oldExposure;
    float speed = delta > 0.0 ? _Params2.x : _Params2.y;
    float exposure = oldExposure + delta * (1.0 - exp2(-_Params2.w * speed));
    return exposure;
}

#ifdef DISABLE_COMPUTE_SHADERS

TRIVIAL_COMPUTE_KERNEL(MAIN)

#else

[numthreads(HISTOGRAM_REDUCTION_THREAD_X, HISTOGRAM_REDUCTION_THREAD_Y, 1)]
void MAIN(uint2 groupThreadId : SV_GroupThreadID)
{
#if HISTOGRAM_REDUCTION_ALT_PATH
    const uint thread_id = groupThreadId.y * HISTOGRAM_REDUCTION_THREAD_X + groupThreadId.x;
    gs_pyramid[thread_id] = max(_HistogramBuffer[thread_id], _HistogramBuffer[thread_id + HISTOGRAM_REDUCTION_BINS]);
#else
    const uint thread_id = groupThreadId.y * HISTOGRAM_REDUCTION_THREAD_X + groupThreadId.x;
    gs_pyramid[thread_id] = _HistogramBuffer[thread_id];
#endif

    GroupMemoryBarrierWithGroupSync();

    // Parallel reduction to find the max value
    UNITY_UNROLL
    for (uint i = HISTOGRAM_REDUCTION_BINS >> 1u; i > 0u; i >>= 1u)
    {
        if (thread_id < i)
            gs_pyramid[thread_id] = max(gs_pyramid[thread_id], gs_pyramid[thread_id + i]);

        GroupMemoryBarrierWithGroupSync();
    }

    GroupMemoryBarrierWithGroupSync();

    if (thread_id == 0u)
    {
        float maxValue = 1.0 / float(gs_pyramid[0]);

#if PROGRESSIVE
        float avgLuminance = GetAverageLuminance(_HistogramBuffer, _Params1, maxValue, _ScaleOffsetRes.xy);
        float exposure = GetExposureMultiplier(avgLuminance);
        float prevExposure = _Source[uint2(0u, 0u)].x;
        exposure = InterpolateExposure(exposure, prevExposure);
        _Destination[uint2(0u, 0u)].x = exposure.x;
#else
        float avgLuminance = GetAverageLuminance(_HistogramBuffer, _Params1, maxValue, _ScaleOffsetRes.xy);
        float exposure = GetExposureMultiplier(avgLuminance);
        _Destination[uint2(0u, 0u)].x = exposure.x;
#endif
    }
}

#endif // DISABLE_COMPUTE_SHADERS