AI21 Tutorial: How to add memory to AI21 Labs model using LangChain?
Introduction
You have probably already been exposed to models from AI21 Labs. I am very comfortable working with them, but the big pain has always been having to manually record the history of my interaction with the model. Much has changed, however, and LangChain makes it possible to implement such an operation in the blink of an eye!
In this tutorial I will try to explain how to implement this quickly, then you can play with it yourself!
Implementation
Dependencies
First, we need to create project directory, virtual environment and install some dependencies. Let's do it!
mkdir ai21-memory
cd ai21-memory
Now we need to create virtual environment:
python3 -m venv venv
# Linux/MacOS
source venv/bin/activate
# Windows
venv\Scripts\activate.bat
Last step of this stage - installing dependencies:
pip install langchain ai21 python-dotenv
Coding time!
Now we can start coding! First, we need to create .env
file with our API key. You can get it from AI21 Labs Studio. I will use AI21_API_KEY
as variable name.
AI21_API_KEY=<YOUR_API_KEY>
Now we can create main.py
file and start coding!
I will import all necessary modules and load API key.
import os
from dotenv import load_dotenv
from langchain import LLMChain, PromptTemplate
from langchain.llms import AI21
from langchain.memory import ConversationBufferMemory
load_dotenv()
AI21_API_KEY = os.getenv("AI21_API_KEY")
I decided to build a simple chatbot for the tutorial. It will be very simple and run in a terminal. Feel free to create any other application, I just want to show you how you can do it! Anyway, let's do it!
I will create a prompt template that will make it easier for the model to understand our task. As I will be using a model which by default is not a chat model but a regular LLM this is necessary. I will also create a memory object that will store our conversation history.
template = """Conduct conversations, respond based on history.\n\nHistory:\n---\n{chat_history}\n---\nCurrent input:\n---\nHuman: {input}\n\nResponse\n---\nAI: """
prompt = PromptTemplate(
input_variables=["input", "chat_history"],
template=template,
)
memory = ConversationBufferMemory(memory_key="chat_history")
I will use ConversationBufferMemory
, but you can use another type of memory. In this article about conversation memory types, others are well explained! Check it out!
Now we can create a LLMChain
that will take care of our conversation.
llm_chain = LLMChain(
llm=AI21(model="j2-ultra", ai21_api_key=AI21_API_KEY),
prompt=prompt,
verbose=True,
memory=memory,
)
I Use verbose=True
to browse the input from memory to Chain
Great, now we can create a main loop that iterates through the stages of our conversation.
while True:
inp = input("You: ")
res = llm_chain.predict(input=inp)
print(f"AI: {res}")
Let's check out the results!
Results
Let's run our application!
python main.py
Conversation
I conducted a simple conversation to test the performance of our programme. As you can see the story is already working in the second stage, as it should be!
As you can see, implementing this is not that difficult and does not take long, and with the combination of AI21 and LangChain we are able to build really powerful applications. I encourage you to check it out during our Plug into AI with AI21 Hackathon! Feel free to join and build with us!
Thank you for your time! - Jakub Misiło @newnative