Skip to content

Commit 5376857

Browse files
authored
Add files via upload
1 parent 772cd80 commit 5376857

1 file changed

Lines changed: 140 additions & 0 deletions

File tree

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# -*- coding: utf-8 -*-
2+
"""Advanced_LitServe_Multi_Endpoint_API_Tutorial_Marktechpost.ipynb
3+
4+
Automatically generated by Colab.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/1_1ItSNeswq8VMclnmb2B_-WXssmTRzsC
8+
"""
9+
10+
!pip install litserve torch transformers -q
11+
12+
import litserve as ls
13+
import torch
14+
from transformers import pipeline
15+
import time
16+
from typing import List
17+
18+
class TextGeneratorAPI(ls.LitAPI):
19+
def setup(self, device):
20+
self.model = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
21+
self.device = device
22+
def decode_request(self, request):
23+
return request["prompt"]
24+
def predict(self, prompt):
25+
result = self.model(prompt, max_length=100, num_return_sequences=1, temperature=0.8, do_sample=True)
26+
return result[0]['generated_text']
27+
def encode_response(self, output):
28+
return {"generated_text": output, "model": "distilgpt2"}
29+
30+
class BatchedSentimentAPI(ls.LitAPI):
31+
def setup(self, device):
32+
self.model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
33+
def decode_request(self, request):
34+
return request["text"]
35+
def batch(self, inputs: List[str]) -> List[str]:
36+
return inputs
37+
def predict(self, batch: List[str]):
38+
results = self.model(batch)
39+
return results
40+
def unbatch(self, output):
41+
return output
42+
def encode_response(self, output):
43+
return {"label": output["label"], "score": float(output["score"]), "batched": True}
44+
45+
class StreamingTextAPI(ls.LitAPI):
46+
def setup(self, device):
47+
self.model = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
48+
def decode_request(self, request):
49+
return request["prompt"]
50+
def predict(self, prompt):
51+
words = ["Once", "upon", "a", "time", "in", "a", "digital", "world"]
52+
for word in words:
53+
time.sleep(0.1)
54+
yield word + " "
55+
def encode_response(self, output):
56+
for token in output:
57+
yield {"token": token}
58+
59+
class MultiTaskAPI(ls.LitAPI):
60+
def setup(self, device):
61+
self.sentiment = pipeline("sentiment-analysis", device=-1)
62+
self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device=-1)
63+
self.device = device
64+
def decode_request(self, request):
65+
return {"task": request.get("task", "sentiment"), "text": request["text"]}
66+
def predict(self, inputs):
67+
task = inputs["task"]
68+
text = inputs["text"]
69+
if task == "sentiment":
70+
result = self.sentiment(text)[0]
71+
return {"task": "sentiment", "result": result}
72+
elif task == "summarize":
73+
if len(text.split()) < 30:
74+
return {"task": "summarize", "result": {"summary_text": text}}
75+
result = self.summarizer(text, max_length=50, min_length=10)[0]
76+
return {"task": "summarize", "result": result}
77+
else:
78+
return {"task": "unknown", "error": "Unsupported task"}
79+
def encode_response(self, output):
80+
return output
81+
82+
class CachedAPI(ls.LitAPI):
83+
def setup(self, device):
84+
self.model = pipeline("sentiment-analysis", device=-1)
85+
self.cache = {}
86+
self.hits = 0
87+
self.misses = 0
88+
def decode_request(self, request):
89+
return request["text"]
90+
def predict(self, text):
91+
if text in self.cache:
92+
self.hits += 1
93+
return self.cache[text], True
94+
self.misses += 1
95+
result = self.model(text)[0]
96+
self.cache[text] = result
97+
return result, False
98+
def encode_response(self, output):
99+
result, from_cache = output
100+
return {"label": result["label"], "score": float(result["score"]), "from_cache": from_cache, "cache_stats": {"hits": self.hits, "misses": self.misses}}
101+
102+
def test_apis_locally():
103+
print("=" * 70)
104+
print("Testing APIs Locally (No Server)")
105+
print("=" * 70)
106+
107+
api1 = TextGeneratorAPI(); api1.setup("cpu")
108+
decoded = api1.decode_request({"prompt": "Artificial intelligence will"})
109+
result = api1.predict(decoded)
110+
encoded = api1.encode_response(result)
111+
print(f"✓ Result: {encoded['generated_text'][:100]}...")
112+
113+
api2 = BatchedSentimentAPI(); api2.setup("cpu")
114+
texts = ["I love Python!", "This is terrible.", "Neutral statement."]
115+
decoded_batch = [api2.decode_request({"text": t}) for t in texts]
116+
batched = api2.batch(decoded_batch)
117+
results = api2.predict(batched)
118+
unbatched = api2.unbatch(results)
119+
for i, r in enumerate(unbatched):
120+
encoded = api2.encode_response(r)
121+
print(f"✓ '{texts[i]}' -> {encoded['label']} ({encoded['score']:.2f})")
122+
123+
api3 = MultiTaskAPI(); api3.setup("cpu")
124+
decoded = api3.decode_request({"task": "sentiment", "text": "Amazing tutorial!"})
125+
result = api3.predict(decoded)
126+
print(f"✓ Sentiment: {result['result']}")
127+
128+
api4 = CachedAPI(); api4.setup("cpu")
129+
test_text = "LitServe is awesome!"
130+
for i in range(3):
131+
decoded = api4.decode_request({"text": test_text})
132+
result = api4.predict(decoded)
133+
encoded = api4.encode_response(result)
134+
print(f"✓ Request {i+1}: {encoded['label']} (cached: {encoded['from_cache']})")
135+
136+
print("=" * 70)
137+
print("✅ All tests completed successfully!")
138+
print("=" * 70)
139+
140+
test_apis_locally()

0 commit comments

Comments
 (0)