/**
 * @author Tres Finocchiaro
 *
 * Copyright (C) 2019 Tres Finocchiaro, QZ Industries, LLC
 *
 * LGPL 2.1 This is free software.  This software and source code are released under
 * the "LGPL 2.1 License".  A copy of this license should be distributed with
 * this software. http://www.gnu.org/licenses/lgpl-2.1.html
 */

package qz.installer.certificate;

import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.asn1.x500.style.BCStyle;
import org.bouncycastle.asn1.x509.GeneralName;
import org.bouncycastle.cert.jcajce.JcaX509CertificateHolder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import qz.common.Constants;
import qz.utils.ShellUtilities;
import qz.utils.SystemUtilities;

import javax.naming.InvalidNameException;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
import java.io.File;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.*;

import static qz.utils.FileUtilities.*;

public class ExpiryTask extends TimerTask {
    private static final Logger log = LogManager.getLogger(CertificateManager.class);
    public static final int DEFAULT_INITIAL_DELAY = 60 * 1000; // 1 minute
    public static final int DEFAULT_CHECK_FREQUENCY = 3600 * 1000; // 1 hour
    private static final int DEFAULT_GRACE_PERIOD_DAYS = 5;
    private enum ExpiryState {VALID, EXPIRING, EXPIRED, MANAGED}

    public enum CertProvider {
        INTERNAL(Constants.ABOUT_COMPANY + ".*"),
        LETS_ENCRYPT("Let's Encrypt.*"),
        CA_CERT_ORG("CA Cert Signing.*"),
        UNKNOWN;
        String[] patterns;
        CertProvider(String ... regexPattern) {
            this.patterns = regexPattern;
        }
    }

    private Timer timer;
    private CertificateManager certificateManager;
    private String[] hostNames;
    private CertProvider certProvider;

    public ExpiryTask(CertificateManager certificateManager) {
        super();
        this.certificateManager = certificateManager;
        this.hostNames = parseHostNames();
        this.certProvider = findCertProvider();
    }

    @Override
    public void run() {
        // Check for expiration
        ExpiryState state = getExpiry(certificateManager.getSslKeyPair().getCert());
        switch(state) {
            case EXPIRING:
            case EXPIRED:
                log.info("Certificate ExpiryState {}, renewing/reloading...", state);
                switch(certProvider) {
                    case INTERNAL:
                        if(renewInternalCert()) {
                            getExpiry();
                        }
                        break;
                    case CA_CERT_ORG:
                    case LETS_ENCRYPT:
                        if(renewExternalCert(certProvider)) {
                            getExpiry();
                        }
                        break;
                    case UNKNOWN:
                    default:
                        log.warn("Certificate can't be renewed/reloaded; ExpiryState: {}, CertProvider: {}", state, certProvider);
                }
            case VALID:
            default:
        }

    }

    public boolean renewInternalCert() {
        try {
            log.info("Requesting a new SSL certificate from {} ...", certificateManager.getCaKeyPair().getAlias());
            certificateManager.renewCertChain(hostNames);
            log.info("New SSL certificate created.  Reloading SslContextFactory...");
            certificateManager.reloadSslContextFactory();
            log.info("Reloaded SSL successfully.");
            return true;
        }
        catch(Exception e) {
            log.error("Could not reload SSL certificate", e);
        }
        return false;
    }

    public ExpiryState getExpiry() {
        return getExpiry(certificateManager.getSslKeyPair().getCert());
    }

    /**
     * Returns true if the SSL certificate is generated by QZ Tray and expires inside the GRACE_PERIOD.
     * GRACE_PERIOD is preferred for scheduling the renewals in advance, such as non-peak hours
     */
    public static ExpiryState getExpiry(X509Certificate cert) {
        // Invalid
        if (cert == null) {
            log.error("Can't check for expiration, certificate is missing.");
            return ExpiryState.EXPIRED;
        }

        Date expireDate = cert.getNotAfter();
        Calendar now = Calendar.getInstance(Locale.ENGLISH);
        Calendar expires = Calendar.getInstance(Locale.ENGLISH);
        expires.setTime(expireDate);

        // Expired
        if (now.after(expires)) {
            log.info("SSL certificate has expired {}.  It must be renewed immediately.", SystemUtilities.toISO(expireDate));
            return ExpiryState.EXPIRED;
        }

        // Expiring
        expires.add(Calendar.DAY_OF_YEAR, -DEFAULT_GRACE_PERIOD_DAYS);
        if (now.after(expires)) {
            log.info("SSL certificate will expire in less than {} days: {}", DEFAULT_GRACE_PERIOD_DAYS, SystemUtilities.toISO(expireDate));
            return ExpiryState.EXPIRING;
        }

        // Valid
        int days = (int)Math.round((expireDate.getTime() - new Date().getTime()) / (double)86400000);
        log.info("SSL certificate is still valid for {} more days: {}.  We'll make a new one automatically when needed.", days, SystemUtilities.toISO(expireDate));
        return ExpiryState.VALID;
    }

    public void schedule() {
        schedule(DEFAULT_INITIAL_DELAY, DEFAULT_CHECK_FREQUENCY);
    }

    public void schedule(int delayMillis, int freqMillis) {
        if(timer != null) {
            timer.cancel();
            timer.purge();
        }
        timer = new Timer();
        timer.scheduleAtFixedRate(this, delayMillis, freqMillis);
    }

    public String[] parseHostNames() {
        return parseHostNames(certificateManager.getSslKeyPair().getCert());
    }

    public CertProvider findCertProvider() {
        return findCertProvider(certificateManager.getSslKeyPair().getCert());
    }

    public static CertProvider findCertProvider(X509Certificate cert) {
        // Internal certs use CN=localhost, trust email instead
        if (CertificateManager.emailMatches(cert)) {
            return CertProvider.INTERNAL;
        }

        String providerDN;

        // check registered patterns to classify certificate
        if(cert.getIssuerDN() != null && (providerDN = cert.getIssuerDN().getName()) != null) {
            String cn = null;
            try {
                // parse issuer's DN
                LdapName ldapName = new LdapName(providerDN);
                for(Rdn rdn : ldapName.getRdns()) {
                    if(rdn.getType().equalsIgnoreCase("CN")) {
                        cn = (String)rdn.getValue();
                        break;
                    }
                }

                // compare cn to our pattern
                if(cn != null) {
                    for(CertProvider provider : CertProvider.values()) {
                        for(String pattern : provider.patterns) {
                            if (cn.matches(pattern)) {
                                log.warn("Cert issuer detected as {}", provider.name());
                                return provider;
                            }
                        }
                    }
                }
            } catch(InvalidNameException ignore) {}
        }

        log.warn("A valid issuer couldn't be found, we won't know how to renew this cert when it expires");
        return CertProvider.UNKNOWN;
    }

    public static String[] parseHostNames(X509Certificate cert) {
        // Cache the SAN hosts for recreation
        List<String> hostNameList = new ArrayList<>();
        try {
            Collection<List<?>> altNames = cert.getSubjectAlternativeNames();
            if (altNames != null) {
                for(List<?> altName : altNames) {
                    if(altName.size()< 1) continue;
                    switch((Integer)altName.get(0)) {
                        case GeneralName.dNSName:
                        case GeneralName.iPAddress:
                            Object data = altName.get(1);
                            if (data instanceof String) {
                                hostNameList.add(((String)data));
                            }
                            break;
                        default:
                    }
                }
            } else {
                log.error("getSubjectAlternativeNames is null?");
            }
            log.debug("Parsed hostNames: {}", String.join(", ", hostNameList));
        } catch(CertificateException e) {
            log.warn("Can't parse hostNames from this cert.  Cert renewals will contain default values instead");
        }
        return hostNameList.toArray(new String[hostNameList.size()]);
    }

    public boolean renewExternalCert(CertProvider externalProvider) {
        switch(externalProvider) {
            case LETS_ENCRYPT:
                return renewLetsEncryptCert(externalProvider);
            case CA_CERT_ORG:
            default:
                log.error("Cert renewal for {} is not implemented", externalProvider);
        }

        return false;
    }

    private boolean renewLetsEncryptCert(CertProvider externalProvider) {
        try {
            File storagePath = CertificateManager.getWritableLocation("ssl");

            // cerbot is much simpler than acme, let's use it
            Path root = Paths.get(SHARED_DIR.toString(), "letsencrypt", "config");
            log.info("Attempting to renew {}.  Assuming certs are installed in {}...", externalProvider, root);
            List<String> cmds = new ArrayList(Arrays.asList("certbot", "--force-renewal", "certonly"));

            cmds.add("--standalone");

            cmds.add("--config-dir");
            String config = Paths.get(SHARED_DIR.toString(), "ssl", "letsencrypt", "config").toString();
            cmds.add(config);

            cmds.add("--logs-dir");
            cmds.add(Paths.get(SHARED_DIR.toString(), "ssl", "letsencrypt", "logs").toString());

            cmds.add("--work-dir");
            cmds.add(Paths.get(SHARED_DIR.toString(), "ssl", "letsencrypt").toString());

            // append dns names
            for(String hostName : hostNames) {
                cmds.add("-d");
                cmds.add(hostName);
            }

            if (ShellUtilities.execute(cmds.toArray(new String[cmds.size()]))) {
                // Assume the cert is stored in a folder called "letsencrypt/config/live/<domain>"
                Path keyPath = Paths.get(config, "live", hostNames[0], "privkey.pem");
                Path certPath = Paths.get(config, "live", hostNames[0], "fullchain.pem"); // fullchain required
                certificateManager.createTrustedKeystore(keyPath.toFile(), certPath.toFile());
                log.info("Files imported, converted and saved.  Reloading SslContextFactory...");
                certificateManager.reloadSslContextFactory();
                log.info("Reloaded SSL successfully.");
                return true;
            } else {
                log.warn("Something went wrong renewing the LetsEncrypt certificate.  Please run the certbot command manually to learn more.");
            }
        } catch(Exception e) {
            log.error("Error renewing/reloading LetsEncrypt cert", e);
        }
        return false;
    }

}
