AI21 Labs + Stable Diffusion Tutorial: Beginner friendly Tutorial - How to build your app with AI21 Labs, adding Stable Diffusion integration.
Introduction
Stable Diffusion, is a new generative model that can generate high-resolution images with a single forward pass. AI21 Studio is a platform that provides developers and businesses with top-tier natural language processing (NLP) solutions, powered by AI21 Labsโ state-of-the-art language models.
What we are going to do?
In this tutorial, we will build a simple and funny app, for creating engaging tweets, using AI21 Labs and cool cover images using Stable Diffusion. We will use Streamlit to build our app. Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.
Prerequisites
Go to Visual Studio Code and choose that compatible with your operating system and download it, or feel free to use any other code editor like: IntelliJ IDEA, PyCharm, etc.
To use AI21 Labs Studio we need API Key. Go to AI21 Labs Studio, Sign up for an account.
Go to Dashboard once you have created an account.
Then, click to Profile picture
on the top right corner.
Select API Key
from the dropdown menu.
Click copy icon
to copy your API Key. Don't forget to save it somewhere safe.
To use Stable Diffusion we need API Key. Go to Dream Studio, Sign up for an account to be taken to your API Key.
Go to Dashboard. Then, click +
button to create new API Key.
Click copy icon
to copy your API Key. Don't forget to save it somewhere safe.
To deploy our app we need to create an account on Streamlit. Go to Streamlit and click Sign up
button.
I recommend you to use GitHub to sign up. Later it will be helpful to depploy our app in a seconds. Click Sign up with GitHub
button.
Getting Started
Create a brand new project
Open Visual Studio Code and create new folder named ai21-sd-tutorial
:
mkdir ai21-sd-tutorial
cd ai21-sd-tutorial
Create a virtual environment and activate it
Next, we need to create a virtual environment and activate it. To do this, run the following command in your terminal:
python3 -m venv venv
# on MacOS and Linux:
source venv/bin/activate
# on Windows:
venv\Scripts\activate
Install all dependencies
Now, we need to install all dependencies. To do this, open and run the following command in your terminal:
pip install streamlit
pip install ai21
pip install stability-sdk
Implementing the app
Firstly, let's create a file named stable_diffusion.py
and implement function to generate image from text using Stable Diffusion model. Read more about Text-to-image gRPC API python client.
import os
import io
import warnings
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
# Our Host URL should not be prepended with "https" nor should it have a trailing slash.
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
# Sign up for an account at the following link to get an API Key.
# https://dreamstudio.ai/
# Click on the following link once you have created an account to be taken to your API Key.
# https://dreamstudio.ai/account
# Paste your API Key below.
def imagine(prompt, stable_diffusion_api_key):
try:
output_path = ""
# Set up our connection to the API.
stability_api = client.StabilityInference(
# key=os.environ.get("STABILITY_KEY"), # API Key reference.
key=stable_diffusion_api_key, # API Key reference.
verbose=True, # Print debug messages.
engine="stable-diffusion-xl-beta-v2-2-2", # Set the engine to use for generation.
# Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0
# stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0
)
# Set up our initial generation parameters.
answers = stability_api.generate(
prompt=prompt, # The prompt to generate from.
seed=992446758, # If a seed is provided, the resulting generated image will be deterministic.
# What this means is that as long as all generation parameters remain the same, you can always recall the same image simply by generating it again.
# Note: This isn't quite the case for CLIP Guided generations, which we tackle in the CLIP Guidance documentation.
steps=30, # Amount of inference steps performed on image generation. Defaults to 30.
cfg_scale=8.0, # Influences how strongly your generation is guided to match your prompt.
# Setting this value higher increases the strength in which it tries to match your prompt.
# Defaults to 7.0 if not specified.
width=512, # Generation width, defaults to 512 if not included.
height=512, # Generation height, defaults to 512 if not included.
samples=1, # Number of images to generate, defaults to 1 if not included.
sampler=generation.SAMPLER_K_DPMPP_2M # Choose which sampler we want to denoise our generation with.
# Defaults to k_dpmpp_2m if not specified. Clip Guidance only supports ancestral samplers.
# (Available Samplers: ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m, k_dpmpp_sde)
)
# Set up our warning to print to the console if the adult content classifier is tripped.
# If adult content classifier is not tripped, save generated images.
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please modify the prompt and try again.")
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
prompt.replace(".", "")
img.save(prompt[:10] + '.png') # Save our generated images with their seed number as the filename.
img_path = prompt[:10] + '.png'
return img_path
except Exception as e:
print(e)
return ""
Next, let's create a file named ai21_studio.py
and implement function which will be used to generate cool ideas for engaging tweets to promote a brand or product:
import ai21
def generate(prompt, ai21_api_key):
ai21.api_key = ai21_api_key
_prompt = """
Write a tweet to promote an event for Shmolt, a food and more delivery company.
##
Event: Christmas.
Tweet: Oven broke in the last minute and you have no turkey? No need to panic, with Shmolt delivery it will be at your door step by Christmas eve!
##
Event: Valentine's Day.
Tweet: Order her roses using Shmolt, so you can have time to stop and smell the roses #Valentines
##
Event: Summer Sale.
Tweet: In this heat, why get out of the AC when you have 10% off of all Shmolt deliveries?
##
Event: Mother's Day.
Tweet:
"""
response = ai21.Completion.execute(
model="j2-mid", # The engine to use for generation.
prompt=_prompt + prompt,
numResults=1,
maxTokens=64,
temperature=0.84,
topKReturn=0,
topP=1,
countPenalty={
"scale": 0.1,
"applyToNumbers": False,
"applyToPunctuations": False,
"applyToStopwords": False,
"applyToWhitespaces": False,
"applyToEmojis": False
},
frequencyPenalty={
"scale": 0,
"applyToNumbers": False,
"applyToPunctuations": False,
"applyToStopwords": False,
"applyToWhitespaces": False,
"applyToEmojis": False
},
presencePenalty={
"scale": 0,
"applyToNumbers": False,
"applyToPunctuations": False,
"applyToStopwords": False,
"applyToWhitespaces": False,
"applyToEmojis": False
},
stopSequences=["##"]
)
return response['completions'][0]['data']['text']
Perfect! Now we have all the functions we need to generate engaging
ideas and cool
cover images.
Finally, Create a file named app.py
here we will implement our logic for Streamlit app. It allows us to create really fast web apps in pure Python and share it with everyone.
Let's import all the necessary libraries and functions:
# Import from standard library
import os
import logging
# Import from 3rd party libraries
import streamlit as st
# Import modules from the local package
import stable_diffusion, ai21_studio
Basic configuration for the app:
# Configure logger
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)
# Configure Streamlit page and state
st.set_page_config(page_title="AI21 Studio Tutorial", page_icon="๐ฉ")
Initialize the app states:
# Store the initial value of widgets in session state
if "imagine" not in st.session_state:
st.session_state.imagine = ""
if "img_path" not in st.session_state:
st.session_state.img_path = ""
if "query" not in st.session_state:
st.session_state.query = ""
if "im_query" not in st.session_state:
st.session_state.im_query = ""
if "prompt_generate" not in st.session_state:
st.session_state.prompt_generate = ""
if "text_error" not in st.session_state:
st.session_state.text_error = ""
if "stable_diffusion_api_key" not in st.session_state:
st.session_state.stable_diffusion_api_key = ""
if "ai21_api_key" not in st.session_state:
st.session_state.ai21_api_key = ""
if "visibility" not in st.session_state:
st.session_state.visibility = "visible"
Work on the app layout:
# Force responsive layout for columns also on mobile
st.write(
"""
<style>
[data-testid="column"] {
width: calc(50% - 1rem);
flex: 1 1 calc(50% - 1rem);
min-width: calc(50% - 1rem);
}
</style>
""",
unsafe_allow_html=True,
)
Let's create sidebar where we can put our API keys
. Why not simple use .env
? Because we want to make it available for everyone to use it without any hassle. Later, we can also deploy it to Streamlit Sharing Cloud.
# Render Streamlit page
with st.sidebar:
st.session_state.cohere_api_key = st.text_input('AI21 Studio API Key', )
st.session_state.stable_diffusion_api_key = st.text_input('Stable Diffusion API Key', )
Quickly add a cool title
and description
for the app:
# title of the app
st.title("AI21 Studio + Stable Diffusion Tutorial")
st.markdown(
"This is a demo of the AI21 Studio + Stable Diffusion Tutorial app."
)
Now, let's add a st.text_area()
widget to get the prompt
from the user and a st.button()
widget to generate the ideas
based on the prompt.
# textarea
st.session_state.query = st.text_area(
label="Generate engaging ideas for tweets",
placeholder="Elon Musk wants to fight with Mu", height=100)
# button
st.button(
label="Generate ideas",
help="Click to genearate ideas",
key="generate_prompt",
type="primary",
on_click=generate_ideas,
)
Now, let's add another st.text_area()
widget for image generation using Stable Diffusion. You may ask why we don't instantly generate the image after clicking the Generate ideas
button along with the ideas. The reason is that we want to make it more interactive and fun.
# textarea
st.session_state.im_query = st.text_area(label="Cool description for tweets cover", placeholder="Elon Musk wants to fight with Mu", height=100)
# button
st.button(
label="Generate image",
help="Click to genearate image",
key="generate_image",
type="primary",
on_click=imagine,
)
Output genearated ideas
and image
:
text_spinner_placeholder = st.empty()
if st.session_state.text_error:
st.error(st.session_state.text_error)
if st.session_state.prompt_generate:
st.markdown("""---""")
st.text_area(label="Cool ideas", value=st.session_state.prompt_generate,)
if st.session_state.img_path:
st.markdown("""---""")
st.subheader("Cool image")
st.image(st.session_state.img_path, use_column_width=True, caption="Image generated by Stable Diffusion", output_format="PNG")
And last but not least, it's time to implement main functions for generating ideas
:
def generate_ideas():
st.session_state.text_error = ""
if st.session_state.cohere_api_key == "":
st.session_state.text_error = "Missed API key."
return
st.session_state.text_error = ""
if st.session_state.file_path == "" or st.session_state.query == "":
st.session_state.text_error = "Missed a file or query."
return
st.session_state.prompt_generate = ""
st.session_state.text_error = ""
with text_spinner_placeholder:
with st.spinner("Please wait while we process your query..."):
prompt = ai21_studio.generate(prompt=st.session_state.query, ai21_api_key=st.session_state.ai21_api_key)
if prompt == "":
st.session_state.text_error = "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again."
logging.info(f"Text Error: {st.session_state.text_error}")
return
st.session_state.prompt_generate = (prompt)
and image
:
def imagine():
st.session_state.text_error = ""
if st.session_state.stable_diffusion_api_key == "":
st.session_state.text_error = "Missed API key."
return
st.session_state.text_error = ""
if st.session_state.im_query == "":
st.session_state.text_error = "Missed a query."
return
st.session_state.imagine = ""
st.session_state.img_path = ""
st.session_state.text_error = ""
with text_spinner_placeholder:
with st.spinner("Please wait while we generate your image..."):
im_path = stable_diffusion.imagine(prompt=st.session_state.im_query, stable_diffusion_api_key=st.session_state.stable_diffusion_api_key)
if im_path == "":
st.session_state.text_error = "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again."
logging.info(f"Text Error: {st.session_state.text_error}")
return
st.session_state.img_path = (im_path)
Perfect! Now, we can run our app and see the result.
streamlit run app.py
Deploy to Streamlit Sharing Cloud
Checkout ElevenLabs Tutorial for detailed instructions on how to deploy your app to Streamlit Sharing Cloud.
Conclusion
We build something cool today! We learned how to use AI21 Studio API to generate ideas for tweets and Stable Diffusion API to generate images for tweets cover. We also learned how to deploy our app to Streamlit Sharing Cloud. Perfect! Don't forget to join upcoming hackathons on lablab.ai and win cool prizes.
Thank you for following along with this tutorial.
If you have any questions, feel free to reach out to me on LinkedIn or Twitter. I'd love to hear from you!
made with ๐ by abdibrokhim for lablab.ai tutorials.