You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
7.2 KiB
Python
190 lines
7.2 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.13.7"
|
|
app = marimo.App(width="medium")
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
import requests
|
|
import json
|
|
import numpy as np
|
|
from sklearn.manifold import TSNE
|
|
import matplotlib.pyplot as plt
|
|
import umap
|
|
import pandas as pd
|
|
import plotly.express as px
|
|
return json, np, pd, px, requests, umap
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
# --- Configuration ---
|
|
OLLAMA_URL = "http://localhost:11434/api/embeddings"
|
|
MODEL_NAME = "nomic-embed-text" # Make sure you have pulled this model
|
|
return MODEL_NAME, OLLAMA_URL
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
# --- Data: Sentences to Analyze (Focus on Sentiment & Topics) ---
|
|
sentences = [
|
|
# Positive Sentiments
|
|
"I had an absolutely wonderful time on vacation!", # Travel, Positive
|
|
"This is the best pizza I've ever tasted, truly amazing.",# Food, Positive
|
|
"The new software update significantly improved performance.",# Tech, Positive
|
|
"She was overjoyed to receive the award.", # Emotion, Positive
|
|
"What a beautiful, sunny day for a picnic!", # Weather/Activity, Positive
|
|
|
|
# Negative Sentiments
|
|
"The airline lost my luggage, ruining the start of my trip.", # Travel, Negative
|
|
"I found the meal to be bland and overpriced.", # Food, Negative
|
|
"Debugging this legacy code is incredibly frustrating.", # Tech, Negative
|
|
"He felt heartbroken and betrayed after the argument.", # Emotion, Negative
|
|
"The constant rain made the whole weekend gloomy.", # Weather/Activity, Negative
|
|
|
|
# Neutral/Objective Sentiments
|
|
"The train is scheduled to depart at 3:00 PM.", # Travel, Neutral
|
|
"The ingredients listed include flour, water, and salt.",# Food, Neutral
|
|
"The system requires 8GB of RAM to operate.", # Tech, Neutral
|
|
"Please file the report by the end of the day.", # Work/Instruction, Neutral
|
|
"The weather report indicates a chance of showers tomorrow.",# Weather, Neutral
|
|
]
|
|
return (sentences,)
|
|
|
|
|
|
@app.cell
|
|
def _(MODEL_NAME, OLLAMA_URL, json, np, requests, sentences):
|
|
|
|
# --- 1. Get Embeddings from Ollama ---
|
|
embeddings = []
|
|
print(f"Getting embeddings using model: {MODEL_NAME}...")
|
|
for i, sentence in enumerate(sentences):
|
|
print(f" Processing sentence {i+1}/{len(sentences)}: '{sentence[:30]}...'")
|
|
try:
|
|
payload = {
|
|
"model": MODEL_NAME,
|
|
"prompt": sentence
|
|
}
|
|
response = requests.post(OLLAMA_URL, json=payload)
|
|
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
|
|
|
# Parse the JSON response line by line if streaming, or directly if not
|
|
# Ollama's embedding API typically returns a single JSON object, not streamed
|
|
response_data = response.json()
|
|
embeddings.append(response_data.get("embedding"))
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"\nError connecting to Ollama or during API request: {e}")
|
|
print("Ensure Ollama is running and the model name is correct.")
|
|
exit()
|
|
except json.JSONDecodeError as e:
|
|
print(f"\nError decoding JSON response: {e}")
|
|
print(f"Received text: {response.text}")
|
|
exit()
|
|
except Exception as e:
|
|
print(f"\nAn unexpected error occurred for sentence '{sentence}': {e}")
|
|
exit()
|
|
|
|
# Check if we got any embeddings
|
|
if not embeddings or any(e is None for e in embeddings):
|
|
print("\nError: Failed to retrieve valid embeddings for some sentences.")
|
|
exit()
|
|
|
|
embeddings_array = np.array(embeddings)
|
|
print(f"\nSuccessfully got {embeddings_array.shape[0]} embeddings with dimension {embeddings_array.shape[1]}.")
|
|
return (embeddings_array,)
|
|
|
|
|
|
@app.cell
|
|
def _(embeddings_array, umap):
|
|
# --- 2. Dimensionality Reduction (UMAP) ---
|
|
print("Reducing dimensionality using UMAP...")
|
|
|
|
# Check if we have enough samples for default neighbors
|
|
n_samples = embeddings_array.shape[0]
|
|
if n_samples <= 1:
|
|
print("\nError: Need at least 2 data points for UMAP.")
|
|
exit()
|
|
|
|
# Adjust n_neighbors if necessary - must be less than n_samples
|
|
# Common defaults are 5-15. Lower values focus more on local structure.
|
|
n_neighbors_value = min(15, n_samples - 1)
|
|
if n_neighbors_value < 2:
|
|
print(f"\nWarning: Small dataset (n={n_samples}). Setting UMAP n_neighbors to {n_samples - 1}.")
|
|
n_neighbors_value = max(2, n_samples - 1) # UMAP needs at least 2 neighbors
|
|
|
|
print(f" Using UMAP with n_neighbors={n_neighbors_value}")
|
|
|
|
# Initialize UMAP
|
|
# metric='cosine' is often recommended for high-dimensional text embeddings
|
|
reducer = umap.UMAP(
|
|
n_neighbors=n_neighbors_value,
|
|
n_components=2, # Target dimension
|
|
min_dist=0.1, # Controls how tightly points are packed
|
|
metric='cosine', # Distance metric suitable for embeddings
|
|
)
|
|
|
|
# Fit and transform the data
|
|
reduced_embeddings = reducer.fit_transform(embeddings_array)
|
|
print("Dimensionality reduction complete.")
|
|
return (reduced_embeddings,)
|
|
|
|
|
|
@app.cell
|
|
def _(MODEL_NAME, pd, px, reduced_embeddings, sentences):
|
|
# Create a Pandas DataFrame for easier plotting with Plotly Express
|
|
df = pd.DataFrame({
|
|
'x': reduced_embeddings[:, 0],
|
|
'y': reduced_embeddings[:, 1],
|
|
'sentence': sentences,
|
|
'index': [i + 1 for i in range(len(sentences))] # 1-based index for display
|
|
})
|
|
|
|
# Create the interactive scatter plot
|
|
fig = px.scatter(
|
|
df,
|
|
x='x',
|
|
y='y',
|
|
text='index', # Display the index number directly on the plot point
|
|
hover_name='index', # Show index number prominently on hover
|
|
hover_data={ # Configure what data appears on hover (tooltips)
|
|
'sentence': True, # Show the full sentence
|
|
'x': False, # Hide the x-coordinate from hover tooltip
|
|
'y': False, # Hide the y-coordinate
|
|
'index': False # Hide the index again (it's already in hover_name)
|
|
},
|
|
title=f'Interactive 2D Visualization of sentence embeddings ({MODEL_NAME})',
|
|
labels={'x': f'Component 1', 'y': f'Component 2'} # Axis labels
|
|
)
|
|
|
|
# --- Optional Customizations ---
|
|
# Adjust text label appearance
|
|
fig.update_traces(
|
|
textposition='top center', # Position text above the marker
|
|
textfont_size=10 # Adjust font size of the index number
|
|
)
|
|
|
|
# Adjust layout
|
|
fig.update_layout(
|
|
hovermode='closest', # Make hovering easier
|
|
width=900, # Set plot width in pixels
|
|
height=700, # Set plot height in pixels
|
|
title_x=0.5 # Center the plot title
|
|
)
|
|
# fig.update_layout(xaxis_visible=False, yaxis_visible=False) # Uncomment to hide axes
|
|
|
|
# --- Show the plot ---
|
|
# This will typically open the plot in your default web browser
|
|
# or display it inline if you're in an environment like Jupyter Notebook/Lab.
|
|
fig.show()
|
|
|
|
print("Plot display initiated (check browser or output).")
|
|
|
|
print("Done.")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|