#pragma kernel GPUFrustumCulling


struct IndirectInstanceData
{
    float4x4 PositionMatrix;
    float4 ControlData;
};

struct IndirectShaderData
{
    float4x4 PositionMatrix;
    float4x4 InversePositionMatrix;
    float4 ControlData;
};

uint _InstanceCount;
bool UseLODs;
bool NoFrustumCulling;
bool ShadowCulling;

//Used for frustum culling
float4 _VS_CameraFrustumPlane0;
float4 _VS_CameraFrustumPlane1;
float4 _VS_CameraFrustumPlane2;
float4 _VS_CameraFrustumPlane3;
float4 _VS_CameraFrustumPlane4;
float4 _VS_CameraFrustumPlane5;
float4 _WorldSpaceCameraPos;

float4 _FloatingOriginOffset;

float3 _LightDirection;
float3 _PlaneOrigin;
float3 _BoundsSize;


//Used for vegetatuion distance culling
float _CullFarStart;
float _CullFarDistance;
float _BoundingSphereRadius;

//Used for LODs
float _LOD1Distance;
float _LOD2Distance;
float _LOD3Distance;

float _LODFactor;
float _LODBias;
float _LODFadeDistance;
int _LODCount;

StructuredBuffer<IndirectInstanceData> SourceShaderDataBuffer;
AppendStructuredBuffer<IndirectShaderData> VisibleBufferLOD0;
AppendStructuredBuffer<IndirectShaderData> VisibleBufferLOD1;
AppendStructuredBuffer<IndirectShaderData> VisibleBufferLOD2;
AppendStructuredBuffer<IndirectShaderData> VisibleBufferLOD3;

AppendStructuredBuffer<IndirectShaderData> ShadowBufferLOD0;
AppendStructuredBuffer<IndirectShaderData> ShadowBufferLOD1;
AppendStructuredBuffer<IndirectShaderData> ShadowBufferLOD2;
AppendStructuredBuffer<IndirectShaderData> ShadowBufferLOD3;

SamplerState _LinearClamp;

// float CalculateLODFade(float cameraDistance, float nextLODDistance)
// {
//     float distance = nextLODDistance - cameraDistance;
//     if (distance <= _LODFadeDistance)
//     {
//         return clamp(1 - distance / _LODFadeDistance, 0, 1);
//     }
//     return 0;
// }

float CalculateLODFadeFirst(float cameraDistance, float nextLODDistance)
{
    float distance = nextLODDistance + _LODFadeDistance - cameraDistance;
    if (distance <= _LODFadeDistance)
    {
        return clamp(distance/ _LODFadeDistance, 0, 1) * 2;
        return clamp(distance * distance/ _LODFadeDistance, 0, 1);
    }

    return 1;
}

float CalculateLODFadeMiddle(float thisLODDistance, float cameraDistance, float nextLODDistance)
{
    if (cameraDistance - thisLODDistance < _LODFadeDistance)
    {
        float distance = cameraDistance - thisLODDistance;
       // return clamp(distance * distance/ _LODFadeDistance, 0, 1);
        return clamp(distance/ _LODFadeDistance, 0, 1) * 2;
    }

    if (nextLODDistance + _LODFadeDistance - cameraDistance <= _LODFadeDistance)
    {
        float distance = nextLODDistance + _LODFadeDistance - cameraDistance;

        //return clamp(distance * distance / _LODFadeDistance, 0, 1);
        return clamp(distance/ _LODFadeDistance, 0, 1) * 2;
    }

    return 1;
}

float CalculateDistanceFade(float cameraDistance, float cullDistance)
{
    float distance = cullDistance - cameraDistance;
    if (distance <= _LODFadeDistance)
    {
        return clamp(1 - distance / _LODFadeDistance, 0, 1);
    }
    return 1;
}

float4x4 inverse(float4x4 input)
{
    #define minor(a,b,c) determinant(float3x3(input.a, input.b, input.c))

    float4x4 cofactors = float4x4(
        minor(_22_23_24, _32_33_34, _42_43_44),
        -minor(_21_23_24, _31_33_34, _41_43_44),
        minor(_21_22_24, _31_32_34, _41_42_44),
        -minor(_21_22_23, _31_32_33, _41_42_43),

        -minor(_12_13_14, _32_33_34, _42_43_44),
        minor(_11_13_14, _31_33_34, _41_43_44),
        -minor(_11_12_14, _31_32_34, _41_42_44),
        minor(_11_12_13, _31_32_33, _41_42_43),

        minor(_12_13_14, _22_23_24, _42_43_44),
        -minor(_11_13_14, _21_23_24, _41_43_44),
        minor(_11_12_14, _21_22_24, _41_42_44),
        -minor(_11_12_13, _21_22_23, _41_42_43),

        -minor(_12_13_14, _22_23_24, _32_33_34),
        minor(_11_13_14, _21_23_24, _31_33_34),
        -minor(_11_12_14, _21_22_24, _31_32_34),
        minor(_11_12_13, _21_22_23, _31_32_33)
    );
    #undef minor
    return transpose(cofactors) / determinant(input);
}


struct Ray
{
    float3 origin;
    float3 direction;
};

Ray CreateRay(float3 origin, float3 direction)
{
    Ray newRay;
    newRay.origin = origin;
    newRay.direction = direction;
    return newRay;
}

struct Bounds
{
    float3 center;
    float3 extents;

    float3 GetMin()
    {
        return center - extents;
    }

    float3 GetMax()
    {
        return center + extents;
    }

    void SetMinMax(float3 min, float3 max)
    {
        extents = (max - min) * 0.5f;
        center = min + extents;
    }

    void Encapsulate(float3 targetPoint)
    {
        SetMinMax(min(GetMin(), targetPoint), max(GetMax(), targetPoint));
    }
};

bool IntersectPlane(Ray ray, float3 planeOrigin, out float3 hitPoint)
{
    float3 planeNormal = -float3(0, 1, 0);
    float denominator = dot(ray.direction, planeNormal);
    if (denominator > 0.00001f)
    {
        float t = dot(planeOrigin - ray.origin, planeNormal) / denominator;
        hitPoint = ray.origin + ray.direction * t;
        return true;
    }

    hitPoint = float3(0, 0, 0);
    return false;
}

Bounds GetShadowBounds(Bounds objectBounds, float3 lightDirection, float3 planeOrigin, out bool hitPlane)
{
    float3 objectBoundsMin = objectBounds.GetMin();
    float3 objectBoundsMax = objectBounds.GetMax();

    Ray p0 = CreateRay(float3(objectBoundsMin.x, objectBoundsMax.y, objectBoundsMin.z), lightDirection);
    Ray p1 = CreateRay(float3(objectBoundsMin.x, objectBoundsMax.y, 0), lightDirection);
    Ray p2 = CreateRay(float3(objectBoundsMax.x, objectBoundsMax.y, objectBoundsMin.z), lightDirection);
    Ray p3 = CreateRay(objectBoundsMax, lightDirection);

    float3 hitPoint;
    hitPlane = false;

    if (IntersectPlane(p0, planeOrigin, hitPoint))
    {
        objectBounds.Encapsulate(hitPoint);
        hitPlane = true;
    }

    if (IntersectPlane(p1, planeOrigin, hitPoint))
    {
        objectBounds.Encapsulate(hitPoint);
        hitPlane = true;
    }

    if (IntersectPlane(p2, planeOrigin, hitPoint))
    {
        objectBounds.Encapsulate(hitPoint);
        hitPlane = true;
    }

    if (IntersectPlane(p3, planeOrigin, hitPoint))
    {
        objectBounds.Encapsulate(hitPoint);
        hitPlane = true;
    }

    return objectBounds;
}

bool TestPlaneIntersection(Bounds bounds, float4 plane)
{
    float3 center = bounds.center;
    float3 extents = bounds.extents;

    float3 planeNormal = plane.xyz;
    float planeDistance = plane.w;

    float3 absNormal = float3(abs(planeNormal.x), abs(planeNormal.y), abs(planeNormal.z));
    float r = extents.x * absNormal.x + extents.y * absNormal.y + extents.z * absNormal.z;
    float s = planeNormal.x * center.x + planeNormal.y * center.y + planeNormal.z * center.z;
    if (s + r < -planeDistance)
    {
        return false;
    }
    return true;
}

bool BoundsIntersectsFrustum(Bounds bounds)
{
    if (TestPlaneIntersection(bounds, _VS_CameraFrustumPlane0) == false)
    {
        return false;
    }

    if (TestPlaneIntersection(bounds, _VS_CameraFrustumPlane1) == false)
    {
        return false;
    }

    if (TestPlaneIntersection(bounds, _VS_CameraFrustumPlane2) == false)
    {
        return false;
    }

    if (TestPlaneIntersection(bounds, _VS_CameraFrustumPlane3) == false)
    {
        return false;
    }

    if (TestPlaneIntersection(bounds, _VS_CameraFrustumPlane4) == false)
    {
        return false;
    }

    if (TestPlaneIntersection(bounds, _VS_CameraFrustumPlane5) == false)
    {
        return false;
    }
    return true;
}

bool IsShadowVisible(Bounds objectBounds, float3 lightDirection, float3 planeOrigin)
{
    bool hitPlane;
    Bounds shadowBounds = GetShadowBounds(objectBounds, lightDirection, planeOrigin, hitPlane);
    return hitPlane && BoundsIntersectsFrustum(shadowBounds);
}

[numthreads(32, 1, 1)]
void GPUFrustumCulling(uint3 id : SV_DispatchThreadID)
{
    uint instanceId = id.x;
    if (instanceId < _InstanceCount)
    {
        IndirectShaderData instanceData;
        instanceData.PositionMatrix = SourceShaderDataBuffer[id.x].PositionMatrix;

        instanceData.PositionMatrix._m03_m13_m23 += _FloatingOriginOffset.xyz;

        float DistanceFalloff = SourceShaderDataBuffer[id.x].ControlData.x;

        float itemCullDistance = _CullFarStart * DistanceFalloff;
        float lod1Distance = clamp(_LOD1Distance * _LODFactor * _LODBias, 0, _CullFarStart);
        float lod2Distance = clamp(_LOD2Distance * _LODFactor * _LODBias, 0, _CullFarStart);
        float lod3Distance = clamp(_LOD3Distance * _LODFactor * _LODBias, 0, _CullFarStart);

        #define transformPosition mul(instanceData.PositionMatrix, float4(0,0,0,1)).xyz
        float3 position = transformPosition + float3(0.0f, _BoundingSphereRadius * 0.5f, 0.0f);

        if (NoFrustumCulling)
        {
            bool useLODFade = true;
            float3 itempos = instanceData.PositionMatrix._m03_m13_m23;
            float dist = distance(itempos, _WorldSpaceCameraPos.xyz);
            if (dist < itemCullDistance)
            {
                instanceData.InversePositionMatrix = inverse(instanceData.PositionMatrix);
                float distanceFade = CalculateDistanceFade(dist, itemCullDistance);

                instanceData.ControlData = float4(0, 0, distanceFade, 0);

                if (UseLODs)
                {
                    if (_LODCount == 1)
                    {
                        lod1Distance = max(lod1Distance, itemCullDistance);
                    }
                    else if (_LODCount == 2)
                    {
                        lod2Distance = max(lod2Distance, itemCullDistance);
                    }
                    else if (_LODCount == 3)
                    {
                        lod3Distance = max(lod3Distance, itemCullDistance);
                    }


                    if (dist <= lod1Distance + _LODFadeDistance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeFirst(dist, lod1Distance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 0);
                        }

                        VisibleBufferLOD0.Append(instanceData);
                    }
                    if (dist <= lod2Distance + _LODFadeDistance && dist > lod1Distance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeMiddle(lod1Distance, dist, lod2Distance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            lodFadeQuantified = 0;
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 1);
                        }
                        VisibleBufferLOD1.Append(instanceData);
                    }
                    if (dist <= lod3Distance + _LODFadeDistance && dist > lod2Distance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeMiddle(lod2Distance, dist, lod3Distance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            lodFadeQuantified = 0.5;
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 2);
                        }
                        VisibleBufferLOD2.Append(instanceData);
                    }
                    if (dist > lod3Distance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeMiddle(lod3Distance, dist, itemCullDistance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 3);
                        }
                        VisibleBufferLOD3.Append(instanceData);
                    }
                }
                else
                {
                    if (useLODFade)
                    {
                        float lodFade = CalculateLODFadeFirst(dist, itemCullDistance);
                        float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                        instanceData.ControlData = float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 0);
                    }
                    VisibleBufferLOD0.Append(instanceData);
                }
            }
            return;
        }

        float4 CameraDistances0 = float4(
            dot(_VS_CameraFrustumPlane0.xyz, position) + _VS_CameraFrustumPlane0.w,
            dot(_VS_CameraFrustumPlane1.xyz, position) + _VS_CameraFrustumPlane1.w,
            dot(_VS_CameraFrustumPlane2.xyz, position) + _VS_CameraFrustumPlane2.w,
            dot(_VS_CameraFrustumPlane3.xyz, position) + _VS_CameraFrustumPlane3.w
        );

        float4 CameraDistances1 = float4(
            dot(_VS_CameraFrustumPlane4.xyz, position) + _VS_CameraFrustumPlane4.w,
            dot(_VS_CameraFrustumPlane5.xyz, position) + _VS_CameraFrustumPlane5.w,
            0.0f,
            0.0f
        );

        if (!(all(CameraDistances0 >= -_BoundingSphereRadius)
            && all(CameraDistances1 >= -_BoundingSphereRadius)))
        {
            if (!ShadowCulling)
            {
                return;
            }
            bool useLODFade = true;
            float3 itempos = instanceData.PositionMatrix._m03_m13_m23;
            float dist = distance(itempos, _WorldSpaceCameraPos.xyz);

            if (dist < itemCullDistance)
            {
                Bounds itemBounds;
                itemBounds.center = itempos;
                itemBounds.extents = _BoundsSize;

                if (IsShadowVisible(itemBounds, _LightDirection, _PlaneOrigin))
                {
                    instanceData.InversePositionMatrix = inverse(instanceData.PositionMatrix);
                    float distanceFade = CalculateDistanceFade(dist, itemCullDistance);
                    instanceData.ControlData = float4(0, 0, distanceFade, 0);

                    if (UseLODs)
                    {
                        if (_LODCount == 1)
                        {
                            lod1Distance = max(lod1Distance, itemCullDistance);
                        }
                        else if (_LODCount == 2)
                        {
                            lod2Distance = max(lod2Distance, itemCullDistance);
                        }
                        else if (_LODCount == 3)
                        {
                            lod3Distance = max(lod3Distance, itemCullDistance);
                        }

                        if (dist <= lod1Distance + _LODFadeDistance)
                        {
                            if (useLODFade)
                            {
                                float lodFade = CalculateLODFadeFirst(dist, lod1Distance);
                                float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                                instanceData.ControlData = float4(lodFade, lodFadeQuantified,
                                                                  instanceData.ControlData.z, 0);
                            }

                            ShadowBufferLOD0.Append(instanceData);
                        }

                        if (dist <= lod2Distance + _LODFadeDistance && dist > lod1Distance)
                        {
                            if (useLODFade)
                            {
                                float lodFade = CalculateLODFadeMiddle(lod1Distance, dist, lod2Distance);
                                float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                                instanceData.ControlData = float4(lodFade, lodFadeQuantified,
                                                                  instanceData.ControlData.z, 0);
                            }
                            ShadowBufferLOD1.Append(instanceData);
                        }

                        if (dist <= lod3Distance + _LODFadeDistance && dist > lod2Distance)
                        {
                            if (useLODFade)
                            {
                                float lodFade = CalculateLODFadeMiddle(lod2Distance, dist, lod3Distance);
                                float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                                instanceData.ControlData = float4(lodFade, lodFadeQuantified,
                                                                  instanceData.ControlData.z, 0);
                            }
                            ShadowBufferLOD2.Append(instanceData);
                        }

                        if (dist > lod3Distance)
                        {
                            if (useLODFade)
                            {
                                float lodFade = CalculateLODFadeMiddle(lod3Distance, dist, itemCullDistance);
                                float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                                instanceData.ControlData = float4(lodFade, lodFadeQuantified,
                                                                  instanceData.ControlData.z, 0);
                            }
                            ShadowBufferLOD3.Append(instanceData);
                        }
                    }
                    else
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeFirst(dist, itemCullDistance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 0);
                        }
                        ShadowBufferLOD0.Append(instanceData);
                    }
                }
            }
        }
        else
        {
            bool useLODFade = true;
            float3 itempos = instanceData.PositionMatrix._m03_m13_m23;
            float dist = distance(itempos, _WorldSpaceCameraPos.xyz);

            if (dist < itemCullDistance + _LODFadeDistance)
            {
                instanceData.InversePositionMatrix = inverse(instanceData.PositionMatrix);
                float distanceFade = CalculateDistanceFade(dist, itemCullDistance);
                instanceData.ControlData = float4(0, 0, distanceFade, 0);

                if (UseLODs)
                {
                    if (_LODCount == 1)
                    {
                        lod1Distance = max(lod1Distance, itemCullDistance);
                    }
                    else if (_LODCount == 2)
                    {
                        lod2Distance = max(lod2Distance, itemCullDistance);
                    }
                    else if (_LODCount == 3)
                    {
                        lod3Distance = max(lod3Distance, itemCullDistance);
                    }

                    if (dist <= lod1Distance + _LODFadeDistance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeFirst(dist, lod1Distance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 0);
                        }

                        VisibleBufferLOD0.Append(instanceData);
                    }
                    if (dist <= lod2Distance + _LODFadeDistance && dist > lod1Distance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeMiddle(lod1Distance, dist, lod2Distance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            lodFadeQuantified = 0;
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 1);
                        }
                        VisibleBufferLOD1.Append(instanceData);
                    }
                    if (dist <= lod3Distance + _LODFadeDistance && dist > lod2Distance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeMiddle(lod2Distance, dist, lod3Distance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 2);
                        }
                        VisibleBufferLOD2.Append(instanceData);
                    }
                    if (dist > lod3Distance)
                    {
                        if (useLODFade)
                        {
                            float lodFade = CalculateLODFadeMiddle(lod3Distance, dist, itemCullDistance);
                            float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                            instanceData.ControlData =
                                float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 3);
                        }
                        VisibleBufferLOD3.Append(instanceData);
                    }
                }
                else
                {
                    if (useLODFade)
                    {
                        float lodFade = CalculateLODFadeFirst(dist, itemCullDistance);
                        float lodFadeQuantified = 1 - clamp(round(lodFade * 16) / 16, 0.0625, 1);
                        instanceData.ControlData = float4(lodFade, lodFadeQuantified, instanceData.ControlData.z, 0);
                    }
                    VisibleBufferLOD0.Append(instanceData);
                }
            }
        }
    }
}