Some initial work on inlining

This commit is contained in:
Adam
2015-08-08 00:37:12 -04:00
parent 42e4223e83
commit d09eb21a0d
12 changed files with 324 additions and 147 deletions

View File

@@ -8,6 +8,7 @@ import info.sigterm.deob.deobfuscators.UnusedFields;
import info.sigterm.deob.deobfuscators.UnusedMethods;
import info.sigterm.deob.deobfuscators.UnusedParameters;
import info.sigterm.deob.deobfuscators.ConstantParameter;
import info.sigterm.deob.deobfuscators.MethodInliner;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
@@ -35,55 +36,57 @@ public class Deob
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("rename unique took " + bdur/1000L + " seconds");
// remove except RuntimeException
bstart = System.currentTimeMillis();
new RuntimeExceptions().run(group);
// the blocks of runtime exceptions may contain interesting things like other obfuscations we identify later, but now that
// it can't be reached by the execution phase, those things become confused. so remove blocks here.
new UnusedBlocks().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("runtime exception took " + bdur/1000L + " seconds");
// remove unused methods
bstart = System.currentTimeMillis();
new UnusedMethods().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("unused methods took " + bdur/1000L + " seconds");
// remove illegal state exceptions, frees up some parameters
bstart = System.currentTimeMillis();
new IllegalStateExceptions().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("illegal state exception took " + bdur/1000L + " seconds");
// remove constant logically dead parameters
bstart = System.currentTimeMillis();
new ConstantParameter().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("constant param took " + bdur/1000L + " seconds");
// remove unhit blocks
bstart = System.currentTimeMillis();
new UnusedBlocks().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("unused blocks took " + bdur/1000L + " seconds");
// remove unused parameters
bstart = System.currentTimeMillis();
new UnusedParameters().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("unused params took " + bdur/1000L + " seconds");
// remove jump obfuscation
//new Jumps().run(group);
// remove unused fields
bstart = System.currentTimeMillis();
new UnusedFields().run(group);
bdur = System.currentTimeMillis() - bstart;
System.out.println("unused fields took " + bdur/1000L + " seconds");
// // remove except RuntimeException
// bstart = System.currentTimeMillis();
// new RuntimeExceptions().run(group);
// // the blocks of runtime exceptions may contain interesting things like other obfuscations we identify later, but now that
// // it can't be reached by the execution phase, those things become confused. so remove blocks here.
// new UnusedBlocks().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("runtime exception took " + bdur/1000L + " seconds");
//
// // remove unused methods
// bstart = System.currentTimeMillis();
// new UnusedMethods().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("unused methods took " + bdur/1000L + " seconds");
//
// // remove illegal state exceptions, frees up some parameters
// bstart = System.currentTimeMillis();
// new IllegalStateExceptions().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("illegal state exception took " + bdur/1000L + " seconds");
//
// // remove constant logically dead parameters
// bstart = System.currentTimeMillis();
// new ConstantParameter().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("constant param took " + bdur/1000L + " seconds");
//
// // remove unhit blocks
// bstart = System.currentTimeMillis();
// new UnusedBlocks().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("unused blocks took " + bdur/1000L + " seconds");
//
// // remove unused parameters
// bstart = System.currentTimeMillis();
// new UnusedParameters().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("unused params took " + bdur/1000L + " seconds");
//
// // remove jump obfuscation
// //new Jumps().run(group);
//
// // remove unused fields
// bstart = System.currentTimeMillis();
// new UnusedFields().run(group);
// bdur = System.currentTimeMillis() - bstart;
// System.out.println("unused fields took " + bdur/1000L + " seconds");
//new ModularArithmeticDeobfuscation().run(group);
new MethodInliner().run(group);
saveJar(group, args[1]);

View File

@@ -1,7 +1,10 @@
package info.sigterm.deob.attributes;
import info.sigterm.deob.Method;
import info.sigterm.deob.attributes.code.Exceptions;
import info.sigterm.deob.attributes.code.Instruction;
import info.sigterm.deob.attributes.code.Instructions;
import info.sigterm.deob.attributes.code.instruction.types.LVTInstruction;
import java.io.DataInputStream;
import java.io.DataOutputStream;
@@ -10,7 +13,6 @@ import java.io.IOException;
public class Code extends Attribute
{
private int maxStack;
private int maxLocals;
private Instructions instructions;
private Exceptions exceptions;
private Attributes attributes;
@@ -22,7 +24,7 @@ public class Code extends Attribute
DataInputStream is = attributes.getStream();
maxStack = is.readUnsignedShort();
maxLocals = is.readUnsignedShort();
is.skip(2); // max locals
instructions = new Instructions(this);
@@ -37,7 +39,7 @@ public class Code extends Attribute
public void writeAttr(DataOutputStream out) throws IOException
{
out.writeShort(maxStack);
out.writeShort(maxLocals);
out.writeShort(getMaxLocals());
instructions.write(out);
exceptions.write(out);
@@ -48,10 +50,35 @@ public class Code extends Attribute
{
return maxStack;
}
private int getMaxLocalsFromSig()
{
Method m = super.getAttributes().getMethod();
int num = m.isStatic() ? 0 : 1;
num += m.getDescriptor().size();
return num;
}
public int getMaxLocals()
{
return maxLocals;
int max = -1;
for (Instruction ins : instructions.getInstructions())
{
if (ins instanceof LVTInstruction)
{
LVTInstruction lvt = (LVTInstruction) ins;
if (lvt.getVariableIndex() > max)
max = lvt.getVariableIndex();
}
}
int fromSig = getMaxLocalsFromSig();
if (fromSig > max)
max = fromSig;
return max + 1;
}
public Exceptions getExceptions()

View File

@@ -153,11 +153,21 @@ public abstract class Instruction
return instructions;
}
public void setInstructions(Instructions instructions)
{
this.instructions = instructions;
}
public InstructionType getType()
{
return type;
}
protected void setType(InstructionType type)
{
this.type = type;
}
public ConstantPool getPool()
{
return instructions.getCode().getAttributes().getClassFile().getPool();

View File

@@ -22,8 +22,8 @@ public enum InstructionType
DCONST_1(0x0f, "dconst_1", DConst_1.class),
BIPUSH(0x10, "bipush", BiPush.class),
SIPUSH(0x11, "sipush", SiPush.class),
LDC(0x12, "ldc", LDC.class),
LDC_W(0x13, "lcd_w", LDC_W.class),
LDC(0x12, "ldc_w", LDC_W.class),
LDC_W(0x13, "ldc_w", LDC_W.class),
LDC2_W(0x14, "ldc2_w", LDC2_W.class),
ILOAD(0x15, "iload", ILoad.class),
LLOAD(0x16, "lload", LLoad.class),

View File

@@ -0,0 +1,6 @@
package info.sigterm.deob.attributes.code.instruction.types;
public interface ReturnInstruction
{
}

View File

@@ -1,72 +0,0 @@
package info.sigterm.deob.attributes.code.instructions;
import info.sigterm.deob.attributes.code.Instruction;
import info.sigterm.deob.attributes.code.InstructionType;
import info.sigterm.deob.attributes.code.Instructions;
import info.sigterm.deob.attributes.code.instruction.types.PushConstantInstruction;
import info.sigterm.deob.execution.Frame;
import info.sigterm.deob.execution.InstructionContext;
import info.sigterm.deob.execution.Stack;
import info.sigterm.deob.execution.StackContext;
import info.sigterm.deob.pool.PoolEntry;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
public class LDC extends Instruction implements PushConstantInstruction
{
private PoolEntry value;
public LDC(Instructions instructions, InstructionType type, int pc) throws IOException
{
super(instructions, type, pc);
DataInputStream is = instructions.getCode().getAttributes().getStream();
value = this.getPool().getEntry(is.readUnsignedByte());
length += 1;
}
@Override
public void prime()
{
int index = this.getPool().make(value);
if (index > 0xFF)
{
// new index might require changing this to an ldc_w
this.replace(new LDC_W(this.getInstructions(), value));
}
}
@Override
public void write(DataOutputStream out) throws IOException
{
super.write(out);
int index = this.getPool().make(value);
out.writeByte(index);
}
@Override
public void execute(Frame frame)
{
InstructionContext ins = new InstructionContext(this, frame);
Stack stack = frame.getStack();
StackContext ctx = new StackContext(ins, value.getTypeClass());
stack.push(ctx);
frame.addInstructionContext(ins);
}
@Override
public PoolEntry getConstant()
{
return value;
}
@Override
public void setConstant(PoolEntry entry)
{
value = entry;
}
}

View File

@@ -23,8 +23,18 @@ public class LDC_W extends Instruction implements PushConstantInstruction
super(instructions, type, pc);
DataInputStream is = instructions.getCode().getAttributes().getStream();
value = this.getPool().getEntry(is.readUnsignedShort());
length += 2;
assert type == InstructionType.LDC_W || type == InstructionType.LDC;
if (type == InstructionType.LDC_W)
{
value = this.getPool().getEntry(is.readUnsignedShort());
length += 2;
}
else if (type == InstructionType.LDC)
{
value = this.getPool().getEntry(is.readUnsignedByte());
length += 1;
}
}
public LDC_W(Instructions instructions, PoolEntry value)
@@ -35,11 +45,36 @@ public class LDC_W extends Instruction implements PushConstantInstruction
length += 2;
}
@Override
public void prime()
{
int index = this.getPool().make(value);
assert index >= 0 && index <= 0xFFFF;
if (index > 0xFF && this.getType() == InstructionType.LDC)
{
this.setType(InstructionType.LDC_W);
++length;
}
}
@Override
public void write(DataOutputStream out) throws IOException
{
super.write(out);
out.writeShort(this.getPool().make(value));
int index = this.getPool().make(value);
assert this.getType() == InstructionType.LDC || this.getType() == InstructionType.LDC_W;
if (this.getType() == InstructionType.LDC)
{
assert index >= 0 && index <= 0xFF;
out.writeByte(index);
}
else if (this.getType() == InstructionType.LDC_W)
{
assert index >= 0 && index <= 0xFFFF;
out.writeShort(index);
}
}
@Override

View File

@@ -13,6 +13,11 @@ public class NOP extends Instruction
{
super(instructions, type, pc);
}
public NOP(Instructions instructions)
{
super(instructions, InstructionType.NOP, 0);
}
@Override
public void execute(Frame frame)

View File

@@ -3,6 +3,7 @@ package info.sigterm.deob.attributes.code.instructions;
import info.sigterm.deob.attributes.code.Instruction;
import info.sigterm.deob.attributes.code.InstructionType;
import info.sigterm.deob.attributes.code.Instructions;
import info.sigterm.deob.attributes.code.instruction.types.ReturnInstruction;
import info.sigterm.deob.execution.Frame;
import info.sigterm.deob.execution.InstructionContext;
import info.sigterm.deob.execution.Stack;
@@ -10,7 +11,7 @@ import info.sigterm.deob.execution.StackContext;
import java.io.IOException;
public class Return extends Instruction
public class Return extends Instruction implements ReturnInstruction
{
public Return(Instructions instructions, InstructionType type, int pc) throws IOException
{

View File

@@ -3,11 +3,12 @@ package info.sigterm.deob.attributes.code.instructions;
import info.sigterm.deob.attributes.code.Instruction;
import info.sigterm.deob.attributes.code.InstructionType;
import info.sigterm.deob.attributes.code.Instructions;
import info.sigterm.deob.attributes.code.instruction.types.ReturnInstruction;
import info.sigterm.deob.execution.Frame;
import java.io.IOException;
public class VReturn extends Instruction
public class VReturn extends Instruction implements ReturnInstruction
{
public VReturn(Instructions instructions, InstructionType type, int pc) throws IOException
{

View File

@@ -1,17 +0,0 @@
package info.sigterm.deob.callgraph;
import info.sigterm.deob.Method;
import info.sigterm.deob.attributes.code.Instruction;
public class Node
{
public Method from, to;
public Instruction ins;
public Node(Method from, Method to, Instruction ins)
{
this.from = from;
this.to = to;
this.ins = ins;
}
}

View File

@@ -0,0 +1,178 @@
package info.sigterm.deob.deobfuscators;
import info.sigterm.deob.ClassFile;
import info.sigterm.deob.ClassGroup;
import info.sigterm.deob.Deobfuscator;
import info.sigterm.deob.Method;
import info.sigterm.deob.attributes.Code;
import info.sigterm.deob.attributes.code.Instruction;
import info.sigterm.deob.attributes.code.Instructions;
import info.sigterm.deob.attributes.code.instruction.types.InvokeInstruction;
import info.sigterm.deob.attributes.code.instruction.types.LVTInstruction;
import info.sigterm.deob.attributes.code.instruction.types.ReturnInstruction;
import info.sigterm.deob.attributes.code.instructions.Goto;
import info.sigterm.deob.attributes.code.instructions.InvokeStatic;
import info.sigterm.deob.attributes.code.instructions.NOP;
import info.sigterm.deob.attributes.code.instructions.VReturn;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MethodInliner implements Deobfuscator
{
private Map<Method, Integer> calls = new HashMap<>();
private void countCalls(Method m)
{
Code code = m.getCode();
if (code == null)
return;
Instructions ins = code.getInstructions();
for (Instruction i : ins.getInstructions())
{
// can only inline static method calls
if (!(i instanceof InvokeStatic))
continue;
List<Method> invokedMethods = ((InvokeInstruction) i).getMethods();
if (invokedMethods.isEmpty())
continue; // not our method
assert invokedMethods.size() == 1;
Method invokedMethod = invokedMethods.get(0);
Integer count = calls.get(invokedMethod);
if (count == null)
calls.put(invokedMethod, 1);
else
calls.put(invokedMethod, count + 1);
}
}
private int processMethod(Method m)
{
int inlineCount = 0;
Code code = m.getCode();
if (code == null)
return inlineCount;
Instructions ins = code.getInstructions();
for (Instruction i : ins.getInstructions())
{
// can only inline static method calls
if (!(i instanceof InvokeStatic))
continue;
List<Method> invokedMethods = ((InvokeInstruction) i).getMethods();
if (invokedMethods.isEmpty())
continue; // not our method
Method invokedMethod = invokedMethods.get(0);
Integer count = calls.get(invokedMethod);
if (count == null || count != 1)
continue; // only inline methods called once
// XXX do this later
System.out.println(invokedMethod.getDescriptor().getReturnValue().getType() + " " + invokedMethod.getDescriptor().size());
if (!invokedMethod.getDescriptor().getReturnValue().getType().equals("V")
|| invokedMethod.getDescriptor().size() != 0)
continue;
inline(m, i, invokedMethod);
++inlineCount;
break;
}
return inlineCount;
}
private void inline(Method method, Instruction invokeIns, Method invokeMethod)
{
Code methodCode = method.getCode(),
invokeMethodCode = invokeMethod.getCode();
Instructions methodInstructions = methodCode.getInstructions(),
invokeMethodInstructions = invokeMethodCode.getInstructions();
int maxLocals = methodCode.getMaxLocals(); // max locals currently
int idx = methodInstructions.getInstructions().indexOf(invokeIns); // index of invoke ins, before removal
assert idx != -1;
Instruction nextInstruction = methodInstructions.getInstructions().get(idx + 1);
// move stuff which jumps to invokeIns to nop
Instruction nop = new NOP(methodInstructions);
methodInstructions.getInstructions().add(idx + 1, nop);
++idx;
for (Instruction fromI : invokeIns.from)
{
assert fromI.jump.contains(invokeIns);
fromI.jump.remove(invokeIns);
fromI.replace(invokeIns, nop);
}
invokeIns.from.clear();
methodInstructions.remove(invokeIns);
for (Instruction i : invokeMethodInstructions.getInstructions())
{
// move instructions over.
if (i instanceof ReturnInstruction)
{
assert i instanceof VReturn; // only support void atm
// XXX I am assuming that this function leaves the stack in a clean state?
// instead of return, jump to next instruction after the invoke
i = new Goto(methodInstructions, nextInstruction);
}
if (i instanceof LVTInstruction)
{
LVTInstruction lvt = (LVTInstruction) i;
// offset lvt index
int newIndex = maxLocals + lvt.getVariableIndex();
i = lvt.setVariableIndex(newIndex);
}
methodInstructions.getInstructions().add(idx++, i);
}
// old instructions go away
invokeMethodInstructions.getInstructions().clear();
}
@Override
public void run(ClassGroup group)
{
group.buildClassGraph();
int count = 0;
for (ClassFile cf : group.getClasses())
{
for (Method m : cf.getMethods().getMethods())
{
countCalls(m);
}
}
for (ClassFile cf : group.getClasses())
{
for (Method m : cf.getMethods().getMethods())
{
count += processMethod(m);
}
}
System.out.println("Inlined " + count + " methods");
}
}