# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Web interface for OpenEnv environments. This module provides a web-based interface for interacting with OpenEnv environments, including a two-pane layout for HumanAgent interaction and state observation. """ from __future__ import annotations import json import time from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Type from datetime import datetime from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request from fastapi.responses import HTMLResponse, FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from .interfaces import Environment from .types import Action, Observation, State @dataclass class ActionLog: """Log entry for an action taken.""" timestamp: str action: Dict[str, Any] observation: Dict[str, Any] reward: Optional[float] done: bool step_count: int @dataclass class EpisodeState: """Current episode state for the web interface.""" episode_id: Optional[str] step_count: int current_observation: Optional[Dict[str, Any]] action_logs: List[ActionLog] is_reset: bool = True class WebInterfaceManager: """Manages the web interface for an environment.""" def __init__( self, env: Environment, action_cls: Type[Action], observation_cls: Type[Observation], ): self.env = env self.action_cls = action_cls self.observation_cls = observation_cls self.episode_state = EpisodeState( episode_id=None, step_count=0, current_observation=None, action_logs=[] ) self.connected_clients: List[WebSocket] = [] async def connect_websocket(self, websocket: WebSocket): """Connect a new WebSocket client.""" await websocket.accept() self.connected_clients.append(websocket) # Send current state to the new client await self._send_state_update() async def disconnect_websocket(self, websocket: WebSocket): """Disconnect a WebSocket client.""" if websocket in self.connected_clients: self.connected_clients.remove(websocket) async def _send_state_update(self): """Send current state to all connected clients.""" if not self.connected_clients: return state_data = { "type": "state_update", "episode_state": asdict(self.episode_state) } # Send to all connected clients disconnected_clients = [] for client in self.connected_clients: try: await client.send_text(json.dumps(state_data)) except: disconnected_clients.append(client) # Remove disconnected clients for client in disconnected_clients: self.connected_clients.remove(client) async def reset_environment(self) -> Dict[str, Any]: """Reset the environment and update state.""" observation = self.env.reset() state = self.env.state # Update episode state self.episode_state.episode_id = state.episode_id self.episode_state.step_count = 0 self.episode_state.current_observation = asdict(observation) self.episode_state.action_logs = [] self.episode_state.is_reset = True # Send state update await self._send_state_update() return { "observation": asdict(observation), "reward": observation.reward, "done": observation.done, } async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: """Execute a step in the environment and update state.""" # Deserialize action action = self._deserialize_action(action_data) # Execute step observation = self.env.step(action) state = self.env.state # Create action log action_log = ActionLog( timestamp=datetime.now().isoformat(), action=asdict(action), observation=asdict(observation), reward=observation.reward, done=observation.done, step_count=state.step_count ) # Update episode state self.episode_state.episode_id = state.episode_id self.episode_state.step_count = state.step_count self.episode_state.current_observation = asdict(observation) self.episode_state.action_logs.append(action_log) self.episode_state.is_reset = False # Send state update await self._send_state_update() return { "observation": asdict(observation), "reward": observation.reward, "done": observation.done, } def get_state(self) -> Dict[str, Any]: """Get current environment state.""" state = self.env.state return asdict(state) def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: """Convert JSON dict to Action instance.""" metadata = action_data.pop("metadata", {}) action = self.action_cls(**action_data) action.metadata = metadata return action def create_web_interface_app( env: Environment, action_cls: Type[Action], observation_cls: Type[Observation], ) -> FastAPI: """ Create a FastAPI application with web interface for the given environment. Args: env: The Environment instance to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns Returns: FastAPI application instance with web interface """ from .http_server import create_fastapi_app # Create the base environment app app = create_fastapi_app(env, action_cls, observation_cls) # Create web interface manager web_manager = WebInterfaceManager(env, action_cls, observation_cls) # Add web interface routes @app.get("/web", response_class=HTMLResponse) async def web_interface(): """Serve the web interface.""" return get_web_interface_html(action_cls) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for real-time updates.""" await web_manager.connect_websocket(websocket) try: while True: # Keep connection alive await websocket.receive_text() except WebSocketDisconnect: await web_manager.disconnect_websocket(websocket) @app.post("/web/reset") async def web_reset(): """Reset endpoint for web interface.""" return await web_manager.reset_environment() @app.post("/web/step") async def web_step(request: Dict[str, Any]): """Step endpoint for web interface.""" action_data = request.get("action", {}) return await web_manager.step_environment(action_data) @app.get("/web/state") async def web_state(): """State endpoint for web interface.""" return web_manager.get_state() return app def get_web_interface_html(action_cls: Type[Action]) -> str: """Generate the HTML for the web interface.""" # Get action fields for dynamic form generation action_fields = [] if hasattr(action_cls, '__dataclass_fields__'): for field_name, field_info in action_cls.__dataclass_fields__.items(): if field_name != 'metadata': field_type = field_info.type if field_type == str: input_type = "text" elif field_type == int: input_type = "number" elif field_type == float: input_type = "number" elif field_type == bool: input_type = "checkbox" else: input_type = "text" action_fields.append({ 'name': field_name, 'type': input_type, 'required': field_info.default is field_info.default_factory }) return f""" OpenEnv Web Interface
HumanAgent Interface

Take Action

{_generate_action_form_fields(action_fields)}

Current State

Status: Not initialized
Episode ID: -
Step Count: 0
State Observer

Current Observation

No observation yet

Action History

No actions taken yet
""".replace('{_generate_action_form_fields(action_fields)}', _generate_action_form_fields(action_fields)) def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: """Generate HTML form fields for action input.""" if not action_fields: return '

No action fields available

' fields_html = [] for field in action_fields: if field['type'] == 'checkbox': fields_html.append(f'''
''') elif field['type'] == 'text' and 'message' in field['name'].lower(): fields_html.append(f'''
''') else: fields_html.append(f'''
''') return '\n'.join(fields_html)