| | import gradio as gr |
| | import requests |
| | from PIL import Image |
| | import io |
| | from typing import Any, Tuple |
| | import os |
| |
|
| |
|
| | class Client: |
| | def __init__(self, server_url: str): |
| | self.server_url = server_url |
| |
|
| | def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]: |
| | response = requests.post( |
| | self.server_url, |
| | json={ |
| | "task_name": task_name, |
| | "model_name": model_name, |
| | "text": text, |
| | "normalization_type": normalization_type |
| | }, |
| | timeout=60 |
| | ) |
| | if response.status_code == 200: |
| | response_data = response.json() |
| | img_data = bytes.fromhex(response_data["image"]) |
| | img = Image.open(io.BytesIO(img_data)) |
| | return img, "OK" |
| | else: |
| | return "Error, please retry", "Error: Could not get response from server" |
| |
|
| | client = Client(f"http://{os.environ['SERVER']}/predict") |
| |
|
| | def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]: |
| | return client.send_request(task_name, model_name, text, normalization_type) |
| |
|
| | def update_output(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any]: |
| | img, _ = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type) |
| | return img |
| |
|
| | def set_default(task_name: str) -> str: |
| | if task_name in ["Layer wise non-linearity", "Next-token prediction from intermediate representations", "Tokenwise loss without i-th layer"]: |
| | return "token-wise" |
| | return "global" |
| |
|
| | def check_normalization(task_name: str, normalization_name) -> Tuple[str]: |
| | if task_name == "Contextualization measurement" and normalization_name == "token-wise": |
| | return "global" |
| | return normalization_name |
| |
|
| | def update_description(task_name: str) -> str: |
| | descriptions = { |
| | "Layer wise non-linearity": "Non-linearity per layer: shows how complex each layer's transformation is. Red = more nonlinear.", |
| | "Next-token prediction from intermediate representations": "Layerwise token prediction: when does the model start guessing correctly?", |
| | "Contextualization measurement": "Context stored in each token: how well can the model reconstruct the previous context?", |
| | "Layerwise predictions (logit lens)": "Logit lens: what does each layer believe the next token should be?", |
| | "Tokenwise loss without i-th layer": "Layer ablation: how much does performance drop if a layer is removed?" |
| | } |
| | return descriptions.get(task_name, "ℹ️ No description available.") |
| |
|
| | with gr.Blocks() as demo: |
| | |
| | gr.Markdown("# 🔬 LLM-Microscope — A Look Inside the Black Box") |
| | gr.Markdown("Select a model, analysis mode, and input — then peek inside the black box of an LLM to see which layers matter most, which tokens carry the most memory, and how predictions evolve.") |
| |
|
| | with gr.Row(): |
| | model_selector = gr.Dropdown( |
| | choices=[ |
| | "facebook/opt-1.3b", |
| | "TheBloke/Llama-2-7B-fp16", |
| | "Qwen/Qwen3-8B" |
| | ], |
| | value="facebook/opt-1.3b", |
| | label="Select Model" |
| | ) |
| | task_selector = gr.Dropdown( |
| | choices=[ |
| | "Layer wise non-linearity", |
| | "Next-token prediction from intermediate representations", |
| | "Contextualization measurement", |
| | "Layerwise predictions (logit lens)", |
| | "Tokenwise loss without i-th layer" |
| | ], |
| | value="Layer wise non-linearity", |
| | label="Select Mode" |
| | ) |
| | normalization_selector = gr.Dropdown( |
| | choices=["global", "token-wise"], |
| | value="token-wise", |
| | label="Select Normalization" |
| | ) |
| |
|
| | task_description = gr.Markdown("ℹ️ Choose a mode to see what it does.") |
| | |
| | with gr.Column(): |
| | text_message = gr.Textbox(label="Enter your input text:", value="I love to live my life") |
| | submit = gr.Button("Submit") |
| | box_for_plot = gr.Image(label="Visualization", type="pil") |
| |
|
| | with gr.Accordion("📘 More Info and Explanation", open=False): |
| | gr.Markdown(""" |
| | This heatmap shows **how each token is processed** across layers of a language model. Here's how to read it: |
| | |
| | - **Rows**: layers of the model (bottom = deeper) |
| | - **Columns**: input tokens |
| | - **Colors**: intensity of effect (depends on the selected metric) |
| | |
| | --- |
| | |
| | ### Metrics explained: |
| | |
| | - `Layer wise non-linearity`: how nonlinear the transformation is at each layer (red = more nonlinear). |
| | - `Next-token prediction from intermediate representations`: shows which layers begin to make good predictions. |
| | - `Contextualization measurement`: tokens with more context info get lower scores (green = more context). |
| | - `Layerwise predictions (logit lens)`: tracks how the model’s guesses evolve at each layer. |
| | - `Tokenwise loss without i-th layer`: shows how much each token depends on a specific layer. Red means performance drops if we skip this layer. |
| | |
| | Use this tool to **peek inside the black box** — it reveals which layers matter most, which tokens carry the most memory, and how LLMs evolve their predictions. |
| | |
| | --- |
| | |
| | You can also use `llm-microscope` as a Python library to run these analyses on **your own models and data**. |
| | |
| | Just install it with: `pip install llm-microscope` |
| | |
| | More details provided in [GitHub repo](https://github.com/AIRI-Institute/LLM-Microscope). |
| | """) |
| |
|
| | task_selector.change(fn=update_description, inputs=[task_selector], outputs=[task_description]) |
| | task_selector.select(set_default, [task_selector], [normalization_selector]) |
| | normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector]) |
| | submit.click( |
| | fn=update_output, |
| | inputs=[task_selector, model_selector, text_message, normalization_selector], |
| | outputs=[box_for_plot] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True, server_port=7860, server_name="0.0.0.0") |
| |
|