1+ # -*- coding: utf-8 -*-
2+ """Advanced_LitServe_Multi_Endpoint_API_Tutorial_Marktechpost.ipynb
3+
4+ Automatically generated by Colab.
5+
6+ Original file is located at
7+ https://colab.research.google.com/drive/1_1ItSNeswq8VMclnmb2B_-WXssmTRzsC
8+ """
9+
10+ !pip install litserve torch transformers - q
11+
12+ import litserve as ls
13+ import torch
14+ from transformers import pipeline
15+ import time
16+ from typing import List
17+
18+ class TextGeneratorAPI (ls .LitAPI ):
19+ def setup (self , device ):
20+ self .model = pipeline ("text-generation" , model = "distilgpt2" , device = 0 if device == "cuda" and torch .cuda .is_available () else - 1 )
21+ self .device = device
22+ def decode_request (self , request ):
23+ return request ["prompt" ]
24+ def predict (self , prompt ):
25+ result = self .model (prompt , max_length = 100 , num_return_sequences = 1 , temperature = 0.8 , do_sample = True )
26+ return result [0 ]['generated_text' ]
27+ def encode_response (self , output ):
28+ return {"generated_text" : output , "model" : "distilgpt2" }
29+
30+ class BatchedSentimentAPI (ls .LitAPI ):
31+ def setup (self , device ):
32+ self .model = pipeline ("sentiment-analysis" , model = "distilbert-base-uncased-finetuned-sst-2-english" , device = 0 if device == "cuda" and torch .cuda .is_available () else - 1 )
33+ def decode_request (self , request ):
34+ return request ["text" ]
35+ def batch (self , inputs : List [str ]) -> List [str ]:
36+ return inputs
37+ def predict (self , batch : List [str ]):
38+ results = self .model (batch )
39+ return results
40+ def unbatch (self , output ):
41+ return output
42+ def encode_response (self , output ):
43+ return {"label" : output ["label" ], "score" : float (output ["score" ]), "batched" : True }
44+
45+ class StreamingTextAPI (ls .LitAPI ):
46+ def setup (self , device ):
47+ self .model = pipeline ("text-generation" , model = "distilgpt2" , device = 0 if device == "cuda" and torch .cuda .is_available () else - 1 )
48+ def decode_request (self , request ):
49+ return request ["prompt" ]
50+ def predict (self , prompt ):
51+ words = ["Once" , "upon" , "a" , "time" , "in" , "a" , "digital" , "world" ]
52+ for word in words :
53+ time .sleep (0.1 )
54+ yield word + " "
55+ def encode_response (self , output ):
56+ for token in output :
57+ yield {"token" : token }
58+
59+ class MultiTaskAPI (ls .LitAPI ):
60+ def setup (self , device ):
61+ self .sentiment = pipeline ("sentiment-analysis" , device = - 1 )
62+ self .summarizer = pipeline ("summarization" , model = "sshleifer/distilbart-cnn-6-6" , device = - 1 )
63+ self .device = device
64+ def decode_request (self , request ):
65+ return {"task" : request .get ("task" , "sentiment" ), "text" : request ["text" ]}
66+ def predict (self , inputs ):
67+ task = inputs ["task" ]
68+ text = inputs ["text" ]
69+ if task == "sentiment" :
70+ result = self .sentiment (text )[0 ]
71+ return {"task" : "sentiment" , "result" : result }
72+ elif task == "summarize" :
73+ if len (text .split ()) < 30 :
74+ return {"task" : "summarize" , "result" : {"summary_text" : text }}
75+ result = self .summarizer (text , max_length = 50 , min_length = 10 )[0 ]
76+ return {"task" : "summarize" , "result" : result }
77+ else :
78+ return {"task" : "unknown" , "error" : "Unsupported task" }
79+ def encode_response (self , output ):
80+ return output
81+
82+ class CachedAPI (ls .LitAPI ):
83+ def setup (self , device ):
84+ self .model = pipeline ("sentiment-analysis" , device = - 1 )
85+ self .cache = {}
86+ self .hits = 0
87+ self .misses = 0
88+ def decode_request (self , request ):
89+ return request ["text" ]
90+ def predict (self , text ):
91+ if text in self .cache :
92+ self .hits += 1
93+ return self .cache [text ], True
94+ self .misses += 1
95+ result = self .model (text )[0 ]
96+ self .cache [text ] = result
97+ return result , False
98+ def encode_response (self , output ):
99+ result , from_cache = output
100+ return {"label" : result ["label" ], "score" : float (result ["score" ]), "from_cache" : from_cache , "cache_stats" : {"hits" : self .hits , "misses" : self .misses }}
101+
102+ def test_apis_locally ():
103+ print ("=" * 70 )
104+ print ("Testing APIs Locally (No Server)" )
105+ print ("=" * 70 )
106+
107+ api1 = TextGeneratorAPI (); api1 .setup ("cpu" )
108+ decoded = api1 .decode_request ({"prompt" : "Artificial intelligence will" })
109+ result = api1 .predict (decoded )
110+ encoded = api1 .encode_response (result )
111+ print (f"✓ Result: { encoded ['generated_text' ][:100 ]} ..." )
112+
113+ api2 = BatchedSentimentAPI (); api2 .setup ("cpu" )
114+ texts = ["I love Python!" , "This is terrible." , "Neutral statement." ]
115+ decoded_batch = [api2 .decode_request ({"text" : t }) for t in texts ]
116+ batched = api2 .batch (decoded_batch )
117+ results = api2 .predict (batched )
118+ unbatched = api2 .unbatch (results )
119+ for i , r in enumerate (unbatched ):
120+ encoded = api2 .encode_response (r )
121+ print (f"✓ '{ texts [i ]} ' -> { encoded ['label' ]} ({ encoded ['score' ]:.2f} )" )
122+
123+ api3 = MultiTaskAPI (); api3 .setup ("cpu" )
124+ decoded = api3 .decode_request ({"task" : "sentiment" , "text" : "Amazing tutorial!" })
125+ result = api3 .predict (decoded )
126+ print (f"✓ Sentiment: { result ['result' ]} " )
127+
128+ api4 = CachedAPI (); api4 .setup ("cpu" )
129+ test_text = "LitServe is awesome!"
130+ for i in range (3 ):
131+ decoded = api4 .decode_request ({"text" : test_text })
132+ result = api4 .predict (decoded )
133+ encoded = api4 .encode_response (result )
134+ print (f"✓ Request { i + 1 } : { encoded ['label' ]} (cached: { encoded ['from_cache' ]} )" )
135+
136+ print ("=" * 70 )
137+ print ("✅ All tests completed successfully!" )
138+ print ("=" * 70 )
139+
140+ test_apis_locally ()
0 commit comments