diff --git a/Directory.Packages.props b/Directory.Packages.props
index 301024cf8a..6dcf34f1d1 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -36,6 +36,7 @@
+
@@ -49,4 +50,4 @@
-
\ No newline at end of file
+
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs
index 7ecd6835db..b03fba375c 100644
--- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs
@@ -103,7 +103,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
lock (_lock)
{
- oldFile.Refcount++;
+ oldFile.RefCount++;
return RegisterFileDescriptor(oldFile);
}
@@ -118,9 +118,9 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
if (file != null)
{
- file.Refcount--;
+ file.RefCount--;
- if (file.Refcount <= 0)
+ if (file.RefCount <= 0)
{
file.Dispose();
}
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
index 21d48288ec..ac989be927 100644
--- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
@@ -1,6 +1,7 @@
using Ryujinx.Common;
using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using Ryujinx.Memory;
using System;
@@ -20,6 +21,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
EventFileDescriptorPollManager.Instance,
ManagedSocketPollManager.Instance,
+ ManagedProxySocketPollManager.Instance,
};
private BsdContext _context;
@@ -95,10 +97,8 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
}
}
- ISocket newBsdSocket = new ManagedSocket(netDomain, (SocketType)type, protocol)
- {
- Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking),
- };
+ ISocket newBsdSocket = ProxyManager.GetSocket(netDomain, (SocketType)type, protocol);
+ newBsdSocket.Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking);
LinuxError errno = LinuxError.SUCCESS;
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs
index 6c00d5e118..4ad5806787 100644
--- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs
@@ -6,7 +6,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
interface IFileDescriptor : IDisposable
{
bool Blocking { get; set; }
- int Refcount { get; set; }
+ int RefCount { get; set; }
LinuxError Read(out int readSize, Span buffer);
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs
index 5b9e6811d3..7c599335f1 100644
--- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs
@@ -32,7 +32,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
UpdateEventStates();
}
- public int Refcount { get; set; }
+ public int RefCount { get; set; }
public void Dispose()
{
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocket.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocket.cs
new file mode 100644
index 0000000000..df66b83f29
--- /dev/null
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocket.cs
@@ -0,0 +1,476 @@
+using Ryujinx.Common.Logging;
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
+using RyuSocks;
+using RyuSocks.Auth;
+using RyuSocks.Commands;
+using RyuSocks.Types;
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Sockets;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
+{
+ class ManagedProxySocket : ISocket
+ {
+ private static readonly Dictionary _authMethods = new()
+ {
+ { AuthMethod.NoAuth, new NoAuth() },
+ };
+
+ private readonly bool _isUdpSocket;
+ private readonly bool _acceptedConnection;
+
+ public SocksClient ProxyClient { get; }
+
+ public bool Blocking { get => ProxyClient.Blocking; set => ProxyClient.Blocking = value; }
+ public int RefCount { get; set; }
+
+ public IPEndPoint RemoteEndPoint => (IPEndPoint)ProxyClient.ProxiedRemoteEndPoint;
+ public IPEndPoint LocalEndPoint => (IPEndPoint)ProxyClient.ProxiedLocalEndPoint;
+
+ public AddressFamily AddressFamily => ProxyClient.AddressFamily;
+ public SocketType SocketType => ProxyClient.SocketType;
+ public ProtocolType ProtocolType => ProxyClient.ProtocolType;
+ public IntPtr Handle => throw new NotSupportedException("Can't get the handle of a proxy socket.");
+
+ public ManagedProxySocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, EndPoint proxyEndpoint)
+ {
+ if (addressFamily != proxyEndpoint.AddressFamily && addressFamily != AddressFamily.Unspecified)
+ {
+ throw new ArgumentException(
+ $"Invalid {nameof(System.Net.Sockets.AddressFamily)}", nameof(addressFamily));
+ }
+
+ if (socketType != SocketType.Stream && socketType != SocketType.Dgram)
+ {
+ throw new ArgumentException(
+ $"Invalid {nameof(System.Net.Sockets.SocketType)}", nameof(socketType));
+ }
+
+ if (protocolType != ProtocolType.Tcp && protocolType != ProtocolType.Udp)
+ {
+ throw new ArgumentException(
+ $"Invalid {nameof(System.Net.Sockets.ProtocolType)}", nameof(protocolType));
+ }
+
+ _isUdpSocket = socketType == SocketType.Dgram && protocolType == ProtocolType.Udp;
+
+ ProxyClient = proxyEndpoint switch
+ {
+ IPEndPoint ipEndPoint => new SocksClient(ipEndPoint) { OfferedAuthMethods = _authMethods },
+ DnsEndPoint dnsEndPoint => new SocksClient(dnsEndPoint) { OfferedAuthMethods = _authMethods },
+ _ => throw new ArgumentException($"Unsupported {nameof(EndPoint)} type", nameof(proxyEndpoint))
+ };
+
+ ProxyClient.Authenticate();
+
+ RefCount = 1;
+ }
+
+ private ManagedProxySocket(SocksClient proxyClient)
+ {
+ ProxyClient = proxyClient;
+ _acceptedConnection = true;
+ RefCount = 1;
+ }
+
+ private static LinuxError ToLinuxError(ReplyField proxyReply)
+ {
+ return proxyReply switch
+ {
+ ReplyField.Succeeded => LinuxError.SUCCESS,
+ ReplyField.ServerFailure => LinuxError.ECONNRESET,
+ ReplyField.ConnectionNotAllowed => LinuxError.ECONNREFUSED,
+ ReplyField.NetworkUnreachable => LinuxError.ENETUNREACH,
+ ReplyField.HostUnreachable => LinuxError.EHOSTUNREACH,
+ ReplyField.ConnectionRefused => LinuxError.ECONNREFUSED,
+ ReplyField.TTLExpired => LinuxError.EHOSTUNREACH,
+ ReplyField.CommandNotSupported => LinuxError.EOPNOTSUPP,
+ ReplyField.AddressTypeNotSupported => LinuxError.EAFNOSUPPORT,
+ _ => throw new ArgumentOutOfRangeException(nameof(proxyReply))
+ };
+ }
+
+ public void Dispose()
+ {
+ ProxyClient.Dispose();
+ }
+
+ public LinuxError Read(out int readSize, Span buffer)
+ {
+ return Receive(out readSize, buffer, BsdSocketFlags.None);
+ }
+
+ public LinuxError Write(out int writeSize, ReadOnlySpan buffer)
+ {
+ return Send(out writeSize, buffer, BsdSocketFlags.None);
+ }
+
+ public LinuxError Receive(out int receiveSize, Span buffer, BsdSocketFlags flags)
+ {
+ LinuxError result;
+ bool shouldBlockAfterOperation = false;
+
+ if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait))
+ {
+ Blocking = false;
+ shouldBlockAfterOperation = true;
+ }
+
+ byte[] proxyBuffer = new byte[buffer.Length + ProxyClient.GetRequiredWrapperSpace()];
+
+ try
+ {
+ receiveSize = ProxyClient.Receive(
+ proxyBuffer,
+ WinSockHelper.ConvertBsdSocketFlags(flags),
+ out SocketError errorCode
+ );
+
+ proxyBuffer[..receiveSize].CopyTo(buffer);
+
+ result = WinSockHelper.ConvertError((WsaError)errorCode);
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"An error occured while trying to receive data: {exception}"
+ );
+
+ receiveSize = -1;
+ result = ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ receiveSize = -1;
+ result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+
+ if (shouldBlockAfterOperation)
+ {
+ Blocking = true;
+ }
+
+ return result;
+ }
+
+ public LinuxError ReceiveFrom(out int receiveSize, Span buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint)
+ {
+ LinuxError result;
+ remoteEndPoint = new IPEndPoint(IPAddress.Any, 0);
+ bool shouldBlockAfterOperation = false;
+
+ byte[] proxyBuffer = new byte[size + ProxyClient.GetRequiredWrapperSpace()];
+
+ if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait))
+ {
+ Blocking = false;
+ shouldBlockAfterOperation = true;
+ }
+
+ try
+ {
+ EndPoint temp = new IPEndPoint(IPAddress.Any, 0);
+
+ receiveSize = ProxyClient.ReceiveFrom(proxyBuffer, WinSockHelper.ConvertBsdSocketFlags(flags), ref temp);
+
+ proxyBuffer[..receiveSize].CopyTo(buffer);
+
+ remoteEndPoint = (IPEndPoint)temp;
+ result = LinuxError.SUCCESS;
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"An error occured while trying to receive data: {exception}"
+ );
+
+ receiveSize = -1;
+ result = ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ receiveSize = -1;
+
+ result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+
+ if (shouldBlockAfterOperation)
+ {
+ Blocking = true;
+ }
+
+ return result;
+ }
+
+ public LinuxError Send(out int sendSize, ReadOnlySpan buffer, BsdSocketFlags flags)
+ {
+ try
+ {
+ sendSize = ProxyClient.Send(buffer, WinSockHelper.ConvertBsdSocketFlags(flags));
+
+ return LinuxError.SUCCESS;
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"An error occured while trying to send data: {exception}"
+ );
+
+ sendSize = -1;
+
+ return ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ sendSize = -1;
+
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
+
+ public LinuxError SendTo(out int sendSize, ReadOnlySpan buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint)
+ {
+ try
+ {
+ // NOTE: sendSize might be larger than size and/or buffer.Length.
+ sendSize = ProxyClient.SendTo(buffer[..size], WinSockHelper.ConvertBsdSocketFlags(flags), remoteEndPoint);
+
+ return LinuxError.SUCCESS;
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"An error occured while trying to send data: {exception}"
+ );
+
+ sendSize = -1;
+
+ return ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ sendSize = -1;
+
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
+
+ public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout)
+ {
+ throw new NotImplementedException();
+ }
+
+ public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags)
+ {
+ throw new NotImplementedException();
+ }
+
+ ///
+ /// Adapted from
+ ///
+ public LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span optionValue)
+ {
+ try
+ {
+ LinuxError result = WinSockHelper.ValidateSocketOption(option, level, write: false);
+
+ if (result != LinuxError.SUCCESS)
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Invalid GetSockOpt Option: {option} Level: {level}");
+
+ return result;
+ }
+
+ if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName))
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}");
+ optionValue.Clear();
+
+ return LinuxError.SUCCESS;
+ }
+
+ byte[] tempOptionValue = new byte[optionValue.Length];
+
+ ProxyClient.GetSocketOption(level, optionName, tempOptionValue);
+
+ tempOptionValue.AsSpan().CopyTo(optionValue);
+
+ return LinuxError.SUCCESS;
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
+
+ ///
+ /// Adapted from
+ ///
+ public LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan optionValue)
+ {
+ try
+ {
+ LinuxError result = WinSockHelper.ValidateSocketOption(option, level, write: true);
+
+ if (result != LinuxError.SUCCESS)
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Invalid SetSockOpt Option: {option} Level: {level}");
+
+ return result;
+ }
+
+ if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName))
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}");
+
+ return LinuxError.SUCCESS;
+ }
+
+ byte[] value = optionValue.ToArray();
+
+ ProxyClient.SetSocketOption(level, optionName, value);
+
+ return LinuxError.SUCCESS;
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
+
+ public bool Poll(int microSeconds, SelectMode mode)
+ {
+ return ProxyClient.Poll(microSeconds, mode);
+ }
+
+ public LinuxError Bind(IPEndPoint localEndPoint)
+ {
+ ProxyClient.RequestCommand = _isUdpSocket ? ProxyCommand.UdpAssociate : ProxyCommand.Bind;
+
+ try
+ {
+ ProxyClient.Bind(localEndPoint);
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"Request for {ProxyClient.RequestCommand} command failed: {exception}"
+ );
+
+ return ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+
+ return LinuxError.SUCCESS;
+ }
+
+ public LinuxError Connect(IPEndPoint remoteEndPoint)
+ {
+ ProxyClient.RequestCommand = ProxyCommand.Connect;
+
+ try
+ {
+ ProxyClient.Connect(remoteEndPoint.Address, remoteEndPoint.Port);
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"Request for {ProxyClient.RequestCommand} command failed: {exception}"
+ );
+
+ return ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+
+ return LinuxError.SUCCESS;
+ }
+
+ public LinuxError Listen(int backlog)
+ {
+ // NOTE: Only one client can connect with the default SOCKS5 commands.
+ if (ProxyClient.RequestCommand != ProxyCommand.Bind)
+ {
+ return LinuxError.EOPNOTSUPP;
+ }
+
+ return LinuxError.SUCCESS;
+ }
+
+ public LinuxError Accept(out ISocket newSocket)
+ {
+ newSocket = null;
+
+ if (ProxyClient.RequestCommand != ProxyCommand.Bind)
+ {
+ return LinuxError.EOPNOTSUPP;
+ }
+
+ // NOTE: Only one client can connect with the default SOCKS5 commands.
+ if (_acceptedConnection)
+ {
+ return LinuxError.EOPNOTSUPP;
+ }
+
+ try
+ {
+ SocksClient newProxyClient = ProxyClient.Accept();
+ newSocket = new ManagedProxySocket(newProxyClient);
+ }
+ catch (ProxyException exception)
+ {
+ Logger.Error?.Print(
+ LogClass.ServiceBsd,
+ $"Failed to accept client connection: {exception}"
+ );
+
+ return ToLinuxError(exception.ReplyCode);
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+
+ return LinuxError.SUCCESS;
+ }
+
+ public void Disconnect()
+ {
+ ProxyClient.Disconnect();
+ ProxyClient.RequestCommand = 0;
+ }
+
+ public LinuxError Shutdown(BsdSocketShutdownFlags how)
+ {
+ try
+ {
+ ProxyClient.Shutdown((SocketShutdown)how);
+
+ return LinuxError.SUCCESS;
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+ }
+
+ public void Close()
+ {
+ ProxyClient.Close();
+ ProxyClient.RequestCommand = 0;
+ }
+ }
+}
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocketPollManager.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocketPollManager.cs
new file mode 100644
index 0000000000..02e6dad487
--- /dev/null
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedProxySocketPollManager.cs
@@ -0,0 +1,213 @@
+using Ryujinx.Common.Logging;
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
+using System.Collections.Generic;
+using System.Net.Sockets;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
+{
+ class ManagedProxySocketPollManager : IPollManager
+ {
+ private static ManagedProxySocketPollManager _instance;
+
+ public static ManagedProxySocketPollManager Instance
+ {
+ get
+ {
+ _instance ??= new ManagedProxySocketPollManager();
+
+ return _instance;
+ }
+ }
+
+ public bool IsCompatible(PollEvent evnt)
+ {
+ return evnt.FileDescriptor is ManagedProxySocket;
+ }
+
+ public LinuxError Poll(List events, int timeoutMilliseconds, out int updatedCount)
+ {
+ Dictionary> eventDict = new()
+ {
+ { SelectMode.SelectRead, [] },
+ { SelectMode.SelectWrite, [] },
+ { SelectMode.SelectError, [] },
+ };
+
+ updatedCount = 0;
+
+ foreach (PollEvent evnt in events)
+ {
+ ManagedProxySocket socket = (ManagedProxySocket)evnt.FileDescriptor;
+
+ bool isValidEvent = evnt.Data.InputEvents == 0;
+
+ eventDict[SelectMode.SelectError].Add(socket);
+
+ if ((evnt.Data.InputEvents & PollEventTypeMask.Input) != 0)
+ {
+ eventDict[SelectMode.SelectRead].Add(socket);
+
+ isValidEvent = true;
+ }
+
+ if ((evnt.Data.InputEvents & PollEventTypeMask.UrgentInput) != 0)
+ {
+ eventDict[SelectMode.SelectRead].Add(socket);
+
+ isValidEvent = true;
+ }
+
+ if ((evnt.Data.InputEvents & PollEventTypeMask.Output) != 0)
+ {
+ eventDict[SelectMode.SelectWrite].Add(socket);
+
+ isValidEvent = true;
+ }
+
+ if (!isValidEvent)
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported Poll input event type: {evnt.Data.InputEvents}");
+ return LinuxError.EINVAL;
+ }
+ }
+
+ try
+ {
+ int actualTimeoutMicroseconds = timeoutMilliseconds == -1 ? -1 : timeoutMilliseconds * 1000;
+ int totalEvents = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count;
+ // TODO: Maybe check all events first, wait for the timeout and then check the failed ones again?
+ int timeoutMicrosecondsPerEvent = actualTimeoutMicroseconds == -1 ? -1 : actualTimeoutMicroseconds / totalEvents;
+
+ foreach ((SelectMode selectMode, List eventList) in eventDict)
+ {
+ List newEventList = [];
+
+ foreach (ManagedProxySocket eventSocket in eventList)
+ {
+ if (eventSocket.Poll(timeoutMicrosecondsPerEvent, selectMode))
+ {
+ newEventList.Add(eventSocket);
+ }
+ }
+
+ eventDict[selectMode] = newEventList;
+ }
+ }
+ catch (SocketException exception)
+ {
+ return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
+ }
+
+ foreach (PollEvent evnt in events)
+ {
+ ManagedProxySocket socket = ((ManagedProxySocket)evnt.FileDescriptor);
+
+ PollEventTypeMask outputEvents = evnt.Data.OutputEvents & ~evnt.Data.InputEvents;
+
+ if (eventDict[SelectMode.SelectError].Contains(socket))
+ {
+ outputEvents |= PollEventTypeMask.Error;
+
+ if (!socket.ProxyClient.Connected || !socket.ProxyClient.IsBound)
+ {
+ outputEvents |= PollEventTypeMask.Disconnected;
+ }
+ }
+
+ if (eventDict[SelectMode.SelectRead].Contains(socket))
+ {
+ if ((evnt.Data.InputEvents & PollEventTypeMask.Input) != 0)
+ {
+ outputEvents |= PollEventTypeMask.Input;
+ }
+ }
+
+ if (eventDict[SelectMode.SelectWrite].Contains(socket))
+ {
+ outputEvents |= PollEventTypeMask.Output;
+ }
+
+ evnt.Data.OutputEvents = outputEvents;
+ }
+
+ updatedCount = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count;
+
+ return LinuxError.SUCCESS;
+ }
+
+ public LinuxError Select(List events, int timeout, out int updatedCount)
+ {
+ Dictionary> eventDict = new()
+ {
+ { SelectMode.SelectRead, [] },
+ { SelectMode.SelectWrite, [] },
+ { SelectMode.SelectError, [] },
+ };
+
+ updatedCount = 0;
+
+ foreach (PollEvent pollEvent in events)
+ {
+ ManagedProxySocket socket = (ManagedProxySocket)pollEvent.FileDescriptor;
+
+ if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Input))
+ {
+ eventDict[SelectMode.SelectRead].Add(socket);
+ }
+
+ if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Output))
+ {
+ eventDict[SelectMode.SelectWrite].Add(socket);
+ }
+
+ if (pollEvent.Data.InputEvents.HasFlag(PollEventTypeMask.Error))
+ {
+ eventDict[SelectMode.SelectError].Add(socket);
+ }
+ }
+
+ int totalEvents = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count;
+ // TODO: Maybe check all events first, wait for the timeout and then check the failed ones again?
+ int timeoutMicrosecondsPerEvent = timeout == -1 ? -1 : timeout / totalEvents;
+
+ foreach ((SelectMode selectMode, List eventList) in eventDict)
+ {
+ List newEventList = [];
+
+ foreach (ManagedProxySocket eventSocket in eventList)
+ {
+ if (eventSocket.Poll(timeoutMicrosecondsPerEvent, selectMode))
+ {
+ newEventList.Add(eventSocket);
+ }
+ }
+
+ eventDict[selectMode] = newEventList;
+ }
+
+ updatedCount = eventDict[SelectMode.SelectRead].Count + eventDict[SelectMode.SelectWrite].Count + eventDict[SelectMode.SelectError].Count;
+
+ foreach (PollEvent pollEvent in events)
+ {
+ ManagedProxySocket socket = (ManagedProxySocket)pollEvent.FileDescriptor;
+
+ if (eventDict[SelectMode.SelectRead].Contains(socket))
+ {
+ pollEvent.Data.OutputEvents |= PollEventTypeMask.Input;
+ }
+
+ if (eventDict[SelectMode.SelectWrite].Contains(socket))
+ {
+ pollEvent.Data.OutputEvents |= PollEventTypeMask.Output;
+ }
+
+ if (eventDict[SelectMode.SelectError].Contains(socket))
+ {
+ pollEvent.Data.OutputEvents |= PollEventTypeMask.Error;
+ }
+ }
+
+ return LinuxError.SUCCESS;
+ }
+ }
+}
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
index c42b7201bf..15c4743dad 100644
--- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
@@ -11,7 +11,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class ManagedSocket : ISocket
{
- public int Refcount { get; set; }
+ public int RefCount { get; set; }
public AddressFamily AddressFamily => Socket.AddressFamily;
@@ -32,57 +32,13 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
public ManagedSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
Socket = new Socket(addressFamily, socketType, protocolType);
- Refcount = 1;
+ RefCount = 1;
}
private ManagedSocket(Socket socket)
{
Socket = socket;
- Refcount = 1;
- }
-
- private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags)
- {
- SocketFlags socketFlags = SocketFlags.None;
-
- if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob))
- {
- socketFlags |= SocketFlags.OutOfBand;
- }
-
- if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek))
- {
- socketFlags |= SocketFlags.Peek;
- }
-
- if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute))
- {
- socketFlags |= SocketFlags.DontRoute;
- }
-
- if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc))
- {
- socketFlags |= SocketFlags.Truncated;
- }
-
- if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc))
- {
- socketFlags |= SocketFlags.ControlDataTruncated;
- }
-
- bsdSocketFlags &= ~(BsdSocketFlags.Oob |
- BsdSocketFlags.Peek |
- BsdSocketFlags.DontRoute |
- BsdSocketFlags.DontWait |
- BsdSocketFlags.Trunc |
- BsdSocketFlags.CTrunc);
-
- if (bsdSocketFlags != BsdSocketFlags.None)
- {
- Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}");
- }
-
- return socketFlags;
+ RefCount = 1;
}
public LinuxError Accept(out ISocket newSocket)
@@ -199,7 +155,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
shouldBlockAfterOperation = true;
}
- receiveSize = Socket.Receive(buffer, ConvertBsdSocketFlags(flags));
+ receiveSize = Socket.Receive(buffer, WinSockHelper.ConvertBsdSocketFlags(flags));
result = LinuxError.SUCCESS;
}
@@ -243,7 +199,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
return LinuxError.EOPNOTSUPP;
}
- receiveSize = Socket.ReceiveFrom(buffer[..size], ConvertBsdSocketFlags(flags), ref temp);
+ receiveSize = Socket.ReceiveFrom(buffer[..size], WinSockHelper.ConvertBsdSocketFlags(flags), ref temp);
remoteEndPoint = (IPEndPoint)temp;
result = LinuxError.SUCCESS;
@@ -267,7 +223,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
try
{
- sendSize = Socket.Send(buffer, ConvertBsdSocketFlags(flags));
+ sendSize = Socket.Send(buffer, WinSockHelper.ConvertBsdSocketFlags(flags));
return LinuxError.SUCCESS;
}
@@ -283,7 +239,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
try
{
- sendSize = Socket.SendTo(buffer[..size], ConvertBsdSocketFlags(flags), remoteEndPoint);
+ sendSize = Socket.SendTo(buffer[..size], WinSockHelper.ConvertBsdSocketFlags(flags), remoteEndPoint);
return LinuxError.SUCCESS;
}
@@ -493,7 +449,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
try
{
- int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
+ int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), WinSockHelper.ConvertBsdSocketFlags(flags), out SocketError socketError);
if (receiveSize > 0)
{
@@ -531,7 +487,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
try
{
- int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
+ int sendSize = Socket.Send(ConvertMessagesToBuffer(message), WinSockHelper.ConvertBsdSocketFlags(flags), out SocketError socketError);
if (sendSize > 0)
{
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs
index e2ef75f807..c41a565fd9 100644
--- a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs
@@ -1,3 +1,4 @@
+using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.Collections.Generic;
@@ -343,5 +344,49 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
return LinuxError.SUCCESS;
}
+
+ public static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags)
+ {
+ SocketFlags socketFlags = SocketFlags.None;
+
+ if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob))
+ {
+ socketFlags |= SocketFlags.OutOfBand;
+ }
+
+ if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek))
+ {
+ socketFlags |= SocketFlags.Peek;
+ }
+
+ if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute))
+ {
+ socketFlags |= SocketFlags.DontRoute;
+ }
+
+ if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc))
+ {
+ socketFlags |= SocketFlags.Truncated;
+ }
+
+ if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc))
+ {
+ socketFlags |= SocketFlags.ControlDataTruncated;
+ }
+
+ bsdSocketFlags &= ~(BsdSocketFlags.Oob |
+ BsdSocketFlags.Peek |
+ BsdSocketFlags.DontRoute |
+ BsdSocketFlags.DontWait |
+ BsdSocketFlags.Trunc |
+ BsdSocketFlags.CTrunc);
+
+ if (bsdSocketFlags != BsdSocketFlags.None)
+ {
+ Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}");
+ }
+
+ return socketFlags;
+ }
}
}
diff --git a/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Proxy/ProxyManager.cs b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Proxy/ProxyManager.cs
new file mode 100644
index 0000000000..db827500cc
--- /dev/null
+++ b/src/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Proxy/ProxyManager.cs
@@ -0,0 +1,53 @@
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Sockets;
+using System.Runtime.CompilerServices;
+
+namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
+{
+ public static class ProxyManager
+ {
+ private static readonly ConcurrentDictionary _proxyEndpoints = new();
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static string GetKey(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+ {
+ return string.Join("-", new[] { (int)addressFamily, (int)socketType, (int)protocolType });
+ }
+
+ internal static ISocket GetSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+ {
+ if (_proxyEndpoints.TryGetValue(GetKey(addressFamily, socketType, protocolType), out EndPoint endPoint))
+ {
+ return new ManagedProxySocket(addressFamily, socketType, protocolType, endPoint);
+ }
+
+ return new ManagedSocket(addressFamily, socketType, protocolType);
+ }
+
+ public static void AddOrUpdate(EndPoint endPoint,
+ AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+ {
+ _proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = endPoint;
+ }
+
+ public static void AddOrUpdate(IPAddress address, int port,
+ AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+ {
+ _proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = new IPEndPoint(address, port);
+ }
+
+ public static void AddOrUpdate(string host, int port,
+ AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+ {
+ _proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = new DnsEndPoint(host, port);
+ }
+
+ public static bool Remove(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
+ {
+ return _proxyEndpoints.Remove(GetKey(addressFamily, socketType, protocolType), out _);
+ }
+ }
+}
diff --git a/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs b/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
index 8cc761baf5..d9d06fe96b 100644
--- a/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
+++ b/src/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
@@ -1,6 +1,7 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
+using RyuSocks;
using System;
using System.IO;
using System.Net;
@@ -111,12 +112,25 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
return hostName;
}
+ // Thrown by ManagedProxySocket when accessing RemoteEndPoint before connecting to a remote.
+ catch (NullReferenceException)
+ {
+ return hostName;
+ }
}
public ResultCode Handshake(string hostName)
{
StartSslOperation();
- _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
+
+ Stream socketStream = Socket switch
+ {
+ ManagedSocket managedSocket => new NetworkStream(managedSocket.Socket, false),
+ ManagedProxySocket proxySocket => new SocksClientStream(proxySocket.ProxyClient, false),
+ _ => throw new NotSupportedException($"{typeof(Socket)} is not supported.")
+ };
+
+ _stream = new SslStream(socketStream, false, null, null);
hostName = RetrieveHostName(hostName);
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
EndSslOperation();
diff --git a/src/Ryujinx.HLE/Ryujinx.HLE.csproj b/src/Ryujinx.HLE/Ryujinx.HLE.csproj
index a7bb3cd7f6..d3f7db0381 100644
--- a/src/Ryujinx.HLE/Ryujinx.HLE.csproj
+++ b/src/Ryujinx.HLE/Ryujinx.HLE.csproj
@@ -29,6 +29,7 @@
+