using System; using System.Collections.Generic; using System.IO; using System.Linq; using Burst.Compiler.IL.Syntax; using Mono.Cecil; using Mono.Cecil.Cil; using Mono.Cecil.Rocks; namespace zzzUnity.Burst.CodeGen { /// /// Transforms a direct invoke on a burst function pointer into an calli, avoiding the need to marshal the delegate back. /// internal class FunctionPointerInvokeTransform { private struct CaptureInformation { public MethodReference Operand; public List Captured; } private Dictionary _needsNativeFunctionPointer; private Dictionary _needsIl2cppInvoke; private Dictionary> _capturedSets; private MethodDefinition _monoPInvokeAttributeCtorDef; private MethodDefinition _nativePInvokeAttributeCtorDef; private MethodDefinition _unmanagedFunctionPointerAttributeCtorDef; private TypeReference _burstFunctionPointerType; private TypeReference _burstCompilerType; private TypeReference _systemType; private TypeReference _callingConventionType; private LogDelegate _debugLog; private int _logLevel; private AssemblyResolver _loader; private ErrorDiagnosticDelegate _errorReport; public readonly static bool enableInvokeAttribute = true; #if UNITY_DOTSPLAYER public readonly static bool enableCalliOptimisation = true; public readonly static bool enableUnmangedFunctionPointerInject = false; #else public readonly static bool enableCalliOptimisation = false; // For now only run the pass on dots player/tiny public readonly static bool enableUnmangedFunctionPointerInject = true; #endif public FunctionPointerInvokeTransform(AssemblyResolver loader,ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0) { _loader = loader; _needsNativeFunctionPointer = new Dictionary(); _needsIl2cppInvoke = new Dictionary(); _capturedSets = new Dictionary>(); _monoPInvokeAttributeCtorDef = null; _unmanagedFunctionPointerAttributeCtorDef = null; _nativePInvokeAttributeCtorDef = null; // Only present on DOTS_PLAYER _burstFunctionPointerType = null; _burstCompilerType = null; _systemType = null; _callingConventionType = null; _debugLog = log; _logLevel = logLevel; _errorReport = error; } private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyResolver loader, string assemblyName) { if (loader.TryResolve(AssemblyNameReference.Parse(assemblyName), out var result)) { return result; } return null; } public void Initialize(AssemblyResolver loader, AssemblyDefinition assemblyDefinition, TypeSystem typeSystem) { if (_monoPInvokeAttributeCtorDef == null) { var burstAssembly = GetAsmDefinitionFromFile(loader, "Unity.Burst"); _burstFunctionPointerType = burstAssembly.MainModule.GetType("Unity.Burst.FunctionPointer`1"); _burstCompilerType = burstAssembly.MainModule.GetType("Unity.Burst.BurstCompiler"); var corLibrary = loader.Resolve(typeSystem.CoreLibrary as AssemblyNameReference); _systemType = corLibrary.MainModule.Types.FirstOrDefault(x => x.FullName == "System.Type"); // Only needed for MonoPInvokeCallback constructor in Unity if (enableUnmangedFunctionPointerInject) { var unmanagedFunctionPointerAttribute = corLibrary.MainModule.GetType("System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute"); _callingConventionType = corLibrary.MainModule.GetType("System.Runtime.InteropServices.CallingConvention"); _unmanagedFunctionPointerAttributeCtorDef = unmanagedFunctionPointerAttribute.GetConstructors().Single(c => c.Parameters.Count == 1 && c.Parameters[0].ParameterType.MetadataType == _callingConventionType.MetadataType); } #if UNITY_DOTSPLAYER var asmDef = assemblyDefinition; if (asmDef.Name.Name != "Unity.Runtime") asmDef = GetAsmDefinitionFromFile(loader, "Unity.Runtime"); if (asmDef == null) return; var monoPInvokeAttribute = asmDef.MainModule.GetType("Unity.Jobs.MonoPInvokeCallbackAttribute"); _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First(); var nativePInvokeAttribute = asmDef.MainModule.GetType("NativePInvokeCallbackAttribute"); _nativePInvokeAttributeCtorDef = nativePInvokeAttribute.GetConstructors().First(); #else var asmDef = GetAsmDefinitionFromFile(loader, "UnityEngine.CoreModule"); // bail if we can't find a reference, handled gracefully later if (asmDef == null) return; var monoPInvokeAttribute = asmDef.MainModule.GetType("AOT.MonoPInvokeCallbackAttribute"); _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First(); #endif } } public bool Run(AssemblyDefinition assemblyDefinition) { Initialize(_loader, assemblyDefinition, assemblyDefinition.MainModule.TypeSystem); var types = assemblyDefinition.MainModule.GetTypes().ToArray(); foreach (var type in types) { CollectDelegateInvokesFromType(type); } return Finish(); } public void CollectDelegateInvokesFromType(TypeDefinition type) { foreach (var m in type.Methods) { if (m.HasBody) { CollectDelegateInvokes(m); } } } private bool ProcessUnmanagedAttributeFixups() { if (_unmanagedFunctionPointerAttributeCtorDef == null) return false; bool modified = false; foreach (var kp in _needsNativeFunctionPointer) { var delegateType = kp.Key; var instruction = kp.Value.instruction; var method = kp.Value.method; var delegateDef = delegateType.Resolve(); var hasAttributeAlready = delegateDef.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == _unmanagedFunctionPointerAttributeCtorDef.DeclaringType.FullName); // If there is already an an attribute present if (hasAttributeAlready!=null) { if (hasAttributeAlready.ConstructorArguments.Count==1) { var cc = (System.Runtime.InteropServices.CallingConvention)hasAttributeAlready.ConstructorArguments[0].Value; if (cc == System.Runtime.InteropServices.CallingConvention.Cdecl) { if (_logLevel > 2) _debugLog?.Invoke($"UnmanagedAttributeFixups Skipping appending unmanagedFunctionPointerAttribute as already present aand calling convention matches"); } else { // constructor with non cdecl calling convention _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer(CallingConvention.{ Enum.GetName(typeof(System.Runtime.InteropServices.CallingConvention), cc) })]` please remove the attribute if you wish to use this function with Burst."); } } else { // Empty constructor which defaults to Winapi which is incompatable _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer]` please remove the attribute if you wish to use this function with Burst."); } continue; } var attribute = new CustomAttribute(delegateType.Module.ImportReference(_unmanagedFunctionPointerAttributeCtorDef)); attribute.ConstructorArguments.Add(new CustomAttributeArgument(delegateType.Module.ImportReference(_callingConventionType), System.Runtime.InteropServices.CallingConvention.Cdecl)); delegateDef.CustomAttributes.Add(attribute); modified = true; } return modified; } private bool ProcessIl2cppInvokeFixups() { if (_monoPInvokeAttributeCtorDef == null) return false; bool modified = false; foreach (var invokeNeeded in _needsIl2cppInvoke) { var declaringType = invokeNeeded.Value; var implementationMethod = invokeNeeded.Key; #if UNITY_DOTSPLAYER // At present always uses monoPInvokeAttribute because we don't currently know if the burst is enabled var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef)); implementationMethod.CustomAttributes.Add(attribute); modified = true; #else // Unity requires a type parameter for the attributecallback if (declaringType == null) { _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing declaringType for {implementationMethod}"); continue; } var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef)); attribute.ConstructorArguments.Add(new CustomAttributeArgument(implementationMethod.Module.ImportReference(_systemType), implementationMethod.Module.ImportReference(declaringType))); implementationMethod.CustomAttributes.Add(attribute); modified = true; if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Added InvokeCallbackAttribute to {implementationMethod}"); #endif } return modified; } private bool ProcessFunctionPointerInvokes() { var madeChange = false; foreach (var capturedData in _capturedSets) { var latePatchMethod = capturedData.Key; var capturedList = capturedData.Value; latePatchMethod.Body.SimplifyMacros(); // De-optimise short branches, since we will end up inserting instructions foreach(var capturedInfo in capturedList) { var captured = capturedInfo.Captured; var operand = capturedInfo.Operand; if (captured.Count!=2) { _debugLog?.Invoke($"FunctionPtrInvoke.Finish: expected 2 instructions - Unable to optimise this reference"); continue; } if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish:{Environment.NewLine} latePatchMethod:{latePatchMethod}{Environment.NewLine} captureList:{capturedList}{Environment.NewLine} capture0:{captured[0]}{Environment.NewLine} operand:{operand}"); var processor = latePatchMethod.Body.GetILProcessor(); var genericContext = GenericContext.From(operand, operand.DeclaringType); CallSite callsite; try { callsite = new CallSite(genericContext.Resolve(operand.ReturnType)) { CallingConvention = MethodCallingConvention.C }; for (int oo = 0; oo < operand.Parameters.Count; oo++) { var param = operand.Parameters[oo]; var ty = genericContext.Resolve(param.ParameterType); callsite.Parameters.Add(new ParameterDefinition(param.Name, param.Attributes, ty)); } } catch (NullReferenceException) { _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Failed to resolve the generic context of `{operand}`"); continue; } // Make sure everything is in order before we make a change var originalGetInvoke = captured[0]; if (originalGetInvoke.Operand is MethodReference mmr) { var genericMethodDef = mmr.Resolve(); var genericInstanceType = mmr.DeclaringType as GenericInstanceType; var genericInstanceDef = genericInstanceType.Resolve(); // Locate the correct instance method - we know already at this point we have an instance of Function MethodReference mr = default; bool failed = true; foreach (var m in genericInstanceDef.Methods) { if (m.FullName.Contains("get_Value")) { mr = m; failed = false; break; } } if (failed) { _debugLog?.Invoke($"FunctionPtrInvoke.Finish: failed to locate get_Value method on {genericInstanceDef} - Unable to optimise this reference"); continue; } var newGenericRef = new MethodReference(mr.Name, mr.ReturnType, genericInstanceType) { HasThis = mr.HasThis, ExplicitThis = mr.ExplicitThis, CallingConvention = mr.CallingConvention }; foreach (var param in mr.Parameters) newGenericRef.Parameters.Add(new ParameterDefinition(param.ParameterType)); foreach (var gparam in mr.GenericParameters) newGenericRef.GenericParameters.Add(new GenericParameter(gparam.Name, newGenericRef)); var importRef = latePatchMethod.Module.ImportReference(newGenericRef); var newMethodCall = processor.Create(OpCodes.Call, importRef); // Replace get_invoke with get_Value - Don't use replace though as if the original call is target of a branch //the branch doesn't get updated. originalGetInvoke.OpCode = newMethodCall.OpCode; originalGetInvoke.Operand = newMethodCall.Operand; // Add local to capture result var newLocal = new VariableDefinition(mr.ReturnType); latePatchMethod.Body.Variables.Add(newLocal); // Store result of get_Value var storeInst = processor.Create(OpCodes.Stloc, newLocal); processor.InsertAfter(originalGetInvoke, storeInst); // Swap invoke with calli var calli = processor.Create(OpCodes.Calli, callsite); // We can use replace here, since we already checked this is in the same Basic Block, and thus can't be target of a branch processor.Replace(captured[1], calli); // Insert load local prior to calli var loadValue = processor.Create(OpCodes.Ldloc, newLocal); processor.InsertBefore(calli, loadValue); if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Optimised {originalGetInvoke} with {newMethodCall}"); madeChange = true; } } latePatchMethod.Body.OptimizeMacros(); // Re-optimise branches } return madeChange; } public bool Finish() { bool madeChange = false; if (enableInvokeAttribute) { madeChange |= ProcessIl2cppInvokeFixups(); } if (enableUnmangedFunctionPointerInject) { madeChange |= ProcessUnmanagedAttributeFixups(); } if (enableCalliOptimisation) { madeChange |= ProcessFunctionPointerInvokes(); } return madeChange; } private bool IsBurstFunctionPointerMethod(MethodReference methodRef, string method, out GenericInstanceType methodInstance) { methodInstance = methodRef?.DeclaringType as GenericInstanceType; return (methodInstance != null && methodInstance.ElementType.FullName == _burstFunctionPointerType.FullName && methodRef.Name == method); } private bool IsBurstCompilerMethod(GenericInstanceMethod methodRef, string method) { var methodInstance = methodRef?.DeclaringType as TypeReference; return (methodInstance != null && methodInstance.FullName == _burstCompilerType.FullName && methodRef.Name == method); } private void LocateFunctionPointerTCreation(MethodDefinition m, Instruction i) { if (i.OpCode == OpCodes.Call) { var genInstMethod = i.Operand as GenericInstanceMethod; var isBurstCompilerCompileFunctionPointer = IsBurstCompilerMethod(genInstMethod, "CompileFunctionPointer"); var isBurstFunctionPointerGetInvoke = IsBurstFunctionPointerMethod(i.Operand as MethodReference, "get_Invoke", out var instanceType); if (!(isBurstCompilerCompileFunctionPointer || isBurstFunctionPointerGetInvoke)) return; if (enableUnmangedFunctionPointerInject) { var delegateType = isBurstCompilerCompileFunctionPointer ? genInstMethod.GenericArguments[0].Resolve() : instanceType.GenericArguments[0].Resolve(); // We check for null, since unfortunately it is possible that the call is wrapped inside //another open delegate and we cannot determine the delegate type if (delegateType != null && !_needsNativeFunctionPointer.ContainsKey(delegateType)) { _needsNativeFunctionPointer.Add(delegateType, (m, i)); } } // No need to process further if its not a CompileFunctionPointer method if (!isBurstCompilerCompileFunctionPointer) return; if (enableInvokeAttribute) { // Currently only handles the following pre-pattern (which should cover most common uses) // ldftn ... // newobj ... if (i.Previous?.OpCode != OpCodes.Newobj) { _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding NewObj {i.Previous}"); return; } var newObj = i.Previous; if (newObj.Previous?.OpCode != OpCodes.Ldftn) { _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding LdFtn {newObj.Previous}"); return; } var ldFtn = newObj.Previous; // Determine the delegate type var methodDefinition = newObj.Operand as MethodDefinition; var declaringType = methodDefinition?.DeclaringType; // Fetch the implementation method var implementationMethod = ldFtn.Operand as MethodDefinition; var hasInvokeAlready = implementationMethod?.CustomAttributes.FirstOrDefault(x => (x.AttributeType.FullName == _monoPInvokeAttributeCtorDef.DeclaringType.FullName) || (_nativePInvokeAttributeCtorDef != null && x.AttributeType.FullName == _nativePInvokeAttributeCtorDef.DeclaringType.FullName)); if (hasInvokeAlready != null) { if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Skipping appending Callback Attribute as already present {hasInvokeAlready}"); return; } if (implementationMethod == null) { _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing method from {ldFtn} {ldFtn.Operand}"); return; } if (implementationMethod.CustomAttributes.FirstOrDefault(x => x.Constructor.DeclaringType.Name == "BurstCompileAttribute") == null) { _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing burst attribute from {implementationMethod}"); return; } // Need to add the custom attribute if (!_needsIl2cppInvoke.ContainsKey(implementationMethod)) { _needsIl2cppInvoke.Add(implementationMethod, declaringType); } } } } [Obsolete("Will be removed in a future Burst verison")] public bool IsInstructionForFunctionPointerInvoke(MethodDefinition m, Instruction i) { throw new NotImplementedException(); } private void CollectDelegateInvokes(MethodDefinition m) { if (!(enableCalliOptimisation || enableInvokeAttribute || enableUnmangedFunctionPointerInject)) return; bool hitGetInvoke = false; TypeDefinition delegateType = null; List captured = null; foreach (var inst in m.Body.Instructions) { if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: CurrentInstruction {inst} {inst.Operand}"); // Check for a FunctionPointerT creation if (enableUnmangedFunctionPointerInject || enableInvokeAttribute) { LocateFunctionPointerTCreation(m, inst); } if (enableCalliOptimisation) { if (!hitGetInvoke) { if (inst.OpCode != OpCodes.Call) continue; if (!IsBurstFunctionPointerMethod(inst.Operand as MethodReference, "get_Invoke", out var methodInstance)) continue; // At this point we have a call to a FunctionPointer.Invoke hitGetInvoke = true; delegateType = methodInstance.GenericArguments[0].Resolve(); captured = new List(); captured.Add(inst); // Capture the get_invoke, we will swap this for get_value and a store to local } else { if (!(inst.OpCode.FlowControl == FlowControl.Next || inst.OpCode.FlowControl == FlowControl.Call)) { // Don't perform transform across blocks hitGetInvoke = false; } else { if (inst.OpCode == OpCodes.Callvirt) { if (inst.Operand is MethodReference mref) { var method = mref.Resolve(); if (method.DeclaringType == delegateType) { hitGetInvoke = false; List storage = null; if (!_capturedSets.TryGetValue(m, out storage)) { storage = new List(); _capturedSets.Add(m, storage); } // Capture the invoke - which we will swap for a load local (stored from the get_value) and a calli captured.Add(inst); var captureInfo = new CaptureInformation { Captured = captured, Operand = mref }; if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: captureInfo:{captureInfo}{Environment.NewLine}capture0{captured[0]}"); storage.Add(captureInfo); } } else { hitGetInvoke = false; } } } } } } } } }