From 9142aca48ff09ed32954eceb3456a255d61945b7 Mon Sep 17 00:00:00 2001
From: Thomas Guillemard <me@thog.eu>
Date: Fri, 11 Oct 2019 17:22:24 +0200
Subject: [PATCH] Fix hwopus DecodeInterleaved implementation (#786)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Fix hwopus DecodeInterleaved implementation

Also implement new variants of this api.

This should fix #763

* Sample rate shouldn't be hardcoded

This fix issues while opening Pokémon Let's Go pause menu.

* Apply Ac_K's suggestion about EndianSwap

* Address gdkchan's comment

* Address Ac_k's comment
---
 .../Utilities/EndianSwap.cs                   |  18 +-
 Ryujinx.HLE/HOS/Font/SharedFontManager.cs     |   2 +-
 .../IHardwareOpusDecoder.cs                   | 243 ++++++++++++++----
 .../Services/Audio/Types/OpusPacketHeader.cs  |  24 ++
 .../HOS/Services/Sockets/Bsd/IClient.cs       |   3 +-
 5 files changed, 240 insertions(+), 50 deletions(-)
 rename {Ryujinx.HLE => Ryujinx.Common}/Utilities/EndianSwap.cs (55%)
 create mode 100644 Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs

diff --git a/Ryujinx.HLE/Utilities/EndianSwap.cs b/Ryujinx.Common/Utilities/EndianSwap.cs
similarity index 55%
rename from Ryujinx.HLE/Utilities/EndianSwap.cs
rename to Ryujinx.Common/Utilities/EndianSwap.cs
index df08191ac6..049570e32a 100644
--- a/Ryujinx.HLE/Utilities/EndianSwap.cs
+++ b/Ryujinx.Common/Utilities/EndianSwap.cs
@@ -1,6 +1,8 @@
-namespace Ryujinx.HLE.Utilities
+using System;
+
+namespace Ryujinx.Common
 {
-    static class EndianSwap
+    public static class EndianSwap
     {
         public static ushort Swap16(ushort value) => (ushort)(((value >> 8) & 0xff) | (value << 8));
 
@@ -13,5 +15,17 @@
                          ((uintVal <<  8) & 0x00ff0000) |
                          ((uintVal << 24) & 0xff000000));
         }
+
+        public static uint FromBigEndianToPlatformEndian(uint value)
+        {
+            uint result = value;
+
+            if (BitConverter.IsLittleEndian)
+            {
+                result = (uint)EndianSwap.Swap32((int)result);
+            }
+
+            return result;
+        }
     }
 }
diff --git a/Ryujinx.HLE/HOS/Font/SharedFontManager.cs b/Ryujinx.HLE/HOS/Font/SharedFontManager.cs
index dfb87f3c94..8a936dbf55 100644
--- a/Ryujinx.HLE/HOS/Font/SharedFontManager.cs
+++ b/Ryujinx.HLE/HOS/Font/SharedFontManager.cs
@@ -1,9 +1,9 @@
 using LibHac.Fs;
 using LibHac.Fs.NcaUtils;
+using Ryujinx.Common;
 using Ryujinx.HLE.FileSystem;
 using Ryujinx.HLE.FileSystem.Content;
 using Ryujinx.HLE.Resource;
-using Ryujinx.HLE.Utilities;
 using System.Collections.Generic;
 using System.IO;
 using static Ryujinx.HLE.Utilities.FontUtils;
diff --git a/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs b/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs
index e23398dfe8..079f2ae75e 100644
--- a/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs
+++ b/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs
@@ -1,13 +1,18 @@
+using Concentus;
+using Concentus.Enums;
 using Concentus.Structs;
+using Ryujinx.HLE.HOS.Services.Audio.Types;
+using System;
+using System.IO;
+using System.Runtime.InteropServices;
 
 namespace Ryujinx.HLE.HOS.Services.Audio.HardwareOpusDecoderManager
 {
     class IHardwareOpusDecoder : IpcService
     {
-        private const int FixedSampleRate = 48000;
-
-        private int _sampleRate;
-        private int _channelsCount;
+        private int  _sampleRate;
+        private int  _channelsCount;
+        private bool _reset;
 
         private OpusDecoder _decoder;
 
@@ -15,65 +20,211 @@ namespace Ryujinx.HLE.HOS.Services.Audio.HardwareOpusDecoderManager
         {
             _sampleRate    = sampleRate;
             _channelsCount = channelsCount;
+            _reset         = false;
 
-            _decoder = new OpusDecoder(FixedSampleRate, channelsCount);
+            _decoder = new OpusDecoder(sampleRate, channelsCount);
         }
 
-        [Command(0)]
-        // DecodeInterleaved(buffer<unknown, 5>) -> (u32, u32, buffer<unknown, 6>)
-        public ResultCode DecodeInterleaved(ServiceCtx context)
+        private ResultCode GetPacketNumSamples(out int numSamples, byte[] packet)
         {
-            long inPosition = context.Request.SendBuff[0].Position;
-            long inSize     = context.Request.SendBuff[0].Size;
+            int result = OpusPacketInfo.GetNumSamples(_decoder, packet, 0, packet.Length);
 
-            if (inSize < 8)
+            numSamples = result;
+
+            if (result == OpusError.OPUS_INVALID_PACKET)
             {
                 return ResultCode.OpusInvalidInput;
             }
-
-            long outPosition = context.Request.ReceiveBuff[0].Position;
-            long outSize     = context.Request.ReceiveBuff[0].Size;
-
-            byte[] opusData = context.Memory.ReadBytes(inPosition, inSize);
-
-            int processed = ((opusData[0] << 24) |
-                             (opusData[1] << 16) |
-                             (opusData[2] << 8)  |
-                             (opusData[3] << 0)) + 8;
-
-            if ((uint)processed > (ulong)inSize)
+            else if (result == OpusError.OPUS_BAD_ARG)
             {
                 return ResultCode.OpusInvalidInput;
             }
 
-            short[] pcm = new short[outSize / 2];
-
-            int frameSize = pcm.Length / (_channelsCount * 2);
-
-            int samples = _decoder.Decode(opusData, 0, opusData.Length, pcm, 0, frameSize);
-
-            foreach (short sample in pcm)
-            {
-                context.Memory.WriteInt16(outPosition, sample);
-
-                outPosition += 2;
-            }
-
-            context.ResponseData.Write(processed);
-            context.ResponseData.Write(samples);
-
             return ResultCode.Success;
         }
 
-        [Command(4)]
-        // DecodeInterleavedWithPerf(buffer<unknown, 5>) -> (u32, u32, u64, buffer<unknown, 0x46>)
-        public ResultCode DecodeInterleavedWithPerf(ServiceCtx context)
+        private ResultCode DecodeInterleavedInternal(BinaryReader input, out short[] outPcmData, long outputSize, out uint outConsumed, out int outSamples)
         {
-            ResultCode result = DecodeInterleaved(context);
+            outPcmData  = null;
+            outConsumed = 0;
+            outSamples  = 0;
 
-            // TODO: Figure out what this value is.
-            // According to switchbrew, it is now used.
-            context.ResponseData.Write(0L);
+            long streamSize = input.BaseStream.Length;
+
+            if (streamSize < Marshal.SizeOf<OpusPacketHeader>())
+            {
+                return ResultCode.OpusInvalidInput;
+            }
+
+            OpusPacketHeader header = OpusPacketHeader.FromStream(input);
+
+            uint totalSize = header.length + (uint)Marshal.SizeOf<OpusPacketHeader>();
+
+            if (totalSize > streamSize)
+            {
+                return ResultCode.OpusInvalidInput;
+            }
+
+            byte[] opusData = input.ReadBytes((int)header.length);
+
+            ResultCode result = GetPacketNumSamples(out int numSamples, opusData);
+
+            if (result == ResultCode.Success)
+            {
+                if ((uint)numSamples * (uint)_channelsCount * sizeof(short) > outputSize)
+                {
+                    return ResultCode.OpusInvalidInput;
+                }
+
+                outPcmData = new short[numSamples * _channelsCount];
+
+                if (_reset)
+                {
+                    _reset = false;
+
+                    _decoder.ResetState();
+                }
+
+                try
+                {
+                    outSamples  = _decoder.Decode(opusData, 0, opusData.Length, outPcmData, 0, outPcmData.Length / _channelsCount);
+                    outConsumed = totalSize;
+                }
+                catch (OpusException)
+                {
+                    // TODO: as OpusException doesn't provide us the exact error code, this is kind of inaccurate in some cases...
+                    return ResultCode.OpusInvalidInput;
+                }
+            }
+
+            return ResultCode.Success;
+        }
+
+        [Command(0)]
+        // DecodeInterleaved(buffer<unknown, 5>) -> (u32, u32, buffer<unknown, 6>)
+        public ResultCode DecodeInterleavedOriginal(ServiceCtx context)
+        {
+            ResultCode result;
+
+            long inPosition     = context.Request.SendBuff[0].Position;
+            long inSize         = context.Request.SendBuff[0].Size;
+            long outputPosition = context.Request.ReceiveBuff[0].Position;
+            long outputSize     = context.Request.ReceiveBuff[0].Size;
+
+            using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize))))
+            {
+                result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples);
+
+                if (result == ResultCode.Success)
+                {
+                    byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)];
+                    Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length);
+                    context.Memory.WriteBytes(outputPosition, pcmDataBytes);
+
+                    context.ResponseData.Write(outConsumed);
+                    context.ResponseData.Write(outSamples);
+                }
+            }
+
+            return result;
+        }
+
+        [Command(4)] // 6.0.0+
+        // DecodeInterleavedWithPerfOld(buffer<unknown, 5>) -> (u32, u32, u64, buffer<unknown, 0x46>)
+        public ResultCode DecodeInterleavedWithPerfOld(ServiceCtx context)
+        {
+            ResultCode result;
+
+            long inPosition     = context.Request.SendBuff[0].Position;
+            long inSize         = context.Request.SendBuff[0].Size;
+            long outputPosition = context.Request.ReceiveBuff[0].Position;
+            long outputSize     = context.Request.ReceiveBuff[0].Size;
+
+            using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize))))
+            {
+                result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples);
+
+                if (result == ResultCode.Success)
+                {
+                    byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)];
+                    Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length);
+                    context.Memory.WriteBytes(outputPosition, pcmDataBytes);
+
+                    context.ResponseData.Write(outConsumed);
+                    context.ResponseData.Write(outSamples);
+
+                    // This is the time the DSP took to process the request, TODO: fill this.
+                    context.ResponseData.Write(0);
+                }
+            }
+
+            return result;
+        }
+
+        [Command(6)] // 6.0.0+
+        // DecodeInterleavedOld(bool reset, buffer<unknown, 5>) -> (u32, u32, u64, buffer<unknown, 0x46>)
+        public ResultCode DecodeInterleavedOld(ServiceCtx context)
+        {
+            ResultCode result;
+
+            _reset = context.RequestData.ReadBoolean();
+
+            long inPosition     = context.Request.SendBuff[0].Position;
+            long inSize         = context.Request.SendBuff[0].Size;
+            long outputPosition = context.Request.ReceiveBuff[0].Position;
+            long outputSize     = context.Request.ReceiveBuff[0].Size;
+
+            using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize))))
+            {
+                result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples);
+
+                if (result == ResultCode.Success)
+                {
+                    byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)];
+                    Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length);
+                    context.Memory.WriteBytes(outputPosition, pcmDataBytes);
+
+                    context.ResponseData.Write(outConsumed);
+                    context.ResponseData.Write(outSamples);
+
+                    // This is the time the DSP took to process the request, TODO: fill this.
+                    context.ResponseData.Write(0);
+                }
+            }
+
+            return result;
+        }
+
+        [Command(8)] // 7.0.0+
+        // DecodeInterleaved(bool reset, buffer<unknown, 0x45>) -> (u32, u32, u64, buffer<unknown, 0x46>)
+        public ResultCode DecodeInterleaved(ServiceCtx context)
+        {
+            ResultCode result;
+
+            _reset = context.RequestData.ReadBoolean();
+
+            long inPosition     = context.Request.SendBuff[0].Position;
+            long inSize         = context.Request.SendBuff[0].Size;
+            long outputPosition = context.Request.ReceiveBuff[0].Position;
+            long outputSize     = context.Request.ReceiveBuff[0].Size;
+
+            using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize))))
+            {
+                result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples);
+
+                if (result == ResultCode.Success)
+                {
+                    byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)];
+                    Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length);
+                    context.Memory.WriteBytes(outputPosition, pcmDataBytes);
+
+                    context.ResponseData.Write(outConsumed);
+                    context.ResponseData.Write(outSamples);
+
+                    // This is the time the DSP took to process the request, TODO: fill this.
+                    context.ResponseData.Write(0);
+                }
+            }
 
             return result;
         }
diff --git a/Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs b/Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs
new file mode 100644
index 0000000000..bb4b6d16c9
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs
@@ -0,0 +1,24 @@
+using Ryujinx.Common;
+using System;
+using System.IO;
+using System.Runtime.InteropServices;
+
+namespace Ryujinx.HLE.HOS.Services.Audio.Types
+{
+    [StructLayout(LayoutKind.Sequential)]
+    struct OpusPacketHeader
+    {
+        public uint length;
+        public uint finalRange;
+
+        public static OpusPacketHeader FromStream(BinaryReader reader)
+        {
+            OpusPacketHeader header = reader.ReadStruct<OpusPacketHeader>();
+
+            header.length     = EndianSwap.FromBigEndianToPlatformEndian(header.length);
+            header.finalRange = EndianSwap.FromBigEndianToPlatformEndian(header.finalRange);
+
+            return header;
+        }
+    }
+}
diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
index 3a02e06c05..7db8066a6b 100644
--- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
+++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs
@@ -1,4 +1,5 @@
-using Ryujinx.Common.Logging;
+using Ryujinx.Common;
+using Ryujinx.Common.Logging;
 using Ryujinx.HLE.Utilities;
 using System.Collections.Generic;
 using System.Net;