diff --git a/Directory.Packages.props b/Directory.Packages.props index 301024cf8a..6dcf34f1d1 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -36,6 +36,7 @@ + @@ -49,4 +50,4 @@ - \ No newline at end of file + diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs index 7ecd6835db..b03fba375c 100644 --- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs @@ -103,7 +103,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { lock (_lock) { - oldFile.Refcount++; + oldFile.RefCount++; return RegisterFileDescriptor(oldFile); } @@ -118,9 +118,9 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd if (file != null) { - file.Refcount--; + file.RefCount--; - if (file.Refcount <= 0) + if (file.RefCount <= 0) { file.Dispose(); } diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs index 21d48288ec..ac989be927 100644 --- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs @@ -1,6 +1,7 @@ using Ryujinx.Common; using Ryujinx.Common.Logging; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl; +using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types; using Ryujinx.Memory; using System; @@ -20,6 +21,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { EventFileDescriptorPollManager.Instance, ManagedSocketPollManager.Instance, + ManagedProxySocketPollManager.Instance, }; private BsdContext _context; @@ -95,10 +97,8 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd } } - ISocket newBsdSocket = new ManagedSocket(netDomain, (SocketType)type, protocol) - { - Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking), - }; + ISocket newBsdSocket = ProxyManager.GetSocket(netDomain, (SocketType)type, protocol); + newBsdSocket.Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking); LinuxError errno = LinuxError.SUCCESS; diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs index 6c00d5e118..4ad5806787 100644 --- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs @@ -6,7 +6,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd interface IFileDescriptor : IDisposable { bool Blocking { get; set; } - int Refcount { get; set; } + int RefCount { get; set; } LinuxError Read(out int readSize, Span buffer); diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs index 5b9e6811d3..7c599335f1 100644 --- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs @@ -32,7 +32,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl UpdateEventStates(); } - public int Refcount { get; set; } + public int RefCount { get; set; } public void Dispose() { diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocket.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocket.cs new file mode 100644 index 0000000000..df66b83f29 --- /dev/null +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocket.cs @@ -0,0 +1,476 @@ +using Ryujinx.Common.Logging; +using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types; +using RyuSocks; +using RyuSocks.Auth; +using RyuSocks.Commands; +using RyuSocks.Types; +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl +{ + class ManagedProxySocket : ISocket + { + private static readonly Dictionary _authMethods = new() + { + { AuthMethod.NoAuth, new NoAuth() }, + }; + + private readonly bool _isUdpSocket; + private readonly bool _acceptedConnection; + + public SocksClient ProxyClient { get; } + + public bool Blocking { get => ProxyClient.Blocking; set => ProxyClient.Blocking = value; } + public int RefCount { get; set; } + + public IPEndPoint RemoteEndPoint => (IPEndPoint)ProxyClient.ProxiedRemoteEndPoint; + public IPEndPoint LocalEndPoint => (IPEndPoint)ProxyClient.ProxiedLocalEndPoint; + + public AddressFamily AddressFamily => ProxyClient.AddressFamily; + public SocketType SocketType => ProxyClient.SocketType; + public ProtocolType ProtocolType => ProxyClient.ProtocolType; + public IntPtr Handle => throw new NotSupportedException("Can't get the handle of a proxy socket."); + + public ManagedProxySocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, EndPoint proxyEndpoint) + { + if (addressFamily != proxyEndpoint.AddressFamily && addressFamily != AddressFamily.Unspecified) + { + throw new ArgumentException( + $"Invalid {nameof(System.Net.Sockets.AddressFamily)}", nameof(addressFamily)); + } + + if (socketType != SocketType.Stream && socketType != SocketType.Dgram) + { + throw new ArgumentException( + $"Invalid {nameof(System.Net.Sockets.SocketType)}", nameof(socketType)); + } + + if (protocolType != ProtocolType.Tcp && protocolType != ProtocolType.Udp) + { + throw new ArgumentException( + $"Invalid {nameof(System.Net.Sockets.ProtocolType)}", nameof(protocolType)); + } + + _isUdpSocket = socketType == SocketType.Dgram && protocolType == ProtocolType.Udp; + + ProxyClient = proxyEndpoint switch + { + IPEndPoint ipEndPoint => new SocksClient(ipEndPoint) { OfferedAuthMethods = _authMethods }, + DnsEndPoint dnsEndPoint => new SocksClient(dnsEndPoint) { OfferedAuthMethods = _authMethods }, + _ => throw new ArgumentException($"Unsupported {nameof(EndPoint)} type", nameof(proxyEndpoint)) + }; + + ProxyClient.Authenticate(); + + RefCount = 1; + } + + private ManagedProxySocket(SocksClient proxyClient) + { + ProxyClient = proxyClient; + _acceptedConnection = true; + RefCount = 1; + } + + private static LinuxError ToLinuxError(ReplyField proxyReply) + { + return proxyReply switch + { + ReplyField.Succeeded => LinuxError.SUCCESS, + ReplyField.ServerFailure => LinuxError.ECONNRESET, + ReplyField.ConnectionNotAllowed => LinuxError.ECONNREFUSED, + ReplyField.NetworkUnreachable => LinuxError.ENETUNREACH, + ReplyField.HostUnreachable => LinuxError.EHOSTUNREACH, + ReplyField.ConnectionRefused => LinuxError.ECONNREFUSED, + ReplyField.TTLExpired => LinuxError.EHOSTUNREACH, + ReplyField.CommandNotSupported => LinuxError.EOPNOTSUPP, + ReplyField.AddressTypeNotSupported => LinuxError.EAFNOSUPPORT, + _ => throw new ArgumentOutOfRangeException(nameof(proxyReply)) + }; + } + + public void Dispose() + { + ProxyClient.Dispose(); + } + + public LinuxError Read(out int readSize, Span buffer) + { + return Receive(out readSize, buffer, BsdSocketFlags.None); + } + + public LinuxError Write(out int writeSize, ReadOnlySpan buffer) + { + return Send(out writeSize, buffer, BsdSocketFlags.None); + } + + public LinuxError Receive(out int receiveSize, Span buffer, BsdSocketFlags flags) + { + LinuxError result; + bool shouldBlockAfterOperation = false; + + if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait)) + { + Blocking = false; + shouldBlockAfterOperation = true; + } + + byte[] proxyBuffer = new byte[buffer.Length + ProxyClient.GetRequiredWrapperSpace()]; + + try + { + receiveSize = ProxyClient.Receive( + proxyBuffer, + WinSockHelper.ConvertBsdSocketFlags(flags), + out SocketError errorCode + ); + + proxyBuffer[..receiveSize].CopyTo(buffer); + + result = WinSockHelper.ConvertError((WsaError)errorCode); + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"An error occured while trying to receive data: {exception}" + ); + + receiveSize = -1; + result = ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + receiveSize = -1; + result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + if (shouldBlockAfterOperation) + { + Blocking = true; + } + + return result; + } + + public LinuxError ReceiveFrom(out int receiveSize, Span buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint) + { + LinuxError result; + remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); + bool shouldBlockAfterOperation = false; + + byte[] proxyBuffer = new byte[size + ProxyClient.GetRequiredWrapperSpace()]; + + if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait)) + { + Blocking = false; + shouldBlockAfterOperation = true; + } + + try + { + EndPoint temp = new IPEndPoint(IPAddress.Any, 0); + + receiveSize = ProxyClient.ReceiveFrom(proxyBuffer, WinSockHelper.ConvertBsdSocketFlags(flags), ref temp); + + proxyBuffer[..receiveSize].CopyTo(buffer); + + remoteEndPoint = (IPEndPoint)temp; + result = LinuxError.SUCCESS; + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"An error occured while trying to receive data: {exception}" + ); + + receiveSize = -1; + result = ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + receiveSize = -1; + + result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + if (shouldBlockAfterOperation) + { + Blocking = true; + } + + return result; + } + + public LinuxError Send(out int sendSize, ReadOnlySpan buffer, BsdSocketFlags flags) + { + try + { + sendSize = ProxyClient.Send(buffer, WinSockHelper.ConvertBsdSocketFlags(flags)); + + return LinuxError.SUCCESS; + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"An error occured while trying to send data: {exception}" + ); + + sendSize = -1; + + return ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + sendSize = -1; + + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError SendTo(out int sendSize, ReadOnlySpan buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint) + { + try + { + // NOTE: sendSize might be larger than size and/or buffer.Length. + sendSize = ProxyClient.SendTo(buffer[..size], WinSockHelper.ConvertBsdSocketFlags(flags), remoteEndPoint); + + return LinuxError.SUCCESS; + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"An error occured while trying to send data: {exception}" + ); + + sendSize = -1; + + return ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + sendSize = -1; + + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout) + { + throw new NotImplementedException(); + } + + public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags) + { + throw new NotImplementedException(); + } + + /// + /// Adapted from + /// + public LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span optionValue) + { + try + { + LinuxError result = WinSockHelper.ValidateSocketOption(option, level, write: false); + + if (result != LinuxError.SUCCESS) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Invalid GetSockOpt Option: {option} Level: {level}"); + + return result; + } + + if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}"); + optionValue.Clear(); + + return LinuxError.SUCCESS; + } + + byte[] tempOptionValue = new byte[optionValue.Length]; + + ProxyClient.GetSocketOption(level, optionName, tempOptionValue); + + tempOptionValue.AsSpan().CopyTo(optionValue); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + /// + /// Adapted from + /// + public LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan optionValue) + { + try + { + LinuxError result = WinSockHelper.ValidateSocketOption(option, level, write: true); + + if (result != LinuxError.SUCCESS) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Invalid SetSockOpt Option: {option} Level: {level}"); + + return result; + } + + if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}"); + + return LinuxError.SUCCESS; + } + + byte[] value = optionValue.ToArray(); + + ProxyClient.SetSocketOption(level, optionName, value); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public bool Poll(int microSeconds, SelectMode mode) + { + return ProxyClient.Poll(microSeconds, mode); + } + + public LinuxError Bind(IPEndPoint localEndPoint) + { + ProxyClient.RequestCommand = _isUdpSocket ? ProxyCommand.UdpAssociate : ProxyCommand.Bind; + + try + { + ProxyClient.Bind(localEndPoint); + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"Request for {ProxyClient.RequestCommand} command failed: {exception}" + ); + + return ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + return LinuxError.SUCCESS; + } + + public LinuxError Connect(IPEndPoint remoteEndPoint) + { + ProxyClient.RequestCommand = ProxyCommand.Connect; + + try + { + ProxyClient.Connect(remoteEndPoint.Address, remoteEndPoint.Port); + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"Request for {ProxyClient.RequestCommand} command failed: {exception}" + ); + + return ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + return LinuxError.SUCCESS; + } + + public LinuxError Listen(int backlog) + { + // NOTE: Only one client can connect with the default SOCKS5 commands. + if (ProxyClient.RequestCommand != ProxyCommand.Bind) + { + return LinuxError.EOPNOTSUPP; + } + + return LinuxError.SUCCESS; + } + + public LinuxError Accept(out ISocket newSocket) + { + newSocket = null; + + if (ProxyClient.RequestCommand != ProxyCommand.Bind) + { + return LinuxError.EOPNOTSUPP; + } + + // NOTE: Only one client can connect with the default SOCKS5 commands. + if (_acceptedConnection) + { + return LinuxError.EOPNOTSUPP; + } + + try + { + SocksClient newProxyClient = ProxyClient.Accept(); + newSocket = new ManagedProxySocket(newProxyClient); + } + catch (ProxyException exception) + { + Logger.Error?.Print( + LogClass.ServiceBsd, + $"Failed to accept client connection: {exception}" + ); + + return ToLinuxError(exception.ReplyCode); + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + return LinuxError.SUCCESS; + } + + public void Disconnect() + { + ProxyClient.Disconnect(); + ProxyClient.RequestCommand = 0; + } + + public LinuxError Shutdown(BsdSocketShutdownFlags how) + { + try + { + ProxyClient.Shutdown((SocketShutdown)how); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public void Close() + { + ProxyClient.Close(); + ProxyClient.RequestCommand = 0; + } + } +} diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocketPollManager.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocketPollManager.cs new file mode 100644 index 0000000000..02e6dad487 --- /dev/null +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocketPollManager.cs @@ -0,0 +1,213 @@ +using Ryujinx.Common.Logging; +using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types; +using System.Collections.Generic; +using System.Net.Sockets; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl +{ + class ManagedProxySocketPollManager : IPollManager + { + private static ManagedProxySocketPollManager _instance; + + public static ManagedProxySocketPollManager Instance + { + get + { + _instance ??= new ManagedProxySocketPollManager(); + + return _instance; + } + } + + public bool IsCompatible(PollEvent evnt) + { + return evnt.FileDescriptor is ManagedProxySocket; + } + + public LinuxError Poll(List events, int timeoutMilliseconds, out int updatedCount) + { + Dictionary> eventDict = new() + { + { SelectMode.SelectRead, [] }, + { SelectMode.SelectWrite, [] }, + { SelectMode.SelectError, [] }, + }; + + updatedCount = 0; + + foreach (PollEvent evnt in events) + { + ManagedProxySocket socket = (ManagedProxySocket)evnt.FileDescriptor; + + bool isValidEvent = evnt.Data.InputEvents == 0; + + eventDict[SelectMode.SelectError].Add(socket); + + if ((evnt.Data.InputEvents & PollEventTypeMask.Input) != 0) + { + eventDict[SelectMode.SelectRead].Add(socket); + + isValidEvent = true; + } + + if ((evnt.Data.InputEvents & PollEventTypeMask.UrgentInput) != 0) + { + eventDict[SelectMode.SelectRead].Add(socket); + + isValidEvent = true; + } + + if ((evnt.Data.InputEvents & PollEventTypeMask.Output) != 0) + { + eventDict[SelectMode.SelectWrite].Add(socket); + + isValidEvent = true; + } + + if (!isValidEvent) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported Poll input event type: {evnt.Data.InputEvents}"); + return LinuxError.EINVAL; + } + } + + try + { + int actualTimeoutMicroseconds = timeoutMilliseconds == -1 ? -1 : timeoutMilliseconds * 1000; + int totalEvents = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count; + // TODO: Maybe check all events first, wait for the timeout and then check the failed ones again? + int timeoutMicrosecondsPerEvent = actualTimeoutMicroseconds == -1 ? -1 : actualTimeoutMicroseconds / totalEvents; + + foreach ((SelectMode selectMode, List eventList) in eventDict) + { + List newEventList = []; + + foreach (ManagedProxySocket eventSocket in eventList) + { + if (eventSocket.Poll(timeoutMicrosecondsPerEvent, selectMode)) + { + newEventList.Add(eventSocket); + } + } + + eventDict[selectMode] = newEventList; + } + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + foreach (PollEvent evnt in events) + { + ManagedProxySocket socket = ((ManagedProxySocket)evnt.FileDescriptor); + + PollEventTypeMask outputEvents = evnt.Data.OutputEvents & ~evnt.Data.InputEvents; + + if (eventDict[SelectMode.SelectError].Contains(socket)) + { + outputEvents |= PollEventTypeMask.Error; + + if (!socket.ProxyClient.Connected || !socket.ProxyClient.IsBound) + { + outputEvents |= PollEventTypeMask.Disconnected; + } + } + + if (eventDict[SelectMode.SelectRead].Contains(socket)) + { + if ((evnt.Data.InputEvents & PollEventTypeMask.Input) != 0) + { + outputEvents |= PollEventTypeMask.Input; + } + } + + if (eventDict[SelectMode.SelectWrite].Contains(socket)) + { + outputEvents |= PollEventTypeMask.Output; + } + + evnt.Data.OutputEvents = outputEvents; + } + + updatedCount = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count; + + return LinuxError.SUCCESS; + } + + public LinuxError Select(List events, int timeout, out int updatedCount) + { + Dictionary> eventDict = new() + { + { SelectMode.SelectRead, [] }, + { SelectMode.SelectWrite, [] }, + { SelectMode.SelectError, [] }, + }; + + updatedCount = 0; + + foreach (PollEvent pollEvent in events) + { + ManagedProxySocket socket = (ManagedProxySocket)pollEvent.FileDescriptor; + + if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Input)) + { + eventDict[SelectMode.SelectRead].Add(socket); + } + + if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Output)) + { + eventDict[SelectMode.SelectWrite].Add(socket); + } + + if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Error)) + { + eventDict[SelectMode.SelectError].Add(socket); + } + } + + int totalEvents = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count; + // TODO: Maybe check all events first, wait for the timeout and then check the failed ones again? + int timeoutMicrosecondsPerEvent = timeout == -1 ? -1 : timeout / totalEvents; + + foreach ((SelectMode selectMode, List eventList) in eventDict) + { + List newEventList = []; + + foreach (ManagedProxySocket eventSocket in eventList) + { + if (eventSocket.Poll(timeoutMicrosecondsPerEvent, selectMode)) + { + newEventList.Add(eventSocket); + } + } + + eventDict[selectMode] = newEventList; + } + + updatedCount = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count; + + foreach (PollEvent pollEvent in events) + { + ManagedProxySocket socket = (ManagedProxySocket)pollEvent.FileDescriptor; + + if (eventDict[SelectMode.SelectRead].Contains(socket)) + { + pollEvent.Data.OutputEvents |= PollEventTypeMask.Input; + } + + if (eventDict[SelectMode.SelectWrite].Contains(socket)) + { + pollEvent.Data.OutputEvents |= PollEventTypeMask.Output; + } + + if (eventDict[SelectMode.SelectError].Contains(socket)) + { + pollEvent.Data.OutputEvents |= PollEventTypeMask.Error; + } + } + + return LinuxError.SUCCESS; + } + } +} diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs index c42b7201bf..15c4743dad 100644 --- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs @@ -11,7 +11,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl { class ManagedSocket : ISocket { - public int Refcount { get; set; } + public int RefCount { get; set; } public AddressFamily AddressFamily => Socket.AddressFamily; @@ -32,57 +32,13 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl public ManagedSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) { Socket = new Socket(addressFamily, socketType, protocolType); - Refcount = 1; + RefCount = 1; } private ManagedSocket(Socket socket) { Socket = socket; - Refcount = 1; - } - - private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags) - { - SocketFlags socketFlags = SocketFlags.None; - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob)) - { - socketFlags |= SocketFlags.OutOfBand; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek)) - { - socketFlags |= SocketFlags.Peek; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute)) - { - socketFlags |= SocketFlags.DontRoute; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc)) - { - socketFlags |= SocketFlags.Truncated; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc)) - { - socketFlags |= SocketFlags.ControlDataTruncated; - } - - bsdSocketFlags &= ~(BsdSocketFlags.Oob | - BsdSocketFlags.Peek | - BsdSocketFlags.DontRoute | - BsdSocketFlags.DontWait | - BsdSocketFlags.Trunc | - BsdSocketFlags.CTrunc); - - if (bsdSocketFlags != BsdSocketFlags.None) - { - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}"); - } - - return socketFlags; + RefCount = 1; } public LinuxError Accept(out ISocket newSocket) @@ -199,7 +155,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl shouldBlockAfterOperation = true; } - receiveSize = Socket.Receive(buffer, ConvertBsdSocketFlags(flags)); + receiveSize = Socket.Receive(buffer, WinSockHelper.ConvertBsdSocketFlags(flags)); result = LinuxError.SUCCESS; } @@ -243,7 +199,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl return LinuxError.EOPNOTSUPP; } - receiveSize = Socket.ReceiveFrom(buffer[..size], ConvertBsdSocketFlags(flags), ref temp); + receiveSize = Socket.ReceiveFrom(buffer[..size], WinSockHelper.ConvertBsdSocketFlags(flags), ref temp); remoteEndPoint = (IPEndPoint)temp; result = LinuxError.SUCCESS; @@ -267,7 +223,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl { try { - sendSize = Socket.Send(buffer, ConvertBsdSocketFlags(flags)); + sendSize = Socket.Send(buffer, WinSockHelper.ConvertBsdSocketFlags(flags)); return LinuxError.SUCCESS; } @@ -283,7 +239,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl { try { - sendSize = Socket.SendTo(buffer[..size], ConvertBsdSocketFlags(flags), remoteEndPoint); + sendSize = Socket.SendTo(buffer[..size], WinSockHelper.ConvertBsdSocketFlags(flags), remoteEndPoint); return LinuxError.SUCCESS; } @@ -493,7 +449,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl try { - int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError); + int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), WinSockHelper.ConvertBsdSocketFlags(flags), out SocketError socketError); if (receiveSize > 0) { @@ -531,7 +487,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl try { - int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError); + int sendSize = Socket.Send(ConvertMessagesToBuffer(message), WinSockHelper.ConvertBsdSocketFlags(flags), out SocketError socketError); if (sendSize > 0) { diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs index e2ef75f807..c41a565fd9 100644 --- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs @@ -1,3 +1,4 @@ +using Ryujinx.Common.Logging; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types; using System; using System.Collections.Generic; @@ -343,5 +344,49 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl return LinuxError.SUCCESS; } + + public static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags) + { + SocketFlags socketFlags = SocketFlags.None; + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob)) + { + socketFlags |= SocketFlags.OutOfBand; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek)) + { + socketFlags |= SocketFlags.Peek; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute)) + { + socketFlags |= SocketFlags.DontRoute; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc)) + { + socketFlags |= SocketFlags.Truncated; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc)) + { + socketFlags |= SocketFlags.ControlDataTruncated; + } + + bsdSocketFlags &= ~(BsdSocketFlags.Oob | + BsdSocketFlags.Peek | + BsdSocketFlags.DontRoute | + BsdSocketFlags.DontWait | + BsdSocketFlags.Trunc | + BsdSocketFlags.CTrunc); + + if (bsdSocketFlags != BsdSocketFlags.None) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}"); + } + + return socketFlags; + } } } diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Proxy/ProxyManager.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Proxy/ProxyManager.cs new file mode 100644 index 0000000000..db827500cc --- /dev/null +++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Proxy/ProxyManager.cs @@ -0,0 +1,53 @@ +using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; +using System.Runtime.CompilerServices; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy +{ + public static class ProxyManager + { + private static readonly ConcurrentDictionary _proxyEndpoints = new(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static string GetKey(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + return string.Join("-", new[] { (int)addressFamily, (int)socketType, (int)protocolType }); + } + + internal static ISocket GetSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + if (_proxyEndpoints.TryGetValue(GetKey(addressFamily, socketType, protocolType), out EndPoint endPoint)) + { + return new ManagedProxySocket(addressFamily, socketType, protocolType, endPoint); + } + + return new ManagedSocket(addressFamily, socketType, protocolType); + } + + public static void AddOrUpdate(EndPoint endPoint, + AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + _proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = endPoint; + } + + public static void AddOrUpdate(IPAddress address, int port, + AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + _proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = new IPEndPoint(address, port); + } + + public static void AddOrUpdate(string host, int port, + AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + _proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = new DnsEndPoint(host, port); + } + + public static bool Remove(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + return _proxyEndpoints.Remove(GetKey(addressFamily, socketType, protocolType), out _); + } + } +} diff --git a/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs b/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs index 8cc761baf5..d9d06fe96b 100644 --- a/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs +++ b/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs @@ -1,6 +1,7 @@ using Ryujinx.HLE.HOS.Services.Sockets.Bsd; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl; using Ryujinx.HLE.HOS.Services.Ssl.Types; +using RyuSocks; using System; using System.IO; using System.Net; @@ -111,12 +112,25 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService { return hostName; } + // Thrown by ManagedProxySocket when accessing RemoteEndPoint before connecting to a remote. + catch (NullReferenceException) + { + return hostName; + } } public ResultCode Handshake(string hostName) { StartSslOperation(); - _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null); + + Stream socketStream = Socket switch + { + ManagedSocket managedSocket => new NetworkStream(managedSocket.Socket, false), + ManagedProxySocket proxySocket => new SocksClientStream(proxySocket.ProxyClient, false), + _ => throw new NotSupportedException($"{typeof(Socket)} is not supported.") + }; + + _stream = new SslStream(socketStream, false, null, null); hostName = RetrieveHostName(hostName); _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false); EndSslOperation(); diff --git a/src/Ryujinx.HLE/Ryujinx.HLE.csproj b/src/Ryujinx.HLE/Ryujinx.HLE.csproj index a7bb3cd7f6..d3f7db0381 100644 --- a/src/Ryujinx.HLE/Ryujinx.HLE.csproj +++ b/src/Ryujinx.HLE/Ryujinx.HLE.csproj @@ -29,6 +29,7 @@ +