From 442485adb32626b3931cd15f284d0e686b0021fc Mon Sep 17 00:00:00 2001
From: gdk <gab.dark.100@gmail.com>
Date: Wed, 27 Nov 2019 00:38:56 -0300
Subject: [PATCH] Partial support for branch with CC, and fix a edge case of
 branch out of loop on shaders

---
 .../Decoders/OpCodeBranch.cs                  |  4 +++
 .../Instructions/InstEmitAlu.cs               |  4 ++-
 .../Instructions/InstEmitFlow.cs              | 32 ++++++++++++++++---
 .../StructuredIr/StructuredProgramContext.cs  | 27 +++++++++++++---
 .../Translation/Translator.cs                 |  7 ++--
 5 files changed, 60 insertions(+), 14 deletions(-)

diff --git a/Ryujinx.Graphics.Shader/Decoders/OpCodeBranch.cs b/Ryujinx.Graphics.Shader/Decoders/OpCodeBranch.cs
index f51c39966d..c4fa921265 100644
--- a/Ryujinx.Graphics.Shader/Decoders/OpCodeBranch.cs
+++ b/Ryujinx.Graphics.Shader/Decoders/OpCodeBranch.cs
@@ -4,12 +4,16 @@ namespace Ryujinx.Graphics.Shader.Decoders
 {
     class OpCodeBranch : OpCode
     {
+        public Condition Condition { get; }
+
         public int Offset { get; }
 
         public bool PushTarget { get; protected set; }
 
         public OpCodeBranch(InstEmitter emitter, ulong address, long opCode) : base(emitter, address, opCode)
         {
+            Condition = (Condition)(opCode & 0x1f);
+
             Offset = ((int)(opCode >> 20) << 8) >> 8;
 
             PushTarget = false;
diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs
index 1f6f389d5b..1d3a1101cb 100644
--- a/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs
+++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitAlu.cs
@@ -288,7 +288,9 @@ namespace Ryujinx.Graphics.Shader.Instructions
                 context.Copy(dest, res);
             }
 
-            // TODO: CC, X
+            SetZnFlags(context, res, op.SetCondCode, op.Extended);
+
+            // TODO: X
         }
 
         public static void Isetp(EmitterContext context)
diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs
index 4a9f5f7fca..c024fe0212 100644
--- a/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs
+++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs
@@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation;
 using System.Collections.Generic;
 using System.Linq;
 
+using static Ryujinx.Graphics.Shader.Instructions.InstEmitHelper;
 using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
 
 namespace Ryujinx.Graphics.Shader.Instructions
@@ -129,22 +130,29 @@ namespace Ryujinx.Graphics.Shader.Instructions
 
         private static void EmitBranch(EmitterContext context, ulong address)
         {
+            OpCode op = context.CurrOp;
+
             // If we're branching to the next instruction, then the branch
             // is useless and we can ignore it.
-            if (address == context.CurrOp.Address + 8)
+            if (address == op.Address + 8)
             {
                 return;
             }
 
             Operand label = context.GetLabel(address);
 
-            Operand pred = Register(context.CurrOp.Predicate);
+            Operand pred = Register(op.Predicate);
 
-            if (context.CurrOp.Predicate.IsPT)
+            if (op is OpCodeBranch opBranch && opBranch.Condition != Condition.Always)
+            {
+                pred = context.BitwiseAnd(pred, GetCondition(context, opBranch.Condition));
+            }
+
+            if (op.Predicate.IsPT)
             {
                 context.Branch(label);
             }
-            else if (context.CurrOp.InvertPredicate)
+            else if (op.InvertPredicate)
             {
                 context.BranchIfFalse(label, pred);
             }
@@ -153,5 +161,21 @@ namespace Ryujinx.Graphics.Shader.Instructions
                 context.BranchIfTrue(label, pred);
             }
         }
+
+        private static Operand GetCondition(EmitterContext context, Condition cond)
+        {
+            // TODO: More condition codes, figure out how they work.
+            switch (cond)
+            {
+                case Condition.Equal:
+                case Condition.EqualUnordered:
+                    return GetZF(context);
+                case Condition.NotEqual:
+                case Condition.NotEqualUnordered:
+                    return context.BitwiseNot(GetZF(context));
+            }
+
+            return Const(IrConsts.True);
+        }
     }
 }
\ No newline at end of file
diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs
index 03ff881835..55958a12ee 100644
--- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs
+++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs
@@ -11,7 +11,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
     {
         private HashSet<BasicBlock> _loopTails;
 
-        private Stack<(AstBlock Block, int EndIndex)> _blockStack;
+        private Stack<(AstBlock Block, int CurrEndIndex, int LoopEndIndex)> _blockStack;
 
         private Dictionary<Operand, AstOperand> _localsMap;
 
@@ -22,6 +22,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
         private AstBlock _currBlock;
 
         private int _currEndIndex;
+        private int _loopEndIndex;
 
         public StructuredProgramInfo Info { get; }
 
@@ -31,7 +32,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
         {
             _loopTails = new HashSet<BasicBlock>();
 
-            _blockStack = new Stack<(AstBlock, int)>();
+            _blockStack = new Stack<(AstBlock, int, int)>();
 
             _localsMap = new Dictionary<Operand, AstOperand>();
 
@@ -42,6 +43,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
             _currBlock = new AstBlock(AstBlockType.Main);
 
             _currEndIndex = blocksCount;
+            _loopEndIndex = blocksCount;
 
             Info = new StructuredProgramInfo(_currBlock);
 
@@ -52,7 +54,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
         {
             while (_currEndIndex == block.Index)
             {
-                (_currBlock, _currEndIndex) = _blockStack.Pop();
+                (_currBlock, _currEndIndex, _loopEndIndex) = _blockStack.Pop();
             }
 
             if (_gotoTempAsgs.TryGetValue(block.Index, out AstAssignment gotoTempAsg))
@@ -107,9 +109,19 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
                 return;
             }
 
+            // We can only enclose the "if" when the branch lands before
+            // the end of the current block. If the current enclosing block
+            // is not a loop, then we can also do so if the branch lands
+            // right at the end of the current block. When it is a loop,
+            // this is not valid as the loop condition would be evaluated,
+            // and it could erroneously jump back to the start of the loop.
+            bool inRange =
+                block.Branch.Index <  _currEndIndex ||
+               (block.Branch.Index == _currEndIndex && block.Branch.Index < _loopEndIndex);
+
             bool isLoop = block.Branch.Index <= block.Index;
 
-            if (block.Branch.Index <= _currEndIndex && !isLoop)
+            if (inRange && !isLoop)
             {
                 NewBlock(AstBlockType.If, branchOp, block.Branch.Index);
             }
@@ -171,10 +183,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 
             AddNode(childBlock);
 
-            _blockStack.Push((_currBlock, _currEndIndex));
+            _blockStack.Push((_currBlock, _currEndIndex, _loopEndIndex));
 
             _currBlock    = childBlock;
             _currEndIndex = endIndex;
+
+            if (type == AstBlockType.DoWhile)
+            {
+                _loopEndIndex = endIndex;
+            }
         }
 
         private IAstNode GetBranchCond(AstBlockType type, Operation branchOp)
diff --git a/Ryujinx.Graphics.Shader/Translation/Translator.cs b/Ryujinx.Graphics.Shader/Translation/Translator.cs
index 9c1eb08e9b..1c37fa70ed 100644
--- a/Ryujinx.Graphics.Shader/Translation/Translator.cs
+++ b/Ryujinx.Graphics.Shader/Translation/Translator.cs
@@ -1,6 +1,5 @@
 using Ryujinx.Graphics.Shader.CodeGen.Glsl;
 using Ryujinx.Graphics.Shader.Decoders;
-using Ryujinx.Graphics.Shader.Instructions;
 using Ryujinx.Graphics.Shader.IntermediateRepresentation;
 using Ryujinx.Graphics.Shader.StructuredIr;
 using Ryujinx.Graphics.Shader.Translation.Optimizations;
@@ -219,15 +218,15 @@ namespace Ryujinx.Graphics.Shader.Translation
 
                     Operand predSkipLbl = null;
 
-                    bool skipPredicateCheck = op.Emitter == InstEmit.Bra;
+                    bool skipPredicateCheck = op is OpCodeBranch opBranch && !opBranch.PushTarget;
 
                     if (op is OpCodeBranchPop opBranchPop)
                     {
-                        // If the instruction is a SYNC instruction with only one
+                        // If the instruction is a SYNC or BRK instruction with only one
                         // possible target address, then the instruction is basically
                         // just a simple branch, we can generate code similar to branch
                         // instructions, with the condition check on the branch itself.
-                        skipPredicateCheck |= opBranchPop.Targets.Count < 2;
+                        skipPredicateCheck = opBranchPop.Targets.Count < 2;
                     }
 
                     if (!(op.Predicate.IsPT || skipPredicateCheck))