#include "Packages/com.unity.postprocessing/PostProcessing/Shaders/StdLib.hlsl"
#include "Packages/com.unity.postprocessing/PostProcessing/Shaders/Colors.hlsl"
#include "Packages/com.unity.postprocessing/PostProcessing/Shaders/ACES.hlsl"

#pragma kernel KGenLut3D_NoTonemap      TONEMAPPING_NONE
#pragma kernel KGenLut3D_AcesTonemap    TONEMAPPING_ACES
#pragma kernel KGenLut3D_NeutralTonemap TONEMAPPING_NEUTRAL
#pragma kernel KGenLut3D_CustomTonemap  TONEMAPPING_CUSTOM

RWTexture3D<float4> _Output;

CBUFFER_START(Params)
    float4 _Size; // x: lut_size, y: 1 / (lut_size - 1), zw: unused

    float4 _ColorBalance;
    float4 _ColorFilter;
    float4 _HueSatCon;

    float4 _ChannelMixerRed;
    float4 _ChannelMixerGreen;
    float4 _ChannelMixerBlue;

    float4 _Lift;
    float4 _InvGamma;
    float4 _Gain;

    float4 _CustomToneCurve;

    // Packing is currently borked, can't pass float arrays without it creating one vector4 per
    // float so we'll pack manually...
    float4 _ToeSegmentA;
    float4 _ToeSegmentB;
    float4 _MidSegmentA;
    float4 _MidSegmentB;
    float4 _ShoSegmentA;
    float4 _ShoSegmentB;
CBUFFER_END

Texture2D _Curves;
SamplerState sampler_Curves;

float3 LogGrade(float3 colorLog)
{
    // Contrast feels a lot more natural when done in log rather than doing it in linear
    colorLog = Contrast(colorLog, ACEScc_MIDGRAY, _HueSatCon.z);

    return colorLog;
}

float3 LinearGrade(float3 colorLinear)
{
    colorLinear = WhiteBalance(colorLinear, _ColorBalance.rgb);
    colorLinear *= _ColorFilter.rgb;
    colorLinear = ChannelMixer(colorLinear, _ChannelMixerRed.rgb, _ChannelMixerGreen.rgb, _ChannelMixerBlue.rgb);
    colorLinear = LiftGammaGainHDR(colorLinear, _Lift.rgb, _InvGamma.rgb, _Gain.rgb);

    // Do NOT feed negative values to RgbToHsv or they'll wrap around
    colorLinear = max(0.0, colorLinear);

    float3 hsv = RgbToHsv(colorLinear);

    // Hue Vs Sat
    float satMult;
    satMult = saturate(_Curves.SampleLevel(sampler_Curves, float2(hsv.x, 0.25), 0).y) * 2.0;

    // Sat Vs Sat
    satMult *= saturate(_Curves.SampleLevel(sampler_Curves, float2(hsv.y, 0.25), 0).z) * 2.0;

    // Lum Vs Sat
    satMult *= saturate(_Curves.SampleLevel(sampler_Curves, float2(Luminance(colorLinear), 0.25), 0).w) * 2.0;

    // Hue Vs Hue
    float hue = hsv.x + _HueSatCon.x;
    float offset = saturate(_Curves.SampleLevel(sampler_Curves, float2(hue, 0.25), 0).x) - 0.5;
    hue += offset;
    hsv.x = RotateHue(hue, 0.0, 1.0);

    colorLinear = HsvToRgb(hsv);
    colorLinear = Saturation(colorLinear, _HueSatCon.y * satMult);

    return colorLinear;
}

#if TONEMAPPING_ACES

float3 ColorGrade(float3 colorLutSpace)
{
    float3 colorLinear = LUT_SPACE_DECODE(colorLutSpace);
    float3 aces = unity_to_ACES(colorLinear);

    // ACEScc (log) space
    float3 acescc = ACES_to_ACEScc(aces);
    acescc = LogGrade(acescc);
    aces = ACEScc_to_ACES(acescc);

    // ACEScg (linear) space
    float3 acescg = ACES_to_ACEScg(aces);
    acescg = LinearGrade(acescg);

    // Tonemap ODT(RRT(aces))
    aces = ACEScg_to_ACES(acescg);
    colorLinear = AcesTonemap(aces);

    return colorLinear;
}

#else

float3 ColorGrade(float3 colorLutSpace)
{
    // colorLutSpace is already in log space
    colorLutSpace = LogGrade(colorLutSpace);

    // Switch back to linear
    float3 colorLinear = LUT_SPACE_DECODE(colorLutSpace);
    colorLinear = LinearGrade(colorLinear);
    colorLinear = max(0.0, colorLinear);

    // Tonemap
    #if TONEMAPPING_NEUTRAL
    {
        colorLinear = NeutralTonemap(colorLinear);
    }
    #elif TONEMAPPING_CUSTOM
    {
        colorLinear = CustomTonemap(
            colorLinear, _CustomToneCurve.xyz,
            _ToeSegmentA, _ToeSegmentB.xy,
            _MidSegmentA, _MidSegmentB.xy,
            _ShoSegmentA, _ShoSegmentB.xy
        );
    }
    #endif

    return colorLinear;
}

#endif

void Eval(uint3 id)
{
    if (float(id.x) < _Size.x && float(id.y) < _Size.x && float(id.z) < _Size.x)
    {
        // Lut space (log space)
        float3 colorLutSpace = float3(id) * _Size.y;

        // Color grade & tonemap
        float3 graded = ColorGrade(colorLutSpace);

        _Output[id] = float4(max(graded, 0.0), 1.0);
    }
}

#define GROUP_SIZE 4

#ifdef DISABLE_COMPUTE_SHADERS

TRIVIAL_COMPUTE_KERNEL(KGenLut3D_NoTonemap)
TRIVIAL_COMPUTE_KERNEL(KGenLut3D_AcesTonemap)
TRIVIAL_COMPUTE_KERNEL(KGenLut3D_NeutralTonemap)
TRIVIAL_COMPUTE_KERNEL(KGenLut3D_CustomTonemap)

#else

[numthreads(GROUP_SIZE, GROUP_SIZE, GROUP_SIZE)]
void KGenLut3D_NoTonemap(uint3 id : SV_DispatchThreadID) { Eval(id); }

[numthreads(GROUP_SIZE, GROUP_SIZE, GROUP_SIZE)]
void KGenLut3D_AcesTonemap(uint3 id : SV_DispatchThreadID) { Eval(id); }

[numthreads(GROUP_SIZE, GROUP_SIZE, GROUP_SIZE)]
void KGenLut3D_NeutralTonemap(uint3 id : SV_DispatchThreadID) { Eval(id); }

[numthreads(GROUP_SIZE, GROUP_SIZE, GROUP_SIZE)]
void KGenLut3D_CustomTonemap(uint3 id : SV_DispatchThreadID) { Eval(id); }

#endif // DISABLE_COMPUTE_SHADERS