//inputs
Texture2D<float4> PrecipMask;         //defines where rainfall will occur
Texture2D<float>  Collision;          //defines texels we will "lock" for the simulation
Texture2D<float>  TerrainHeightPrev;  //previous frame
Texture2D<float2> WaterVelPrev;       //water vel prev
Texture2D<float>  Hardness;           //mineral hardness

//Texture2D<float>   HeightToScale;     //curve mapping height -> scale

RWTexture2D<float>  TerrainHeight;      //terrain heightmap

RWTexture2D<float>  Water;       //water height
RWTexture2D<float>  WaterPrev;   //water height prev

RWTexture2D<float2> WaterVel;    //water velocity

RWTexture2D<float4> Flux;        //current water flux (how much water is transferring to neighboring cells)
RWTexture2D<float4> FluxPrev;    //flux

RWTexture2D<float>  Sediment;    //suspended sediment concentration (being transported by the water)
RWTexture2D<float>  SedimentPrev;

RWTexture2D<float>  Eroded;      //how much sediment was eroded this frame

float EffectScalar;
float DT;

float4 dxdy; //<dx, dy, 1 / dx, 1 / dy>
#define DX	dxdy.x
#define DY	dxdy.y
#define INV_DX dxdy.z
#define INV_DY dxdy.w

float4 WaterTransportScalars; //(WaterLevelScale, precipRate, flowRate, evapRate)
#define WATER_LEVEL_SCALE	WaterTransportScalars.x
#define PRECIP_RATE			WaterTransportScalars.y
#define FLOW_RATE			WaterTransportScalars.z
#define EVAP_RATE			WaterTransportScalars.w

float4 SedimentScalars; //(SedimentScale, Capacity, DissolveRate, DepositRate)
#define SEDIMENT_SCALE				SedimentScalars.x
#define SEDIMENT_CAP				SedimentScalars.y
#define SEDIMENT_DISSOLVE_RATE		SedimentScalars.z
#define SEDIMENT_DEPOSIT_RATE		SedimentScalars.w

float4 RiverBedScalars; //(bedDissolveRate, bedDepositRate, bankDissolveRate, bankDepositRate)
#define RIVER_BED_DISSOLVE_RATE		RiverBedScalars.x
#define RIVER_BED_DEPOSIT_RATE		RiverBedScalars.y
#define RIVER_BANK_DISSOLVE_RATE	RiverBedScalars.z
#define RIVER_BANK_DEPOSIT_RATE		RiverBedScalars.w

float4 texDim;     //the dimensions of all of our textures (they must all be the same dimensions)
float3 terrainDim; //the dimensions of the terrain in world units

//float4 HeightToScaleRange; //(min, max, ?, ?)

uint4 getSafeNeighbors(uint2 coord) {
	return uint4(
		(coord.x < (uint)(texDim[0] - 1)) ? coord.x + 1 : coord.x,  //right index
		(coord.x > 0) ? (uint)(coord.x - 1) : coord.x,              //left index
		(coord.y < (uint)(texDim[1] - 1)) ? coord.y + 1 : coord.y,  //bottom index
		(coord.y > 0) ? (uint)(coord.y - 1) : coord.y               //top index
		);
}

#define RIGHT(c)	(c.x)
#define LEFT(c)		(c.y)
#define BOTTOM(c)	(c.z)
#define TOP(c)		(c.w)

float3 computeNormal(uint2 coord) {
	//This is faster, but doesn't get recomputed every iteration.
	//float3 n = 2.0f * (SurfaceNormal[coord] - float3(0.5f, 0.5f, 0.5f));
	//n = normalize(n);


	//this is more accurate, and probably is better at removing rectilinear artifacts, since we are sampling 8 neighbors
	//but... slower
	uint4 nidx = getSafeNeighbors(coord);

	float dzdx = ((TerrainHeightPrev[uint2(RIGHT(nidx), BOTTOM(nidx))] + 2 * TerrainHeightPrev[uint2(RIGHT(nidx), coord.y)] + TerrainHeightPrev[uint2(RIGHT(nidx), TOP(nidx))]) -
		          (TerrainHeightPrev[uint2(LEFT(nidx), BOTTOM(nidx))] +  2 * TerrainHeightPrev[uint2(LEFT(nidx), coord.y)] +  TerrainHeightPrev[uint2(LEFT(nidx), TOP(nidx))])) / 8.0f;
	
	float dzdy = ((TerrainHeightPrev[uint2(LEFT(nidx), TOP(nidx))] +    2 * TerrainHeightPrev[uint2(coord.x, TOP(nidx))] +    TerrainHeightPrev[uint2(RIGHT(nidx), TOP(nidx))]) -
		          (TerrainHeightPrev[uint2(LEFT(nidx), BOTTOM(nidx))] + 2 * TerrainHeightPrev[uint2(coord.x, BOTTOM(nidx))] + TerrainHeightPrev[uint2(RIGHT(nidx), BOTTOM(nidx))])) / 8.0f;

	float m = length(float2(dzdx, dzdy));
	return normalize(float3(dzdx, m, dzdy));
}

float4 simulateFlux(uint2 coord, float cellArea) {
	float h = TerrainHeightPrev[coord] + WaterPrev[coord];

	float tileHeight = terrainDim.y;

	uint4 nidx = getSafeNeighbors(coord); //indices of neighbor cells, packed to a uint4

	// 1.21 Gigawatts?!?!?
	float fluxCapacitance = (DT * FLOW_RATE) / cellArea;

	//left
	float4 outFlux;
	float dh = tileHeight * (h - (TerrainHeightPrev[uint2(LEFT(nidx), coord.y)] + WaterPrev[uint2(LEFT(nidx), coord.y)]));
	LEFT(outFlux) = max(0.0f, LEFT(FluxPrev[coord]) + dh * fluxCapacitance);

	//right
	dh = tileHeight * (h - (TerrainHeightPrev[uint2(RIGHT(nidx), coord.y)] + WaterPrev[uint2(RIGHT(nidx), coord.y)]));
	RIGHT(outFlux) = max(0.0f, RIGHT(FluxPrev[coord]) + dh * fluxCapacitance);

	//top
	dh = tileHeight * (h - (TerrainHeightPrev[uint2(coord.x, TOP(nidx))] + WaterPrev[uint2(coord.x, TOP(nidx))]));
	TOP(outFlux) = max(0.0f, TOP(FluxPrev[coord]) + dh * fluxCapacitance);

	//bottom
	dh = tileHeight * (h - (TerrainHeightPrev[uint2(coord.x, BOTTOM(nidx))] + WaterPrev[uint2(coord.x, BOTTOM(nidx))]));
	BOTTOM(outFlux) = max(0.0f, BOTTOM(FluxPrev[coord]) + dh * fluxCapacitance);

	return outFlux;
}

void simulateWaterVelocity(uint2 coord, float cellArea) {
	uint4 nidx = getSafeNeighbors(coord);

	//use flux (instead of FluxPrev) here because we've already computed this for this frame
	float inFlow = RIGHT(Flux[uint2(LEFT(nidx), coord.y)]) +
		           LEFT(Flux[uint2(RIGHT(nidx), coord.y)]) +
		           TOP(Flux[uint2(coord.x, BOTTOM(nidx))]) +
		           BOTTOM(Flux[uint2(coord.x, TOP(nidx))]);

	float outFlow = LEFT(Flux[coord]) + RIGHT(Flux[coord]) + TOP(Flux[coord]) + BOTTOM(Flux[coord]);

	float dV = DT * (inFlow - outFlow);

	float waterOld = WaterPrev[coord] / WATER_LEVEL_SCALE;
	float waterNew = waterOld + dV / cellArea;
	
	//update velocities
	float waterAvg = 0.5f * (waterOld + waterNew);
	float2 newWaterVel;

	waterAvg = 1.0f / waterAvg; //possible NaN here if waterAvg is zero, but don't seem to be hitting it...
	newWaterVel.x = waterAvg * 0.5f * (LEFT(Flux[uint2(LEFT(nidx), coord.y)]) - RIGHT(Flux[coord]) - RIGHT(Flux[uint2(RIGHT(nidx), coord.y)]) + LEFT(Flux[coord])) / DX;
	newWaterVel.y = waterAvg * 0.5f * (BOTTOM(Flux[uint2(coord.x, TOP(nidx))]) - TOP(Flux[coord]) - TOP(Flux[uint2(coord.x, BOTTOM(nidx))]) + BOTTOM(Flux[coord])) / DY;

	Water[coord] = WATER_LEVEL_SCALE * max(0.0f, waterNew);
	WaterVel[coord] = newWaterVel;
}

#pragma kernel SimulateWaterFlow

[numthreads(8, 8, 1)]
void SimulateWaterFlow(uint3 id : SV_DispatchThreadID)
{
	uint2 coord = id.xy;

	//rainfall
	float rainfall = PRECIP_RATE * DT;
	float evaporated = EVAP_RATE * DT;

	float cellArea = (DX * DY);

	Flux[coord] = simulateFlux(coord, cellArea);
	simulateWaterVelocity(coord, cellArea);

	Water[coord] = WATER_LEVEL_SCALE * max(WaterPrev[coord] / WATER_LEVEL_SCALE + rainfall - evaporated, 0.0f);
}


//
// Hydraulic Erosion
//
void simulateErosion(uint2 coord, out float newSediment, out float newHeight) {
	float cellArea = dxdy.x * dxdy.y;

	float kc = cellArea * SEDIMENT_CAP;

	float h = TerrainHeightPrev[coord];
	float s = SedimentPrev[coord];// / SEDIMENT_SCALE;
	float w = WaterPrev[coord] / WATER_LEVEL_SCALE;

	float3 n = computeNormal(coord);
	
	float slopeFactor = saturate(dot(n, float3(0.0f, 1.0f, 0.0f)));

	float slopeDissolveScalar = lerp(RiverBedScalars[2], RiverBedScalars[0], slopeFactor);
	float slopeDepositScalar = lerp(RiverBedScalars[3], RiverBedScalars[1], slopeFactor);

	float flowSpeed = sqrt(WaterVel[coord].x * WaterVel[coord].x + WaterVel[coord].y * WaterVel[coord].y);

	// calculate how much sediment we will dissolve
	// the cos function here is just to blend between the maximum dissolve rate (when no sediment is suspended)
	// and zero, when the suspended sediment amount equals the sediment capacity
	float hardness = 0.0f;// Hardness[coord]; //TODO
	float effectiveDissolveRate = saturate(hardness) * slopeDissolveScalar * flowSpeed * SEDIMENT_DISSOLVE_RATE * cos(0.5f * 3.14159f * (s / kc));
	float dissolved = clamp(DT * effectiveDissolveRate, 0.0f, kc);

	//calculate how much sediment we will deposit
	float effectiveDepositRate = slopeDepositScalar * (1.0f / max(flowSpeed, 0.05f)) *  SEDIMENT_DEPOSIT_RATE;
	float deposited = DT * effectiveDepositRate;

	newSediment = max(s + EffectScalar * (dissolved - deposited), 0.0f);
	newHeight =  max(h + EffectScalar * (deposited - dissolved), 0.0f);
}

//TODO: general-purpose advection kernel? (use Advection compute shader instead, for consistency)
void simulateSedimentTransport(uint2 coord, out float transportedSediment) {
	float WaterSpeed = 1.0f; //user param?
	float speed = WaterSpeed * DT;

	float2 vel = WaterVel[coord].xy * dxdy.xy;
	
	float x = clamp((float)coord.x - (speed * vel.x), 0.0f, texDim[0] - 1);
	float y = clamp((float)coord.y - (speed * vel.y), 0.0f, texDim[1] - 1);
	
	uint2 uv0 = uint2((uint)x, (uint)y);
	uint2 uv1 = uv0 + uint2(1, 1);
	
	// remainder values, used for blending between voxels
	// TODO: optimize by using free pixel interpolation in pixel shader?
	float s1 = (x - (float)uv0.x) / texDim[0];
	float s0 = 1.0f - s1;
	float t1 = (y - (float)uv0.y) / texDim[1];
	float t0 = 1.0f - t1;
	
	// sample from 4 backtraced voxels, in proportions based on where the exact backtraced position landed
	// essentially doing bilinear interpolation here
	transportedSediment = (s0 * (t0 * SedimentPrev[uv0] + t1 * SedimentPrev[uint2(uv0.x, uv1.y)]) +
		s1 * (t1 * SedimentPrev[uint2(uv1.x, uv0.x)] + t1 * SedimentPrev[uv1]));// / SEDIMENT_SCALE;
}

#pragma kernel HydraulicErosion
[numthreads(8, 8, 1)]
void HydraulicErosion(uint3 id : SV_DispatchThreadID)
{
	float dissolvedSediment = 0.0f;
	float transportedSediment = 0.0f;
	float newHeight = 0.0f;

	simulateErosion(id.xy, dissolvedSediment, newHeight);
	simulateSedimentTransport(id.xy, transportedSediment); //doesn't appear to have a visual result

	//TODO: Figure out why these values are so screwy
	//Eroded[id.xy] = Eroded[id.xy] + 100.0f * max(TerrainHeightPrev[id.xy] - newHeight, 0.0f);
	Sediment[id.xy] = /*SEDIMENT_SCALE * */(dissolvedSediment + transportedSediment);
	TerrainHeight[id.xy] = newHeight;
}