#!/usr/bin/env python3
"""
LiveTalker Working Voice Server
Real-time voice input with actual VAD processing
"""

import asyncio
import json
import logging
import time
import base64
import numpy as np
from typing import Dict, Any
import torch
from silero_vad import load_silero_vad, get_speech_timestamps

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="LiveTalker Working Voice Server")

# Enable CORS for all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global VAD model
vad_model = None
active_connections: Dict[str, Dict] = {}

@app.on_event("startup")
async def startup():
    """Initialize VAD model on startup"""
    global vad_model
    logger.info("Loading Silero VAD model...")
    try:
        vad_model = load_silero_vad(onnx=False)
        logger.info("✅ Silero VAD model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load VAD model: {e}")
        vad_model = None

@app.get("/")
async def root():
    """Main interface with working voice input"""
    html_content = """
<!DOCTYPE html>
<html>
<head>
    <title>LiveTalker - Working Voice Input</title>
    <style>
        body { 
            font-family: Arial, sans-serif; 
            margin: 20px; 
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            min-height: 100vh;
        }
        .container { 
            max-width: 900px; 
            margin: 0 auto; 
            background: rgba(255,255,255,0.1);
            padding: 30px;
            border-radius: 15px;
            backdrop-filter: blur(10px);
        }
        h1 { text-align: center; margin-bottom: 30px; }
        .status { 
            padding: 15px; 
            margin: 20px 0; 
            border-radius: 8px; 
            background: rgba(255,255,255,0.2);
            border: 2px solid transparent;
        }
        .status.active { border-color: #4CAF50; background: rgba(76,175,80,0.3); }
        .status.listening { border-color: #FF9800; background: rgba(255,152,0,0.3); }
        .status.speaking { border-color: #2196F3; background: rgba(33,150,243,0.3); }
        .controls { 
            text-align: center; 
            margin: 30px 0; 
            display: grid;
            grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
            gap: 15px;
        }
        button { 
            padding: 15px 25px; 
            border: none; 
            border-radius: 8px; 
            background: #4CAF50; 
            color: white; 
            cursor: pointer; 
            font-size: 16px;
            transition: all 0.3s ease;
        }
        button:hover { background: #45a049; transform: translateY(-2px); }
        button:disabled { background: #666; cursor: not-allowed; transform: none; }
        .mic-button { 
            background: #f44336; 
            font-size: 18px; 
            padding: 20px 30px;
            border-radius: 50%;
            width: 80px;
            height: 80px;
            margin: 20px auto;
            display: flex;
            align-items: center;
            justify-content: center;
        }
        .mic-button.recording { 
            background: #ff1744; 
            animation: pulse 1s infinite;
        }
        @keyframes pulse {
            0%, 100% { opacity: 1; transform: scale(1); }
            50% { opacity: 0.7; transform: scale(1.05); }
        }
        .log { 
            background: rgba(0,0,0,0.4); 
            padding: 15px; 
            border-radius: 8px; 
            height: 300px; 
            overflow-y: auto; 
            font-family: monospace; 
            font-size: 13px;
            white-space: pre-wrap;
        }
        .vad-indicator {
            text-align: center;
            margin: 20px 0;
        }
        .vad-bar {
            width: 100%;
            height: 30px;
            background: rgba(255,255,255,0.2);
            border-radius: 15px;
            overflow: hidden;
            position: relative;
        }
        .vad-level {
            height: 100%;
            background: linear-gradient(90deg, #4CAF50, #8BC34A, #FFC107, #FF5722);
            width: 0%;
            transition: width 0.1s ease;
        }
        .feature-grid {
            display: grid;
            grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
            gap: 20px;
            margin: 20px 0;
        }
        .feature-card {
            background: rgba(255,255,255,0.15);
            padding: 20px;
            border-radius: 10px;
            text-align: center;
        }
        .conversation {
            background: rgba(0,0,0,0.3);
            border-radius: 10px;
            padding: 20px;
            margin: 20px 0;
            max-height: 400px;
            overflow-y: auto;
        }
        .message {
            margin: 10px 0;
            padding: 10px;
            border-radius: 8px;
        }
        .message.user {
            background: rgba(33,150,243,0.3);
            text-align: right;
        }
        .message.assistant {
            background: rgba(76,175,80,0.3);
            text-align: left;
        }
        .message.system {
            background: rgba(158,158,158,0.3);
            text-align: center;
            font-style: italic;
        }
    </style>
</head>
<body>
    <div class="container">
        <h1>🎙️ LiveTalker - Real Voice Processing</h1>
        
        <div class="status" id="connectionStatus">
            <h3>Connection Status:</h3>
            <div id="statusText">Not connected</div>
        </div>
        
        <div class="controls">
            <button onclick="connectWebSocket()">🔗 Connect</button>
            <button onclick="requestMicrophone()" id="micPermBtn">🎤 Request Microphone</button>
            <button onclick="startListening()" id="startBtn" disabled>🎧 Start Listening</button>
            <button onclick="stopListening()" id="stopBtn" disabled>🛑 Stop</button>
        </div>
        
        <div class="vad-indicator">
            <h3>🎵 Voice Activity Detection</h3>
            <div class="vad-bar">
                <div class="vad-level" id="vadLevel"></div>
            </div>
            <div id="vadStatus">Waiting for audio...</div>
        </div>
        
        <div class="status" id="micStatus">
            <h3>🎤 Microphone Status:</h3>
            <div id="micStatusText">Permission required</div>
        </div>
        
        <div class="conversation" id="conversation">
            <div class="message system">Ready for voice conversation...</div>
        </div>
        
        <div class="status">
            <h3>📊 Activity Log:</h3>
            <div id="log" class="log">Waiting to connect...</div>
        </div>
        
        <div class="feature-grid">
            <div class="feature-card">
                <h4>🎯 Real VAD</h4>
                <p>Silero VAD processing your actual voice input</p>
            </div>
            <div class="feature-card">
                <h4>⚡ Live Processing</h4>
                <p>Real-time audio analysis and speech detection</p>
            </div>
            <div class="feature-card">
                <h4>🔄 Turn Detection</h4>
                <p>Intelligent conversation flow management</p>
            </div>
            <div class="feature-card">
                <h4>🧠 Smart Responses</h4>
                <p>Context-aware conversation handling</p>
            </div>
        </div>
    </div>

    <script>
        let ws = null;
        let mediaStream = null;
        let audioContext = null;
        let processor = null;
        let isRecording = false;
        let connected = false;
        
        function log(message) {
            const logDiv = document.getElementById('log');
            const timestamp = new Date().toLocaleTimeString();
            logDiv.textContent += `[${timestamp}] ${message}\\n`;
            logDiv.scrollTop = logDiv.scrollHeight;
        }
        
        function updateStatus(elementId, message, className = '') {
            const element = document.getElementById(elementId);
            if (element) {
                const textElement = elementId === 'connectionStatus' ? 
                    document.getElementById('statusText') : element;
                textElement.textContent = message;
                element.className = 'status ' + className;
            }
        }
        
        function addMessage(type, content) {
            const conversation = document.getElementById('conversation');
            const message = document.createElement('div');
            message.className = `message ${type}`;
            message.textContent = content;
            conversation.appendChild(message);
            conversation.scrollTop = conversation.scrollHeight;
        }
        
        function updateVAD(level, isActive) {
            const vadLevel = document.getElementById('vadLevel');
            const vadStatus = document.getElementById('vadStatus');
            
            vadLevel.style.width = `${level * 100}%`;
            vadStatus.textContent = isActive ? 
                `🎵 Speech detected (${(level * 100).toFixed(1)}%)` : 
                `🔇 Silence (${(level * 100).toFixed(1)}%)`;
        }
        
        async function connectWebSocket() {
            const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
            const wsUrl = `${protocol}//${window.location.host}/media-stream`;
            
            log('Connecting to WebSocket...');
            updateStatus('connectionStatus', 'Connecting...', '');
            
            try {
                ws = new WebSocket(wsUrl);
                
                ws.onopen = function() {
                    connected = true;
                    log('✅ WebSocket connected');
                    updateStatus('connectionStatus', '✅ Connected', 'active');
                };
                
                ws.onmessage = function(event) {
                    try {
                        const data = JSON.parse(event.data);
                        handleServerMessage(data);
                    } catch (e) {
                        log(`📨 Raw message: ${event.data.substring(0, 100)}...`);
                    }
                };
                
                ws.onclose = function() {
                    connected = false;
                    log('❌ WebSocket disconnected');
                    updateStatus('connectionStatus', '❌ Disconnected', '');
                };
                
                ws.onerror = function(error) {
                    log(`❌ WebSocket error: ${error}`);
                    updateStatus('connectionStatus', '❌ Error', '');
                };
                
            } catch (error) {
                log(`❌ Connection failed: ${error}`);
                updateStatus('connectionStatus', '❌ Failed', '');
            }
        }
        
        function handleServerMessage(data) {
            log(`📨 ${data.type}: ${JSON.stringify(data).substring(0, 150)}...`);
            
            switch(data.type) {
                case 'config':
                    log('Server configuration received');
                    break;
                    
                case 'vad_result':
                    updateVAD(data.confidence || 0, data.is_speech || false);
                    if (data.is_speech) {
                        updateStatus('connectionStatus', '🎵 Voice detected!', 'listening');
                    }
                    break;
                    
                case 'turn_detected':
                    log(`🔄 Turn detected: ${data.state}`);
                    break;
                    
                case 'speech_to_text':
                    if (data.text && data.text.trim()) {
                        addMessage('user', data.text);
                        log(`🗣️ Transcribed: "${data.text}"`);
                    }
                    break;
                    
                case 'ai_response':
                    if (data.text) {
                        addMessage('assistant', data.text);
                        log(`🤖 AI Response: "${data.text}"`);
                    }
                    break;
                    
                case 'error':
                    log(`❌ Server error: ${data.error}`);
                    break;
            }
        }
        
        async function requestMicrophone() {
            try {
                log('Requesting microphone permission...');
                updateStatus('micStatusText', 'Requesting permission...', '');
                
                mediaStream = await navigator.mediaDevices.getUserMedia({
                    audio: {
                        sampleRate: 16000,
                        channelCount: 1,
                        echoCancellation: true,
                        noiseSuppression: true,
                        autoGainControl: true
                    }
                });
                
                log('✅ Microphone permission granted');
                updateStatus('micStatusText', '✅ Microphone ready', 'active');
                
                // Setup audio processing
                audioContext = new (window.AudioContext || window.webkitAudioContext)({
                    sampleRate: 16000
                });
                
                const source = audioContext.createMediaStreamSource(mediaStream);
                
                // Create audio worklet for processing
                await audioContext.audioWorklet.addModule('data:text/javascript,class%20AudioProcessor%20extends%20AudioWorkletProcessor%20%7B%0A%20%20%20%20constructor()%20%7B%0A%20%20%20%20%20%20%20%20super();%0A%20%20%20%20%20%20%20%20this.bufferSize%20%3D%201024;%0A%20%20%20%20%20%20%20%20this.buffer%20%3D%20new%20Float32Array(this.bufferSize);%0A%20%20%20%20%20%20%20%20this.bufferIndex%20%3D%200;%0A%20%20%20%20%7D%0A%0A%20%20%20%20process(inputs%2C%20outputs%2C%20parameters)%20%7B%0A%20%20%20%20%20%20%20%20const%20input%20%3D%20inputs%5B0%5D;%0A%20%20%20%20%20%20%20%20if%20(input.length%20%3E%200)%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20const%20inputChannel%20%3D%20input%5B0%5D;%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20(let%20i%20%3D%200;%20i%20%3C%20inputChannel.length;%20i%2B%2B)%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20this.buffer%5Bthis.bufferIndex%5D%20%3D%20inputChannel%5Bi%5D;%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20this.bufferIndex%2B%2B;%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20(this.bufferIndex%20%3E%3D%20this.bufferSize)%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20this.port.postMessage(%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20type%3A%20'audio'%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20audio%3A%20new%20Float32Array(this.buffer)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%7D);%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20this.bufferIndex%20%3D%200;%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20%20%20%20%20%7D%0A%0A%20%20%20%20%20%20%20%20return%20true;%0A%20%20%20%20%7D%0A%7D%0A%0AregisterProcessor('audio-processor'%2C%20AudioProcessor);');
                
                processor = new AudioWorkletNode(audioContext, 'audio-processor');
                source.connect(processor);
                
                processor.port.onmessage = (event) => {
                    if (event.data.type === 'audio' && isRecording && connected) {
                        sendAudioData(event.data.audio);
                    }
                };
                
                document.getElementById('startBtn').disabled = false;
                
            } catch (error) {
                log(`❌ Microphone error: ${error.message}`);
                updateStatus('micStatusText', '❌ Permission denied', '');
            }
        }
        
        function sendAudioData(audioData) {
            if (!ws || ws.readyState !== WebSocket.OPEN) return;
            
            // Convert Float32Array to Int16Array and then to base64
            const int16Array = new Int16Array(audioData.length);
            for (let i = 0; i < audioData.length; i++) {
                int16Array[i] = Math.max(-1, Math.min(1, audioData[i])) * 0x7FFF;
            }
            
            const base64Audio = btoa(String.fromCharCode(...new Uint8Array(int16Array.buffer)));
            
            ws.send(JSON.stringify({
                type: 'audio',
                data: base64Audio,
                format: 'pcm_s16le',
                sample_rate: 16000
            }));
        }
        
        function startListening() {
            if (!connected) {
                alert('Please connect to WebSocket first');
                return;
            }
            if (!mediaStream) {
                alert('Please request microphone permission first');
                return;
            }
            
            isRecording = true;
            log('🎧 Started listening...');
            updateStatus('connectionStatus', '🎧 Listening...', 'listening');
            
            document.getElementById('startBtn').disabled = true;
            document.getElementById('stopBtn').disabled = false;
            
            // Resume audio context if suspended
            if (audioContext.state === 'suspended') {
                audioContext.resume();
            }
            
            // Send start message
            if (ws && ws.readyState === WebSocket.OPEN) {
                ws.send(JSON.stringify({
                    type: 'start_conversation',
                    config: { personality: 'luna' }
                }));
            }
        }
        
        function stopListening() {
            isRecording = false;
            log('🛑 Stopped listening');
            updateStatus('connectionStatus', '✅ Connected', 'active');
            updateVAD(0, false);
            
            document.getElementById('startBtn').disabled = false;
            document.getElementById('stopBtn').disabled = true;
        }
        
        // Auto-connect on page load
        document.addEventListener('DOMContentLoaded', function() {
            log('LiveTalker Voice Interface loaded');
            log('Click Connect, then Request Microphone, then Start Listening');
        });
    </script>
</body>
</html>
    """
    return HTMLResponse(content=html_content)

@app.get("/health")
async def health_check():
    """Health check"""
    return {
        "status": "healthy",
        "vad_model": "loaded" if vad_model else "not_loaded",
        "timestamp": time.time(),
        "features": {
            "real_vad": vad_model is not None,
            "microphone_input": True,
            "real_time_processing": True,
            "speech_detection": True
        }
    }

@app.get("/stats")
async def get_stats():
    """System statistics"""
    return {
        "active_connections": len(active_connections),
        "vad_model_loaded": vad_model is not None,
        "gpu_available": torch.cuda.is_available(),
        "processing_mode": "real_time",
        "features": {
            "silero_vad": "active" if vad_model else "failed",
            "microphone_capture": "supported",
            "real_time_analysis": "enabled"
        }
    }

@app.websocket("/media-stream")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint for real voice processing"""
    await websocket.accept()
    session_id = f"session_{int(time.time() * 1000)}"
    
    session = {
        "id": session_id,
        "websocket": websocket,
        "audio_buffer": b"",
        "conversation": [],
        "last_speech": "",
        "is_listening": False
    }
    active_connections[session_id] = session
    
    logger.info(f"New voice session: {session_id}")
    
    try:
        # Send initial config
        await websocket.send_json({
            "type": "config",
            "session_id": session_id,
            "vad_ready": vad_model is not None,
            "message": "Real voice processing ready"
        })
        
        async for message in websocket.iter_json():
            await handle_voice_message(session, message)
            
    except WebSocketDisconnect:
        logger.info(f"Voice session disconnected: {session_id}")
    except Exception as e:
        logger.error(f"Voice session error: {e}")
        await websocket.send_json({
            "type": "error",
            "error": str(e)
        })
    finally:
        if session_id in active_connections:
            del active_connections[session_id]

async def handle_voice_message(session: Dict, message: Dict[str, Any]):
    """Handle voice-related WebSocket messages"""
    msg_type = message.get("type")
    
    if msg_type == "start_conversation":
        session["is_listening"] = True
        await session["websocket"].send_json({
            "type": "conversation_started",
            "message": "Voice conversation started - speak now!"
        })
        logger.info(f"Started voice conversation for {session['id']}")
        
    elif msg_type == "audio":
        if session["is_listening"] and vad_model:
            await process_audio_chunk(session, message)
        
    elif msg_type == "stop_listening":
        session["is_listening"] = False
        await session["websocket"].send_json({
            "type": "stopped",
            "message": "Stopped listening"
        })

async def process_audio_chunk(session: Dict, message: Dict[str, Any]):
    """Process incoming audio with real VAD"""
    try:
        # Decode audio data
        audio_data = base64.b64decode(message["data"])
        
        # Convert to numpy array (PCM 16-bit)
        audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
        
        if len(audio_np) == 0:
            return
        
        # Run VAD on audio chunk
        audio_tensor = torch.from_numpy(audio_np).float()
        
        # Resample to 16kHz if needed (Silero VAD expects 16kHz)
        if len(audio_tensor) > 0:
            with torch.no_grad():
                speech_prob = vad_model(audio_tensor, 16000).item()
            
            is_speech = speech_prob > 0.5
            
            # Send VAD result
            await session["websocket"].send_json({
                "type": "vad_result",
                "is_speech": is_speech,
                "confidence": speech_prob,
                "timestamp": time.time()
            })
            
            # Accumulate audio if speech detected
            if is_speech:
                session["audio_buffer"] += audio_data
                
                # Process accumulated audio for speech-to-text when we have enough
                if len(session["audio_buffer"]) > 16000 * 2:  # ~1 second of audio
                    await process_speech_segment(session)
            
            # Detect end of speech (silence after speech)
            elif len(session["audio_buffer"]) > 0:
                # Process final segment
                await process_speech_segment(session)
                session["audio_buffer"] = b""
    
    except Exception as e:
        logger.error(f"Error processing audio: {e}")
        await session["websocket"].send_json({
            "type": "error",
            "error": f"Audio processing error: {str(e)}"
        })

async def process_speech_segment(session: Dict):
    """Process accumulated speech segment"""
    try:
        if len(session["audio_buffer"]) == 0:
            return
        
        # Convert audio buffer to numpy
        audio_np = np.frombuffer(session["audio_buffer"], dtype=np.int16).astype(np.float32) / 32768.0
        
        # Get speech timestamps using Silero VAD
        audio_tensor = torch.from_numpy(audio_np).float()
        
        # For demo, we'll simulate speech-to-text
        # In a real implementation, you'd use a STT model here
        simulated_text = f"[Speech detected: {len(audio_np)/16000:.1f}s audio segment]"
        
        if len(audio_np) > 16000:  # Only process if > 1 second
            # Send transcription result
            await session["websocket"].send_json({
                "type": "speech_to_text",
                "text": simulated_text,
                "confidence": 0.85,
                "duration": len(audio_np) / 16000
            })
            
            session["conversation"].append({
                "role": "user",
                "content": simulated_text,
                "timestamp": time.time()
            })
            
            # Generate AI response (simulated)
            ai_response = f"I heard your speech segment of {len(audio_np)/16000:.1f} seconds. In a full implementation, this would be processed by CSM/Sesame for ultra-low latency voice response."
            
            await session["websocket"].send_json({
                "type": "ai_response", 
                "text": ai_response,
                "processing_time": "simulated"
            })
            
            session["conversation"].append({
                "role": "assistant",
                "content": ai_response,
                "timestamp": time.time()
            })
            
            logger.info(f"Processed speech segment: {len(audio_np)/16000:.1f}s")
    
    except Exception as e:
        logger.error(f"Error processing speech segment: {e}")

if __name__ == "__main__":
    print("🎙️ Starting LiveTalker Working Voice Server...")
    print("Features:")
    print("  ✅ Real microphone input")
    print("  ✅ Actual Silero VAD processing")
    print("  ✅ Real-time audio analysis")
    print("  ✅ WebSocket communication")
    print("  ✅ Speech segment detection")
    print("")
    print("📍 Access URL: http://localhost:8000")
    print("🎯 Click 'Request Microphone' then 'Start Listening' to test!")
    
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8000,
        log_level="info"
    )