package org.jgroups.protocols;

import org.jgroups.*;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.Property;
import org.jgroups.stack.Protocol;
import org.jgroups.util.*;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.function.BiConsumer;
import java.util.zip.Adler32;
import java.util.zip.CRC32;
import java.util.zip.Checksum;

/**
 * Super class of symmetric ({@link SYM_ENCRYPT}) and asymmetric ({@link ASYM_ENCRYPT}) encryption protocols.
 * @author Bela Ban
 */
public abstract class Encrypt extends Protocol {
    protected static final String DEFAULT_SYM_ALGO="AES";


    /* -----------------------------------------    Properties     -------------------------------------------------- */
    @Property(description="Cryptographic Service Provider")
    protected String                        provider;

    @Property(description="Cipher engine transformation for asymmetric algorithm. Default is RSA")
    protected String                        asym_algorithm="RSA";

    @Property(description="Cipher engine transformation for symmetric algorithm. Default is AES")
    protected String                        sym_algorithm=DEFAULT_SYM_ALGO;

    @Property(description="Initial public/private key length. Default is 512")
    protected int                           asym_keylength=512;

    @Property(description="Initial key length for matching symmetric algorithm. Default is 128")
    protected int                           sym_keylength=128;

    @Property(description="Number of ciphers in the pool to parallelize encrypt and decrypt requests",writable=false)
    protected int                           cipher_pool_size=8;

    @Property(description="If true, the entire message (including payload and headers) is encrypted, else only the payload")
    protected boolean                       encrypt_entire_message=true;

    @Property(description="If true, all messages are digitally signed by adding an encrypted checksum of the encrypted " +
      "message to the header. Ignored if encrypt_entire_message is false")
    protected boolean                       sign_msgs=true;

    @Property(description="When sign_msgs is true, by default CRC32 is used to create the checksum. If use_adler is " +
      "true, Adler32 will be used")
    protected boolean                       use_adler;

    protected volatile Address              local_addr;

    protected volatile View                 view;

    // Cipher pools used for encryption and decryption. Size is cipher_pool_size
    protected BlockingQueue<Cipher>         encoding_ciphers, decoding_ciphers;

    // version filed for secret key
    protected volatile byte[]               sym_version;

    // shared secret key to encrypt/decrypt messages
    protected volatile SecretKey            secret_key;

    // map to hold previous keys so we can decrypt some earlier messages if we need to
    protected final Map<AsciiString,Cipher> key_map=new WeakHashMap<>();



    public int                      asymKeylength()                 {return asym_keylength;}
    public <T extends Encrypt> T    asymKeylength(int len)          {this.asym_keylength=len; return (T)this;}
    public int                      symKeylength()                  {return sym_keylength;}
    public <T extends Encrypt> T    symKeylength(int len)           {this.sym_keylength=len; return (T)this;}
    public SecretKey                secretKey()                     {return secret_key;}
    public <T extends Encrypt> T    secretKey(SecretKey key)        {this.secret_key=key; return (T)this;}
    public String                   symAlgorithm()                  {return sym_algorithm;}
    public <T extends Encrypt> T    symAlgorithm(String alg)        {this.sym_algorithm=alg; return (T)this;}
    public String                   asymAlgorithm()                 {return asym_algorithm;}
    public <T extends Encrypt> T    asymAlgorithm(String alg)       {this.asym_algorithm=alg; return (T)this;}
    public byte[]                   symVersion()                    {return sym_version;}
    public <T extends Encrypt> T    symVersion(byte[] v)            {this.sym_version=Arrays.copyOf(v, v.length); return (T)this;}
    public <T extends Encrypt> T    localAddress(Address addr)      {this.local_addr=addr; return (T)this;}
    public boolean                  encryptEntireMessage()          {return encrypt_entire_message;}
    public <T extends Encrypt> T    encryptEntireMessage(boolean b) {this.encrypt_entire_message=b; return (T)this;}
    public boolean                  signMessages()                  {return this.sign_msgs;}
    public <T extends Encrypt> T    signMessages(boolean flag)      {this.sign_msgs=flag; return (T)this;}
    public boolean                  adler()                         {return use_adler;}
    public <T extends Encrypt> T    adler(boolean flag)             {this.use_adler=flag; return (T)this;}
    @ManagedAttribute public String version()                       {return Util.byteArrayToHexString(sym_version);}

    public void init() throws Exception {
        int tmp=Util.getNextHigherPowerOfTwo(cipher_pool_size);
        if(tmp != cipher_pool_size) {
            log.warn("%s: setting cipher_pool_size (%d) to %d (power of 2) for faster modulo operation", local_addr, cipher_pool_size, tmp);
            cipher_pool_size=tmp;
        }
        encoding_ciphers=new ArrayBlockingQueue<>(cipher_pool_size);
        decoding_ciphers=new ArrayBlockingQueue<>(cipher_pool_size);
        initSymCiphers(sym_algorithm, secret_key);
    }


    public Object down(Event evt) {
        switch(evt.getType()) {
            case Event.MSG:
                Message msg=evt.getArg();
                try {
                    if(secret_key == null) {
                        log.trace("%s: discarded %s message to %s as secret key is null, hdrs: %s",
                                  local_addr, msg.dest() == null? "mcast" : "unicast", msg.dest(), msg.printHeaders());
                        return null;
                    }
                    encryptAndSend(msg);
                }
                catch(Exception e) {
                    log.warn("%s: unable to send message down", local_addr, e);
                }
                return null;

            case Event.VIEW_CHANGE:
                handleView(evt.getArg());
                break;

            case Event.SET_LOCAL_ADDRESS:
                local_addr=evt.getArg();
                break;
        }
        return down_prot.down(evt);
    }


    public Object up(Event evt) {
        switch(evt.getType()) {
            case Event.VIEW_CHANGE:
                handleView(evt.getArg());
                break;
            case Event.MSG:
                Message msg=evt.getArg();
                try {
                    return handleUpMessage(msg);
                }
                catch(Exception e) {
                    log.warn("%s: exception occurred decrypting message", local_addr, e);
                }
                return null;
        }
        return up_prot.up(evt);
    }


    public void up(MessageBatch batch) {
        Cipher cipher=null;
        try {
            if(secret_key == null) {
                log.trace("%s: discarded %s batch from %s as secret key is null",
                          local_addr, batch.dest() == null? "mcast" : "unicast", batch.sender());
                return;
            }
            BiConsumer<Message,MessageBatch> decrypter=new Decrypter(cipher=decoding_ciphers.take());
            batch.forEach(decrypter);
        }
        catch(InterruptedException e) {
            log.error("%s: failed processing batch; discarding batch", local_addr, e);
            // we need to drop the batch if we for example have a failure fetching a cipher, or else other messages
            // in the batch might make it up the stack, bypassing decryption! This is not an issue because encryption
            // is below NAKACK2 or UNICAST3, so messages will get retransmitted
            return;
        }
        finally {
            if(cipher != null)
                decoding_ciphers.offer(cipher);
        }
        if(!batch.isEmpty())
            up_prot.up(batch);
    }



    /** Initialises the ciphers for both encryption and decryption using the generated or supplied secret key */
    protected synchronized void initSymCiphers(String algorithm, SecretKey secret) throws Exception {
        if(secret == null)
            return;
        encoding_ciphers.clear();
        decoding_ciphers.clear();
        for(int i=0; i < cipher_pool_size; i++ ) {
            encoding_ciphers.add(createCipher(Cipher.ENCRYPT_MODE, secret, algorithm));
            decoding_ciphers.add(createCipher(Cipher.DECRYPT_MODE, secret, algorithm));
        };

        //set the version
        MessageDigest digest=MessageDigest.getInstance("MD5");
        digest.reset();
        digest.update(secret.getEncoded());

        byte[] tmp=digest.digest();
        sym_version=Arrays.copyOf(tmp, tmp.length);
        log.debug("%s: created %d symmetric ciphers with secret key (%d bytes)", local_addr, cipher_pool_size, sym_version.length);
    }


    protected Cipher createCipher(int mode, SecretKey secret_key, String algorithm) throws Exception {
        Cipher cipher=provider != null && !provider.trim().isEmpty()?
          Cipher.getInstance(algorithm, provider) : Cipher.getInstance(algorithm);
        cipher.init(mode, secret_key);
        return cipher;
    }


    protected Object handleUpMessage(Message msg) throws Exception {
        EncryptHeader hdr=msg.getHeader(this.id);
        if(hdr == null) {
            log.error("%s: received message without encrypt header from %s; dropping it", local_addr, msg.src());
            return null;
        }
        switch(hdr.type()) {
            case EncryptHeader.ENCRYPT:
                return handleEncryptedMessage(msg);
            default:
                return handleUpEvent(msg,hdr);
        }
    }


    protected Object handleEncryptedMessage(Message msg) throws Exception {
        if(!process(msg))
            return null;

        // try and decrypt the message - we need to copy msg as we modify its
        // buffer (http://jira.jboss.com/jira/browse/JGRP-538)
        Message tmpMsg=decryptMessage(null, msg.copy()); // need to copy for possible xmits
        if(tmpMsg != null)
            return up_prot.up(new Event(Event.MSG, tmpMsg));
        log.warn("%s: unrecognized cipher; discarding message from %s", local_addr, msg.src());
        return null;
    }

    protected Object handleUpEvent(Message msg, EncryptHeader hdr) {
        return null;
    }

    /** Whether or not to process this received message */
    protected boolean process(Message msg) {return true;}

    protected void handleView(View view) {
        this.view=view;
    }

    protected boolean inView(Address sender, String error_msg) {
        View curr_view=this.view;
        if(curr_view == null || curr_view.containsMember(sender))
            return true;
        log.error(error_msg, sender, curr_view);
        return false;
    }

    protected Checksum createChecksummer() {return use_adler? new Adler32() : new CRC32();}


    /** Does the actual work for decrypting - if version does not match current cipher then tries the previous cipher */
    protected Message decryptMessage(Cipher cipher, Message msg) throws Exception {
        EncryptHeader hdr=msg.getHeader(this.id);
        if(!Arrays.equals(hdr.version(), sym_version)) {
            cipher=key_map.get(new AsciiString(hdr.version()));
            if(cipher == null) {
                handleUnknownVersion();
                return null;
            }
            log.trace("%s: decrypting msg from %s using previous cipher version", local_addr, msg.src());
            return _decrypt(cipher, msg, hdr);
        }
        return _decrypt(cipher, msg, hdr);
    }

    protected Message _decrypt(final Cipher cipher, Message msg, EncryptHeader hdr) throws Exception {
        byte[] decrypted_msg;

        if(!encrypt_entire_message && msg.getLength() == 0)
            return msg;

        if(encrypt_entire_message && sign_msgs) {
            byte[] signature=hdr.signature();
            if(signature == null) {
                log.error("%s: dropped message from %s as the header did not have a checksum", local_addr, msg.src());
                return null;
            }

            long msg_checksum=decryptChecksum(cipher, signature, 0, signature.length);
            long actual_checksum=computeChecksum(msg.getRawBuffer(), msg.getOffset(), msg.getLength());
            if(actual_checksum != msg_checksum) {
                log.error("%s: dropped message from %s as the message's checksum (%d) did not match the computed checksum (%d)",
                          local_addr, msg.src(), msg_checksum, actual_checksum);
                return null;
            }
        }

        if(cipher == null)
            decrypted_msg=code(msg.getRawBuffer(), msg.getOffset(), msg.getLength(), true);
        else
            decrypted_msg=cipher.doFinal(msg.getRawBuffer(), msg.getOffset(), msg.getLength());

        if(!encrypt_entire_message) {
            msg.setBuffer(decrypted_msg);
            return msg;
        }

        Message ret=Util.streamableFromBuffer(Message.class,decrypted_msg,0,decrypted_msg.length);
        if(ret.getDest() == null)
            ret.setDest(msg.getDest());
        if(ret.getSrc() == null)
            ret.setSrc(msg.getSrc());
        return ret;
    }


    protected void encryptAndSend(Message msg) throws Exception {
        EncryptHeader hdr=new EncryptHeader(EncryptHeader.ENCRYPT, symVersion());
        if(encrypt_entire_message) {
            if(msg.getSrc() == null)
                msg.setSrc(local_addr);

            Buffer serialized_msg=Util.streamableToBuffer(msg);
            byte[] encrypted_msg=code(serialized_msg.getBuf(),serialized_msg.getOffset(),serialized_msg.getLength(),false);

            if(sign_msgs) {
                long checksum=computeChecksum(encrypted_msg, 0, encrypted_msg.length);
                byte[] checksum_array=encryptChecksum(checksum);
                hdr.signature(checksum_array);
            }

            // exclude existing headers, they will be seen again when we decrypt and unmarshal the msg at the receiver
            Message tmp=msg.copy(false, false).setBuffer(encrypted_msg).putHeader(this.id,hdr);
            down_prot.down(new Event(Event.MSG, tmp));
            return;
        }

        // copy neeeded because same message (object) may be retransmitted -> prevent double encryption
        Message msgEncrypted=msg.copy(false).putHeader(this.id, hdr);
        if(msg.getLength() > 0)
            msgEncrypted.setBuffer(code(msg.getRawBuffer(),msg.getOffset(),msg.getLength(),false));
        down_prot.down(new Event(Event.MSG,msgEncrypted));
    }


    protected byte[] code(byte[] buf, int offset, int length, boolean decode) throws Exception {
        BlockingQueue<Cipher> queue=decode? decoding_ciphers : encoding_ciphers;
        Cipher cipher=queue.take();
        try {
            return cipher.doFinal(buf, offset, length);
        }
        finally {
            queue.offer(cipher);
        }
    }

    protected long computeChecksum(byte[] input, int offset, int length) {
        Checksum checksummer=createChecksummer();
        checksummer.update(input, offset, length);
        return checksummer.getValue();
    }

    protected byte[] encryptChecksum(long checksum) throws Exception {
        byte[] checksum_array=new byte[Global.LONG_SIZE];
        Bits.writeLong(checksum, checksum_array, 0);
        return code(checksum_array, 0, checksum_array.length, false);
    }

    protected long decryptChecksum(final Cipher cipher, byte[] input, int offset, int length) throws Exception {
        byte[] decrypted_checksum;
        if(cipher == null)
            decrypted_checksum=code(input, offset, length, true);
        else
            decrypted_checksum=cipher.doFinal(input, offset, length);
        return Bits.readLong(decrypted_checksum, 0);
    }


    /* Get the algorithm name from "algorithm/mode/padding"  taken from original ENCRYPT */
    protected static String getAlgorithm(String s) {
        int index=s.indexOf('/');
        return index == -1? s : s.substring(0, index);
    }


    /** Called when the version shipped in the header can't be found */
    protected void handleUnknownVersion() {}


    /** Decrypts all messages in a batch, replacing encrypted messages in-place with their decrypted versions */
    protected class Decrypter implements BiConsumer<Message,MessageBatch> {
        protected final Cipher cipher;

        public Decrypter(Cipher cipher) {
            this.cipher=cipher;
        }

        public void accept(Message msg, MessageBatch batch) {
            EncryptHeader hdr;
            if((hdr=msg.getHeader(id)) == null) {
                log.error("%s: received message without encrypt header from %s; dropping it", local_addr, batch.sender());
                batch.remove(msg); // remove from batch to prevent passing the message further up as part of the batch
                return;
            }

            if(hdr.type() == EncryptHeader.ENCRYPT) {
                try {
                    if(!process(msg)) {
                        batch.remove(msg);
                        return;
                    }
                    Message tmpMsg=decryptMessage(cipher, msg.copy()); // need to copy for possible xmits
                    if(tmpMsg != null)
                        batch.replace(msg, tmpMsg);
                    else
                        batch.remove(msg);
                }
                catch(Exception e) {
                    log.error("%s: failed decrypting message from %s (offset=%d, length=%d, buf.length=%d): %s, headers are %s",
                              local_addr, msg.getSrc(), msg.getOffset(), msg.getLength(), msg.getRawBuffer().length, e, msg.printHeaders());
                    batch.remove(msg);
                }
            }
            else {
                batch.remove(msg); // a control message will get handled by ENCRYPT and should not be passed up
                handleUpEvent(msg, hdr);
            }
        }
    }

}