BITKit/Src/Core/Net/NetProviderService.cs

472 lines
17 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using Cysharp.Threading.Tasks;
using MemoryPack;
using Microsoft.CodeAnalysis.Scripting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
namespace BITKit
{
public class NetProviderService
{
public static readonly byte[] Heartbeat = new byte[] { 2 };
private readonly ILogger<NetProviderService> _logger;
private readonly IServiceProvider _serviceProvider;
public readonly Dictionary<int,UniTaskCompletionSource<byte[]>> TaskCompletionSources = new();
private readonly ValidHandle _heartBeating = new();
private readonly IntervalUpdate _heatBeatTimeout = new(1);
public NetProviderService(IServiceProvider serviceProvider, ILogger<NetProviderService> logger)
{
_serviceProvider = serviceProvider;
_logger = logger;
_heartBeating.AddListener(x => { _logger.LogInformation(x ? "开始心跳" : "停止心跳"); });
}
private int _heartbeatCount = 0;
public void Tick()
{
_heartBeating.SetElements(this,!_heatBeatTimeout.AllowUpdateWithoutReset);
}
public async void OnData(int id, byte[] bytes)
{
using var ms = new MemoryStream(bytes);
using var br = new BinaryReader(ms);
var command = (NetCommandType)br.ReadByte();
_serviceProvider.QueryComponents(out INetProvider netProvider);
switch (command)
{
case NetCommandType.Heartbeat:
{
_heatBeatTimeout.Reset();
_heartBeating.AddElement(this);
//_logger.LogInformation($"Heartbeat:{++_heartbeatCount}");
}
break;
case NetCommandType.ReturnValue:
{
var index = br.ReadInt32();
//_logger.LogInformation($"收到远程调用结果,索引:{index}");
if (TaskCompletionSources.Remove(index, out var cs) is false)
{
_logger.LogWarning(
$"无法找到索引为 {index} 的任务完成源,可能是因为任务已经完成或不存在。,目前有:"+string.Join( ",",TaskCompletionSources.Keys));
break;
}
if (br.ReadInt32() is 0)
{
var remainingBytes = br.ReadBytes((int)(br.BaseStream.Length - br.BaseStream.Position));
cs.TrySetResult(remainingBytes);
}
else
{
var exp =br.ReadString();
cs.TrySetException(new NetRemoteInternalException(exp));
}
}
break;
case NetCommandType.WaitTask:
case NetCommandType.Rpc:
{
var rpcIndex = br.ReadInt32();
GetService(br, out var service, out var name);
var methodInfo = service.GetType().GetMethod(name, ReflectionHelper.Flags)!;
var parameters = methodInfo.GetParameters();
var args = parameters.Length is 0 ? Array.Empty<object>() : new object[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
var parameterInfo = parameters[i];
var type = parameterInfo.ParameterType;
var arg = RemoteInterfaceGenerator.Deserialize(type, br);
args[i] = arg;
}
var isAwaitable = methodInfo.ReturnType.GetMethod(nameof(Task.GetAwaiter)) != null;
object resultValue = null;
try
{
if (isAwaitable)
{
dynamic result = methodInfo.Invoke(service, args)!;
if (methodInfo.ReturnType == typeof(void)
||
methodInfo.ReturnType == typeof(UniTask)
||
methodInfo.ReturnType == typeof(UniTask<>)
)
{
await result;
resultValue = -1;
}
else
{
resultValue = await result;
}
}
else
{
resultValue = methodInfo.Invoke(service, args);
}
if (methodInfo.ReturnType == typeof(UniTask))
{
using var returnMs = new MemoryStream();
await using var bw = new BinaryWriter(returnMs);
bw.Write((byte)NetCommandType.ReturnValue);
bw.Write(rpcIndex);
bw.Write(0);
netProvider.Invoke(id, returnMs.ToArray());
}
else
{
var isUniTaskWithResult =
methodInfo.ReturnType.IsGenericType &&
methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(UniTask<>);
if (isUniTaskWithResult)
{
using var returnMs = new MemoryStream();
await using var bw = new BinaryWriter(returnMs);
bw.Write((byte)NetCommandType.ReturnValue);
bw.Write(rpcIndex);
bw.Write(0);
if (resultValue is not null)
{
RemoteInterfaceGenerator.Serialize(bw, resultValue);
}
else
{
}
netProvider.Invoke(id, returnMs.ToArray());
}
}
}
catch (Exception e)
{
//_logger.LogWarning(e.Message);
using var returnMs = new MemoryStream();
await using var bw = new BinaryWriter(returnMs);
bw.Write((byte)NetCommandType.ReturnValue);
bw.Write(rpcIndex);
bw.Write(1);
bw.Write(e.Message);
netProvider.Invoke(id, returnMs.ToArray());
}
}
break;
case NetCommandType.Message:
{
var message = br.ReadString();
_logger.LogInformation(message);
}
break;
}
}
private void GetService(BinaryReader binaryReader, out object service, out string name)
{
var typeName = binaryReader.ReadString();
name = binaryReader.ReadString();
var type = BITSharp.GetTypeFromFullName(typeName);
service = _serviceProvider.GetRequiredService(type);
}
public T GetRemoteInterface<T>()
{
var code = RemoteInterfaceGenerator.Shared.Generate(typeof(T));
var options = ScriptOptions.Default;
var assemblies = BITSharp.GetReferencedAssemblies(typeof(T));
options = options.AddReferences(typeof(MemoryPackSerializer).Assembly);
foreach (var referencedAssembly in assemblies)
{
options= options.AddReferences(referencedAssembly);
}
options = options.AddReferences(typeof(BITApp).Assembly);
var assembly = BITSharp.Compile(code, options);
var type = assembly.GetExportedTypes().First(x => typeof(T).IsAssignableFrom(x));
var instance = Activator.CreateInstance(type)!;
_serviceProvider.QueryComponents(out INetProvider netProvider);
instance.GetType().GetField("NetProvider")!.SetValue(instance,netProvider);
return (T)instance;
}
}
public static class RemoteInterfaceExtensions
{
public static void AddRemoteInterface<T>(this IServiceCollection self) where T : class
{
self.AddSingleton<T>(x => x.GetRequiredService<INetProvider>().GetRemoteInterface<T>());
}
}
public class RemoteInterfaceGenerator : BITSharp.CodeGenerator
{
#if UNITY_5_3_OR_NEWER
[UnityEngine.RuntimeInitializeOnLoadMethod]
private static void Reload()
{
Shared = new RemoteInterfaceGenerator();
}
#endif
public static RemoteInterfaceGenerator Shared { get; private set; }= new();
public INetProvider NetProvider { get; set; }
private readonly Dictionary<Type, bool> _memoryPackSupported = new();
public static void Serialize<T>(BinaryWriter writer, T data)
{
var type = typeof(T);
if (Shared._memoryPackSupported.TryGetValue(type,out var isSupported) is false)
{
try
{
if (data.GetType().IsArray)
{
MemoryPackSerializer.Serialize(data.GetType().GetElementType()!, (T)default);
}
else
{
MemoryPackSerializer.Serialize<T>(default(T));
}
isSupported = true;
}
catch (Exception)
{
isSupported = false;
}
Shared._memoryPackSupported[type] = isSupported;
}
if (isSupported)
{
writer.Write(true);
var bytes = MemoryPackSerializer.Serialize<T>(data);
writer.Write(bytes.Length);
writer.Write(bytes);
}
else
{
writer.Write(false);
var json = JsonConvert.SerializeObject(data);
writer.Write(json);
}
}
public static object Deserialize(Type type,BinaryReader br)
{
if (br.ReadBoolean())
{
var length = br.ReadInt32();
return MemoryPackSerializer.Deserialize(type,br.ReadBytes(length));
}
var json = br.ReadString();
return JsonConvert.DeserializeObject(json,type);
}
public static T Deserialize<T>(byte[] bytes)
{
using var ms = new MemoryStream(bytes);
using var br = new BinaryReader(ms);
return (T) Deserialize(typeof(T),br);
}
public override string BeforeGenerate(Type type)
{
return $"public {nameof(INetProvider)} {nameof(NetProvider)};";
}
public override IReadOnlyList<string> GenerateNamespaces(Type type)
{
return new string[]
{
"System.IO",
typeof(INetProvider).Namespace,
typeof(MemoryPackSerializer).Namespace,
};
}
public override string GenerateMethodContext(MethodInfo methodInfo)
{
var codeBuilder = new StringBuilder();
var parameterInfos = methodInfo.GetParameters();
using var ms = new MemoryStream();
using var writer = new BinaryWriter(ms);
foreach (var parameterInfo in methodInfo.GetParameters())
{
if (parameterInfo.IsOut)
{
codeBuilder.AppendLine($"{parameterInfo.Name} = default;");
}
}
var isAwaitable = methodInfo.ReturnType.GetMethod(nameof(Task.GetAwaiter)) != null;
codeBuilder.AppendLine(" using var ms = new MemoryStream();");
codeBuilder.AppendLine(" using var writer = new BinaryWriter(ms);");
codeBuilder.AppendLine(" writer.Write((byte)NetCommandType.Rpc);");
codeBuilder.AppendLine(" writer.Write(++NetProvider.RpcCount);");
codeBuilder.AppendLine($" writer.Write(\"{methodInfo.DeclaringType!.FullName}\");");
codeBuilder.AppendLine($" writer.Write(\"{methodInfo.Name}\");");
foreach (var parameterInfo in parameterInfos)
{
codeBuilder.AppendLine($"RemoteInterfaceGenerator.Serialize(writer,{parameterInfo.Name});");
}
if (methodInfo.ReturnType == typeof(UniTask))
{
codeBuilder.AppendLine(@"await NetProvider.InvokeAsync(0, ms.ToArray());");
codeBuilder.AppendLine("return;");
}
else
{
if (methodInfo.ReturnType.IsGenericType &&
methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(UniTask<>))
{
codeBuilder.AppendLine(@"var bytes =await NetProvider.InvokeAsync(0, ms.ToArray());");
codeBuilder.AppendLine(
$"return RemoteInterfaceGenerator.Deserialize<{methodInfo.ReturnType.GetGenericArguments()[0].CSharpName()}>(bytes);");
}
else
{
codeBuilder.AppendLine(@" NetProvider.Invoke(0,ms.ToArray());");
}
}
if (isAwaitable)
{
var generics = methodInfo.ReturnType.CSharpName();
generics = generics.Replace(nameof(UniTask), string.Empty);
generics = generics.Replace(nameof(Task), string.Empty);
if (methodInfo.ReturnType == typeof(void))
{
}else if (methodInfo.ReturnType == typeof(UniTask))
{
}
else
{
codeBuilder.AppendLine("return default;");
}
/*
codeBuilder.AppendLine(
$" return BITBinary.Read{generics}((await NetProvider.InvokeAsync(0,ms.ToArray())));");
*/
}
else
{
codeBuilder.AppendLine(base.GenerateMethodContext(methodInfo));
}
return codeBuilder.ToString();
}
public override string GenerateProperty(PropertyInfo propertyInfo)
{
var codeBuilder = new StringBuilder();
var source = base.GenerateProperty(propertyInfo);
source = source.Replace("get;", $"\nget=>_{propertyInfo.Name};");
source = source.Replace("set;",
"set\n" +
"{\n" +
$" _{propertyInfo.Name} = value;\n\n" +
" using var ms = new MemoryStream();\n" +
" using var writer = new BinaryWriter(ms);\n\n" +
" writer.Write((byte)NetCommandType.Rpc);\n" +
" writer.Write((byte)NetCommandType.SetPropertyValue);\n" +
$" writer.Write(\"{propertyInfo.DeclaringType!.FullName}\");\n" +
$" writer.Write(\"{propertyInfo.Name}\");\n\n" +
" RemoteInterfaceGenerator.Serialize(writer, value);\n\n" +
" NetProvider.Invoke(0,ms.ToArray());\n" +
"}"
);
codeBuilder.AppendLine(source);
codeBuilder.AppendLine($"public {propertyInfo.PropertyType.CSharpName()} _{propertyInfo.Name};");
return codeBuilder.ToString();
}
}
}