Texture2D<float> In_BaseMaskTex;
Texture2D<float> In_HeightTex;
Texture2D<float> RemapTex;
RWTexture2D<float> OutputTex;

float EffectStrength;
int RemapTexRes;
float4 TextureResolution;

float2 gradient(uint2 id, uint2 offset) {

	uint upY = max(0, id.y - offset.y);
	uint downY = min((uint)TextureResolution.y, id.y + offset.y);
	uint rightX = max(0, id.x - offset.x);
	uint leftX = min((uint)TextureResolution.x, id.x + offset.x);

	//don't need to divide by 8.0f, since we are normalizing
	float dzdx = ((In_HeightTex[uint2(rightX, upY)] + 2 * In_HeightTex[uint2(rightX, id.y)] + In_HeightTex[uint2(rightX, downY)]) -
		(In_HeightTex[uint2(leftX, upY)] + 2 * In_HeightTex[uint2(leftX, id.y)] + In_HeightTex[uint2(leftX, downY)])) / 8.0f;
	float dzdy = ((In_HeightTex[uint2(leftX, downY)] + 2 * In_HeightTex[uint2(id.x, downY)] + In_HeightTex[uint2(rightX, downY)]) -
		(In_HeightTex[uint2(leftX, upY)] + 2 * In_HeightTex[uint2(id.x, upY)] + In_HeightTex[uint2(rightX, upY)])) / 8.0f;

	float mag = length(float2(dzdx, dzdy));
	return float2(dzdx, dzdy) / mag;
}


#pragma kernel ConcavityMultiply
[numthreads(1, 1, 1)]
void ConcavityMultiply(uint3 id : SV_DispatchThreadID)
{
	uint2 offset = uint2(TextureResolution.zz);

	uint upY = max(0, id.y - offset.x);
	uint downY = min((uint)TextureResolution.y, id.y + offset.y);
	uint rightX = max(0, id.x - offset.x);
	uint leftX = min((uint)TextureResolution.x, id.x + offset.y);

	/*
	float dgdx = gradient(uint2(rightX, id.y), offset).x - gradient(uint2(leftX, id.y), offset).x;
	float dgdy = gradient(uint2(id.x, downY), offset).y - gradient(uint2(id.x, upY), offset).y;
	*/
	float dgdx = ((gradient(uint2(rightX, upY), offset).x + 2 * gradient(uint2(rightX, id.y), offset).x + gradient(uint2(rightX, downY), offset).x) -
		(gradient(uint2(leftX, upY), offset).x + 2 * gradient(uint2(leftX, id.y), offset).x + gradient(uint2(leftX, downY), offset).x)) / 8.0f;
	float dgdy = ((gradient(uint2(leftX, downY), offset).y + 2 * gradient(uint2(id.x, downY), offset).y + gradient(uint2(rightX, downY), offset).y) -
		(gradient(uint2(leftX, upY), offset).y + 2 * gradient(uint2(id.x, upY), offset).y + gradient(uint2(rightX, upY), offset).y)) / 8.0f;

	float laplace = saturate(TextureResolution.w * (dgdx + dgdy) / 2.0f);

	uint remapIdx = (uint)(laplace * (float)(RemapTexRes - 1));
	float remappedLaplace = RemapTex[uint2(remapIdx, 0)];

	OutputTex[id.xy] = lerp(1.0f, remappedLaplace, EffectStrength) * In_BaseMaskTex[id.xy];
}