"""Mail List Shield Flask application factory and extensions.
This module provides the Flask application factory function and initializes
Flask extensions used throughout the application.
"""
from flask import Flask, request, redirect
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_login import LoginManager
from flask_bcrypt import Bcrypt
from flask_mail import Mail
from flask_sslify import SSLify
from flask_cors import CORS
from flask_wtf.csrf import CSRFProtect
from decouple import UndefinedValueError
from datetime import datetime
from oauthlib.oauth2 import WebApplicationClient
import stripe
import sys
from app.config import appTimezone
# Initialize the Flask extensions without attaching them to the app
[docs]
def create_app(config_class="app.config.Config", test_config=False):
"""Application factory to create a Flask application instance
with the specified configuration.
Args:
config_class: Reference to the configuration object to use.
test_config: If True, apply the test configuration.
Returns:
Flask: The Flask application instance.
"""
# Initialize the flask app
app = Flask(__name__)
# Load the Config object
try:
app.config.from_object(config_class)
except UndefinedValueError as e:
print(
"Please make sure you have all the required environment variables set.", e
)
sys.exit()
# If specified, load the test configuration to override the Config object
if test_config:
# Disable Postgres pooling when SQLite is used
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {}
# Initialize the Flask extensions for the app instance
mail.init_app(app)
db.init_app(app)
bc.init_app(app)
lm.init_app(app)
csrf.init_app(app)
SSLify(app)
# Enable CORS for all routes
CORS(
app,
resources={r"/*": {"origins": [app.config["APP_ROOT_URL"]]}},
supports_credentials=False,
)
# Set up the database tables if they don't exist
try:
with app.app_context():
# Import the table models
from app.models import Users, Tiers, BatchJobs, APIKeys
# Create the database tables if they don't exist
db.create_all()
# If there are no tiers in the database, create the default tiers
if Tiers.query.count() == 0:
default_tier = Tiers(name="free", label="Free")
db.session.add(default_tier)
db.session.commit()
print(" * Default tiers created successfully.")
except Exception as e:
print("Error creating the database tables.")
print(e)
sys.exit()
# Import the Blueprints
from app.views import public_bp
from app.views_private import private_bp
from app.views_auth import auth_bp
from app.views_api import api_bp
# Register the Blueprints
app.register_blueprint(public_bp, url_prefix="/")
app.register_blueprint(private_bp, url_prefix="/app")
app.register_blueprint(auth_bp, url_prefix="/")
app.register_blueprint(api_bp, url_prefix="/api")
# Variables available in all templates
@app.context_processor
def inject_globals():
"""Inject global variables into all templates.
Returns:
dict: Dictionary of variables available in all templates.
"""
return {
"APP_NAME": app.config["APP_NAME"],
"COPYRIGHT": f"2021–{datetime.now().year} - {app.config['APP_NAME']}",
}
# Redirect pages with trailing slashes to versions without
# Applies to other Blueprints like app_private as well
@app.before_request
def remove_trailing_slash():
"""Remove trailing slashes from URLs and redirect.
Returns:
Response: Redirect response if URL has trailing slash, None otherwise.
"""
if (
request.path != "/"
and request.path != "/app/"
and request.path.endswith("/")
):
return redirect(request.path[:-1])
# Register a Jinja2 filter for date formatting
@app.template_filter("dateformat")
def dateformat_filter(value, format="%B %d, %Y"):
"""Format a Unix timestamp as a date string.
This is used to format the date in the Jinja templates.
Args:
value: Unix timestamp to format.
format: strftime format string.
Returns:
str: Formatted date string.
"""
return datetime.fromtimestamp(value).strftime(format)
# Register a Jinja2 filter for time formatting
@app.template_filter("timeformat")
def timeformat_filter(value, format="%I:%M %p"):
"""Format a Unix timestamp as a time string.
This is used to format the time in the Jinja templates.
Args:
value: Unix timestamp to format.
format: strftime format string.
Returns:
str: Formatted time string.
"""
return datetime.fromtimestamp(value).strftime(format)
# Register a Jinja2 filter for date formatting for database dates
@app.template_filter("dbDateformat")
def dbDateformat_filter(value, format="%B %d, %Y"):
"""Format a database datetime as a date string in the app timezone.
Args:
value: datetime object from the database.
format: strftime format string.
Returns:
str: Formatted date string.
"""
return value.astimezone(appTimezone).strftime(format)
# Register a Jinja2 filter for time formatting for database dates
@app.template_filter("dbTimeformat")
def dbTimeformat_filter(value, format="%I:%M %p"):
"""Format a database datetime as a time string in the app timezone.
Args:
value: datetime object from the database.
format: strftime format string.
Returns:
str: Formatted time string.
"""
return value.astimezone(appTimezone).strftime(format)
# Register a Jinja2 filter for formatting numbers with thousand separators
@app.template_filter("thousandSeparator")
def thousandSeparator_filter(value):
"""Format a number with thousand separators.
Args:
value: Number to format.
Returns:
str: Number formatted with commas as thousand separators.
"""
return "{:,}".format(value)
# Prettify batch validation job status
@app.template_filter("prettifyJobStatus")
def prettifyJobStatus_filter(job):
"""Convert a batch job status from the database into a nicer looking string.
Args:
job: BatchJob object with a status attribute.
Returns:
str: Human-readable status string.
"""
match job.status:
case "pending_start":
return "Pending Start"
case "file_accepted":
return "File Accepted"
case "file_queued":
return "File Queued"
case "file_validation_in_progress":
return "Processing"
case "file_completed":
return "Completed"
case s if s.startswith("error"):
return "Failed"
case _:
return job.status
# Google OAuth 2 client configuration
app.google_client = WebApplicationClient(app.config["GOOGLE_CLIENT_ID"])
# Stripe setup
stripe.api_key = app.config["STRIPE_SECRET_KEY"]
# Return the app instance created
return app