/*
 * Decompiled with CFR 0.152.
 */
package cryptix.jce.provider.rsa;

import cryptix.jce.provider.rsa.RSAAlgorithm;
import cryptix.jce.provider.util.Util;
import java.math.BigInteger;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.SignatureSpi;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;

public abstract class RSASignature_PSS
extends SignatureSpi {
    private static final byte[] MASK = new byte[]{-1, 127, 63, 31, 15, 7, 3, 1};
    private final MessageDigest md;
    private final int hLen;
    private final int sLen;
    private byte[] presetSalt;
    private int emLen;
    private int emBits;
    private BigInteger exp;
    private BigInteger n;
    private BigInteger p;
    private BigInteger q;
    private BigInteger u;
    private SecureRandom rng;

    public RSASignature_PSS(String string) {
        try {
            this.md = MessageDigest.getInstance(string);
            this.hLen = this.sLen = this.md.getDigestLength();
        }
        catch (NoSuchAlgorithmException noSuchAlgorithmException) {
            throw new InternalError("MessageDigest not found! (" + string + "): " + noSuchAlgorithmException.toString());
        }
    }

    private byte[] concat(byte[] byArray, byte[] byArray2) {
        byte[] byArray3 = new byte[byArray.length + byArray2.length];
        System.arraycopy(byArray, 0, byArray3, 0, byArray.length);
        System.arraycopy(byArray2, 0, byArray3, byArray.length, byArray2.length);
        return byArray3;
    }

    private byte[] concat(byte[] byArray, byte[] byArray2, byte[] byArray3) {
        return this.concat(byArray, this.concat(byArray2, byArray3));
    }

    protected Object engineGetParameter(String string) {
        throw new RuntimeException("NYI");
    }

    protected void engineInitSign(PrivateKey privateKey) throws InvalidKeyException {
        this.engineInitSign(privateKey, new SecureRandom());
    }

    protected void engineInitSign(PrivateKey privateKey, SecureRandom secureRandom) throws InvalidKeyException {
        if (!(privateKey instanceof RSAPrivateKey)) {
            throw new InvalidKeyException("Not an RSA private key");
        }
        RSAPrivateKey rSAPrivateKey = (RSAPrivateKey)privateKey;
        this.n = rSAPrivateKey.getModulus();
        this.exp = rSAPrivateKey.getPrivateExponent();
        if (privateKey instanceof RSAPrivateCrtKey) {
            RSAPrivateCrtKey rSAPrivateCrtKey = (RSAPrivateCrtKey)privateKey;
            this.p = rSAPrivateCrtKey.getPrimeP();
            this.q = rSAPrivateCrtKey.getPrimeQ();
            this.u = rSAPrivateCrtKey.getCrtCoefficient();
        } else {
            this.u = null;
            this.q = null;
            this.p = null;
        }
        this.rng = secureRandom;
        this.initCommon();
    }

    protected void engineInitVerify(PublicKey publicKey) throws InvalidKeyException {
        if (!(publicKey instanceof RSAPublicKey)) {
            throw new InvalidKeyException("Not an RSA public key");
        }
        RSAPublicKey rSAPublicKey = (RSAPublicKey)publicKey;
        this.n = rSAPublicKey.getModulus();
        this.exp = rSAPublicKey.getPublicExponent();
        this.u = null;
        this.q = null;
        this.p = null;
        this.rng = null;
        this.initCommon();
    }

    protected void engineSetParameter(String string, Object object) {
        if (string.equalsIgnoreCase("CryptixDebugFixedSalt") && object instanceof byte[]) {
            this.presetSalt = (byte[])object;
        }
    }

    protected byte[] engineSign() {
        byte[] byArray;
        byte[] byArray2 = new byte[8];
        byte[] byArray3 = this.md.digest();
        if (this.presetSalt == null) {
            byArray = new byte[this.sLen];
            this.rng.nextBytes(byArray);
        } else {
            if (this.sLen != this.presetSalt.length) {
                throw new Error("Invalid presetSalt, size mismatch!");
            }
            byArray = this.presetSalt;
            this.presetSalt = null;
            System.err.println("Using preset salt: " + cryptix.jce.util.Util.toString((byte[])byArray) + "!");
        }
        this.md.update(byArray2);
        this.md.update(byArray3);
        byte[] byArray4 = this.md.digest(byArray);
        byte[] byArray5 = this.mgf1(byArray4, this.emLen - this.hLen - 1);
        byte[] byArray6 = new byte[this.emLen - this.sLen - this.hLen - 2];
        byte[] byArray7 = new byte[]{1};
        byte[] byArray8 = this.concat(byArray6, byArray7, byArray);
        byte[] byArray9 = RSASignature_PSS.xor(byArray8, byArray5);
        int n = 8 * this.emLen - this.emBits;
        byArray9[0] = (byte)(byArray9[0] & MASK[n]);
        byte[] byArray10 = this.concat(byArray9, byArray4, new byte[]{-68});
        BigInteger bigInteger = new BigInteger(1, byArray10);
        if (bigInteger.compareTo(this.n) != -1) {
            throw new InternalError("message > modulus!");
        }
        BigInteger bigInteger2 = RSAAlgorithm.rsa(bigInteger, this.n, this.exp, this.p, this.q, this.u);
        return Util.toFixedLenByteArray(bigInteger2, this.getModulusLen());
    }

    protected void engineUpdate(byte by) {
        this.md.update(by);
    }

    protected void engineUpdate(byte[] byArray, int n, int n2) {
        this.md.update(byArray, n, n2);
    }

    protected boolean engineVerify(byte[] byArray) {
        if (byArray.length != this.getModulusLen()) {
            return false;
        }
        BigInteger bigInteger = new BigInteger(1, byArray);
        if (bigInteger.compareTo(Util.BI_ZERO) < 0 || bigInteger.compareTo(this.n) >= 0) {
            return false;
        }
        BigInteger bigInteger2 = RSAAlgorithm.rsa(bigInteger, this.n, this.exp, this.p, this.q, this.u);
        if (bigInteger2.bitLength() > this.emLen * 8) {
            return false;
        }
        byte[] byArray2 = Util.toFixedLenByteArray(bigInteger2, this.emLen);
        return this.pssVerify(this.md.digest(), byArray2, this.getModulusBitLen() - 1);
    }

    private int getModulusBitLen() {
        return this.n.bitLength();
    }

    private int getModulusLen() {
        return (this.n.bitLength() + 7) / 8;
    }

    private void initCommon() throws InvalidKeyException {
        this.emBits = this.getModulusBitLen() - 1;
        this.emLen = (this.emBits + 7) / 8;
        if (this.emBits < 8 * this.hLen + 8 * this.sLen + 9) {
            throw new InvalidKeyException("Signer's key modulus too short.");
        }
        this.md.reset();
    }

    private byte[] mgf1(byte[] byArray, int n) {
        int n2 = (n + this.hLen - 1) / this.hLen;
        byte[] byArray2 = new byte[]{};
        int n3 = 0;
        while (n3 < n2) {
            byArray2 = this.concat(byArray2, this.mgf1Hash(byArray, (byte)n3));
            ++n3;
        }
        byte[] byArray3 = new byte[n];
        System.arraycopy(byArray2, 0, byArray3, 0, byArray3.length);
        return byArray3;
    }

    private byte[] mgf1Hash(byte[] byArray, byte by) {
        this.md.update(byArray);
        this.md.update(new byte[3]);
        this.md.update(by);
        return this.md.digest();
    }

    private boolean pssVerify(byte[] byArray, byte[] byArray2, int n) {
        if (n < 8 * this.hLen + 8 * this.sLen + 9) {
            return false;
        }
        if (byArray2[byArray2.length - 1] != -68) {
            return false;
        }
        int n2 = this.emLen - this.hLen - 1;
        byte[] byArray3 = new byte[n2];
        System.arraycopy(byArray2, 0, byArray3, 0, n2);
        byte[] byArray4 = new byte[this.hLen];
        System.arraycopy(byArray2, n2, byArray4, 0, this.hLen);
        int n3 = 8 * this.emLen - n;
        if ((byArray3[0] & ~MASK[n3]) != 0) {
            return false;
        }
        byte[] byArray5 = this.mgf1(byArray4, this.emLen - this.hLen - 1);
        byte[] byArray6 = RSASignature_PSS.xor(byArray3, byArray5);
        int n4 = 8 * this.emLen - n;
        byArray6[0] = (byte)(byArray6[0] & MASK[n4]);
        int n5 = this.emLen - this.hLen - this.sLen - 2;
        int n6 = 0;
        while (n6 < n5) {
            if (byArray6[n6] != 0) {
                return false;
            }
            ++n6;
        }
        if (byArray6[n5] != 1) {
            return false;
        }
        byte[] byArray7 = new byte[this.sLen];
        System.arraycopy(byArray6, byArray6.length - this.sLen, byArray7, 0, this.sLen);
        this.md.reset();
        this.md.update(new byte[8]);
        this.md.update(byArray);
        byte[] byArray8 = this.md.digest(byArray7);
        return cryptix.jce.util.Util.equals((byte[])byArray8, (byte[])byArray4);
    }

    private static byte[] xor(byte[] byArray, byte[] byArray2) {
        if (byArray.length != byArray2.length) {
            throw new InternalError("a.len != b.len");
        }
        byte[] byArray3 = new byte[byArray.length];
        int n = 0;
        while (n < byArray3.length) {
            byArray3[n] = (byte)(byArray[n] ^ byArray2[n]);
            ++n;
        }
        return byArray3;
    }
}

