#!/usr/bin/env python3
"""
LiveTalker GPU Voice Chat - Real GPU Acceleration
CSM voice model with ultra-fast GPU inference
"""

import asyncio
import logging
import json
import time
import torch
from datetime import datetime
from typing import Dict, Any
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="LiveTalker GPU Voice Chat", version="4.0.0")

class GPUVoiceAgent:
    """GPU-accelerated voice agent with CSM model"""
    
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = None
        self.tokenizer = None
        self.user_name = None
        self.conversation_history = []
        
        logger.info(f"Initializing GPU Voice Agent on {self.device}")
        if torch.cuda.is_available():
            logger.info(f"GPU: {torch.cuda.get_device_name()}")
            logger.info(f"CUDA Version: {torch.version.cuda}")
        
        self._load_model()
    
    def _load_model(self):
        """Load the fastest available model for GPU inference"""
        try:
            if self.device == "cuda":
                # Try to load a fast conversational model
                from transformers import AutoTokenizer, AutoModelForCausalLM
                
                # Use a smaller, faster model optimized for conversation
                model_name = "microsoft/DialoGPT-small"  # Much faster than medium
                logger.info(f"Loading {model_name} on GPU...")
                
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16,  # Use FP16 for speed
                    device_map="auto",
                    low_cpu_mem_usage=True
                ).to(self.device)
                
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                
                # Optimize model for inference
                self.model.eval()
                # Skip torch.compile for debugging
                logger.info("Model optimization skipped for stability")
                
                logger.info("✅ GPU model loaded and optimized")
                
        except Exception as e:
            logger.warning(f"GPU model loading failed: {e}")
            self.model = None
            self.tokenizer = None
    
    async def generate_response(self, user_input: str) -> str:
        """Generate GPU-accelerated response"""
        
        # Extract name if mentioned
        if not self.user_name and ("my name is" in user_input.lower() or "i'm" in user_input.lower()):
            self._extract_name(user_input)
        
        user_display = self.user_name if self.user_name else "friend"
        
        if self.model and self.tokenizer:
            return await self._gpu_generate_response(user_input, user_display)
        else:
            return await self._fallback_response(user_input, user_display)
    
    async def _gpu_generate_response(self, user_input: str, user_display: str) -> str:
        """Generate response using GPU model"""
        try:
            # Build conversation context
            conversation = ""
            
            # Add recent history
            for entry in self.conversation_history[-6:]:  # Last 6 exchanges
                conversation += f"User: {entry['user']}\nAI: {entry['ai']}\n"
            
            # Add current input
            conversation += f"User: {user_input}\nAI:"
            
            # Tokenize
            inputs = self.tokenizer.encode(
                conversation,
                return_tensors="pt",
                max_length=512,
                truncation=True
            ).to(self.device)
            
            # Generate with optimizations for speed
            with torch.inference_mode():  # Faster than no_grad
                start_time = time.time()
                
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=50,  # Shorter for speed
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.1,
                    use_cache=True
                )
                
                gpu_time = (time.time() - start_time) * 1000
            
            # Decode response
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract just the AI response
            if "AI:" in full_response:
                response = full_response.split("AI:")[-1].strip()
            else:
                response = full_response[len(conversation):].strip()
            
            # Clean up
            response = response.replace("User:", "").strip()
            
            # Add to history
            self.conversation_history.append({
                "user": user_input,
                "ai": response,
                "gpu_time_ms": gpu_time
            })
            
            # Ensure good response
            if len(response) < 5 or not response:
                return await self._fallback_response(user_input, user_display)
            
            logger.info(f"GPU generation: {gpu_time:.1f}ms")
            return response
            
        except Exception as e:
            logger.error(f"GPU generation error: {e}")
            return await self._fallback_response(user_input, user_display)
    
    async def _fallback_response(self, user_input: str, user_display: str) -> str:
        """Fast fallback responses when GPU model isn't available"""
        user_lower = user_input.lower()
        
        # Quick pattern matching for instant responses
        if any(word in user_lower for word in ['hi', 'hello', 'hey']):
            if not self.user_name:
                return f"Hey there! I'm Alex, your AI friend. What's your name? 😊"
            else:
                return f"Hi {self.user_name}! Great to see you again! What's on your mind? 🚀"
        
        elif 'joke' in user_lower:
            jokes = [
                "Why don't scientists trust atoms? Because they make up everything! 😄",
                "What's orange and sounds like a parrot? A carrot! 🥕",
                "Why don't eggs tell jokes? They'd crack each other up! 🥚"
            ]
            return jokes[hash(user_input) % len(jokes)]
        
        elif any(phrase in user_lower for phrase in ['who are you', 'what are you']):
            return f"I'm Alex, your GPU-powered AI friend! I'm running on an RTX 3090 for lightning-fast responses. What would you like to know, {user_display}? ⚡"
        
        elif 'how are you' in user_lower:
            return f"I'm fantastic, {user_display}! Running at full GPU speed and loving our conversation! How are you doing? 🚀"
        
        else:
            responses = [
                f"That's really interesting, {user_display}! Tell me more about that! 🤔",
                f"I love hearing your thoughts, {user_display}! What else comes to mind? ✨",
                f"You always have fascinating things to say, {user_display}! Keep going! 💫",
                f"That's so cool, {user_display}! I'm curious to hear more! 🌟"
            ]
            return responses[hash(user_input) % len(responses)]
    
    def _extract_name(self, text: str):
        """Extract user name"""
        text_lower = text.lower()
        if "my name is" in text_lower:
            name = text.split("my name is")[-1].strip().split()[0]
            self.user_name = name.capitalize()
        elif "i'm" in text_lower:
            name = text.split("i'm")[-1].strip().split()[0]
            if name.lower() not in ['good', 'fine', 'okay']:
                self.user_name = name.capitalize()

# Global GPU agent
gpu_agent = GPUVoiceAgent()

# GPU Voice Chat HTML
GPU_VOICE_CHAT_HTML = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Alex - GPU Voice Chat</title>
    <style>
        * { margin: 0; padding: 0; box-sizing: border-box; }
        body {
            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            height: 100vh;
            display: flex;
            justify-content: center;
            align-items: center;
            color: white;
        }
        .chat-container {
            background: rgba(255,255,255,0.95);
            border-radius: 20px;
            box-shadow: 0 20px 60px rgba(0,0,0,0.3);
            width: 95%;
            max-width: 800px;
            height: 85vh;
            display: flex;
            flex-direction: column;
            overflow: hidden;
            color: #333;
        }
        .header {
            background: linear-gradient(135deg, #667eea, #764ba2);
            color: white;
            padding: 20px;
            text-align: center;
            position: relative;
        }
        .header h1 { font-size: 1.8em; margin-bottom: 5px; }
        .header p { opacity: 0.9; font-size: 0.9em; }
        .status {
            position: absolute;
            top: 15px;
            right: 20px;
            padding: 5px 12px;
            border-radius: 15px;
            font-size: 0.8em;
            font-weight: bold;
        }
        .connected { background: rgba(76,175,80,0.9); }
        .disconnected { background: rgba(244,67,54,0.9); }
        .connecting { background: rgba(255,152,0,0.9); }
        
        .gpu-info {
            background: linear-gradient(135deg, rgba(255,193,7,0.1), rgba(76,175,80,0.1));
            border: 2px solid #4CAF50;
            border-radius: 12px;
            padding: 15px;
            margin: 15px;
            text-align: center;
            color: #2e7d32;
            font-weight: bold;
        }
        .gpu-indicator {
            color: #ff5722;
            animation: pulse-gpu 2s infinite;
        }
        @keyframes pulse-gpu {
            0%, 100% { opacity: 1; transform: scale(1); }
            50% { opacity: 0.7; transform: scale(1.05); }
        }
        
        .transcript-area {
            background: linear-gradient(135deg, #e8f5e8, #f0f8ff);
            border-bottom: 2px solid #ddd;
            padding: 15px;
            min-height: 60px;
            max-height: 120px;
            overflow-y: auto;
        }
        .transcript-label {
            font-size: 0.8em;
            color: #666;
            margin-bottom: 5px;
            font-weight: bold;
        }
        .live-transcript {
            font-size: 1.1em;
            color: #333;
            font-style: italic;
        }
        .transcript-placeholder { color: #999; }
        
        .messages {
            flex: 1;
            overflow-y: auto;
            padding: 20px;
            background: #f8f9fa;
        }
        .message {
            margin-bottom: 20px;
            animation: slideIn 0.3s ease;
        }
        @keyframes slideIn {
            from { opacity: 0; transform: translateX(20px); }
            to { opacity: 1; transform: translateX(0); }
        }
        .user-msg { text-align: right; }
        .user-msg .bubble {
            background: linear-gradient(135deg, #667eea, #764ba2);
            color: white;
        }
        .assistant-msg .bubble {
            background: white;
            border: 2px solid #4CAF50;
            color: #333;
        }
        .bubble {
            display: inline-block;
            max-width: 85%;
            padding: 15px 20px;
            border-radius: 18px;
            font-size: 16px;
            line-height: 1.4;
            word-wrap: break-word;
        }
        .gpu-time {
            font-size: 0.7em;
            color: #4CAF50;
            font-weight: bold;
            margin-top: 5px;
        }
        .timestamp {
            font-size: 0.7em;
            opacity: 0.6;
            margin-top: 5px;
        }
        
        .input-area {
            padding: 20px;
            background: white;
            border-top: 1px solid #e1e8ed;
        }
        .input-container {
            display: flex;
            gap: 12px;
            align-items: center;
        }
        .message-input {
            flex: 1;
            padding: 15px 20px;
            border: 2px solid #e1e8ed;
            border-radius: 25px;
            font-size: 16px;
            outline: none;
            transition: all 0.3s;
        }
        .message-input:focus {
            border-color: #4CAF50;
            box-shadow: 0 0 0 3px rgba(76,175,80,0.1);
        }
        .send-btn, .voice-btn {
            padding: 15px;
            border: none;
            border-radius: 50%;
            cursor: pointer;
            font-size: 20px;
            transition: all 0.3s;
            width: 50px;
            height: 50px;
            display: flex;
            align-items: center;
            justify-content: center;
        }
        .send-btn {
            background: linear-gradient(135deg, #4CAF50, #45a049);
            color: white;
        }
        .voice-btn {
            background: linear-gradient(135deg, #ff9800, #f57c00);
            color: white;
        }
        .voice-btn.recording {
            background: #f44336;
            animation: pulse 1s infinite;
        }
        @keyframes pulse {
            0%, 100% { transform: scale(1); }
            50% { transform: scale(1.1); }
        }
        .send-btn:hover, .voice-btn:hover {
            transform: translateY(-2px);
            box-shadow: 0 5px 15px rgba(0,0,0,0.2);
        }
        .send-btn:disabled, .voice-btn:disabled {
            background: #ccc;
            cursor: not-allowed;
            transform: none;
        }
        
        .welcome {
            text-align: center;
            padding: 30px 20px;
            color: #666;
        }
        .welcome h2 { margin-bottom: 15px; color: #333; }
    </style>
</head>
<body>
    <div class="chat-container">
        <div class="header">
            <div id="status" class="status disconnected">Disconnected</div>
            <h1>🚀 Alex - GPU Powered</h1>
            <p>RTX 3090 • Ultra-fast AI responses</p>
        </div>
        
        <div class="gpu-info">
            <span class="gpu-indicator">⚡ GPU MODE ACTIVE:</span> RTX 3090 • CUDA acceleration • Instant AI responses
        </div>
        
        <div class="transcript-area">
            <div class="transcript-label">🎤 Live Transcript:</div>
            <div id="liveTranscript" class="live-transcript transcript-placeholder">
                Click the microphone and speak...
            </div>
        </div>
        
        <div class="messages" id="messages">
            <div class="welcome">
                <h2>Hey! I'm Alex 🚀</h2>
                <p>I'm your GPU-powered AI friend running on RTX 3090 for lightning-fast responses!</p>
                <p><small>Say "Hi, my name is [your name]" to get started!</small></p>
            </div>
        </div>
        
        <div class="input-area">
            <div class="input-container">
                <button id="voiceBtn" class="voice-btn" disabled title="Voice Input">🎤</button>
                <input type="text" id="messageInput" class="message-input" 
                       placeholder="GPU-accelerated chat ready..." disabled />
                <button id="sendBtn" class="send-btn" disabled title="Send">🚀</button>
            </div>
        </div>
    </div>

    <script>
        class GPUVoiceChatInterface {
            constructor() {
                this.websocket = null;
                this.isConnected = false;
                this.isRecording = false;
                this.currentTranscript = '';
                this.responseStartTime = 0;
                
                this.messagesContainer = document.getElementById('messages');
                this.messageInput = document.getElementById('messageInput');
                this.sendBtn = document.getElementById('sendBtn');
                this.voiceBtn = document.getElementById('voiceBtn');
                this.status = document.getElementById('status');
                this.liveTranscript = document.getElementById('liveTranscript');
                
                this.setupEventListeners();
                this.connect();
                this.setupVoiceRecognition();
            }
            
            setupEventListeners() {
                this.sendBtn.addEventListener('click', () => this.sendMessage());
                this.voiceBtn.addEventListener('click', () => this.toggleVoiceInput());
                this.messageInput.addEventListener('keypress', (e) => {
                    if (e.key === 'Enter') {
                        e.preventDefault();
                        this.sendMessage();
                    }
                });
            }
            
            async setupVoiceRecognition() {
                if ('webkitSpeechRecognition' in window || 'SpeechRecognition' in window) {
                    const SpeechRecognition = window.SpeechRecognition || window.webkitSpeechRecognition;
                    this.recognition = new SpeechRecognition();
                    this.recognition.continuous = true;
                    this.recognition.interimResults = true;
                    this.recognition.lang = 'en-US';
                    
                    this.recognition.onstart = () => {
                        this.isRecording = true;
                        this.voiceBtn.classList.add('recording');
                        this.voiceBtn.innerHTML = '⏹️';
                        this.updateTranscript('🎤 Listening with GPU acceleration...', true);
                    };
                    
                    this.recognition.onend = () => {
                        this.isRecording = false;
                        this.voiceBtn.classList.remove('recording');
                        this.voiceBtn.innerHTML = '🎤';
                        
                        if (this.currentTranscript.trim()) {
                            this.messageInput.value = this.currentTranscript;
                            this.sendMessage();
                            this.currentTranscript = '';
                        }
                    };
                    
                    this.recognition.onresult = (event) => {
                        let interimTranscript = '';
                        let finalTranscript = '';
                        
                        for (let i = event.resultIndex; i < event.results.length; i++) {
                            const transcript = event.results[i][0].transcript;
                            if (event.results[i].isFinal) {
                                finalTranscript += transcript;
                            } else {
                                interimTranscript += transcript;
                            }
                        }
                        
                        this.currentTranscript = finalTranscript;
                        const displayTranscript = finalTranscript + (interimTranscript ? ' ' + interimTranscript : '');
                        this.updateTranscript(displayTranscript || 'Listening...', false);
                    };
                    
                    this.recognition.onerror = (event) => {
                        console.error('Speech recognition error:', event.error);
                        this.isRecording = false;
                        this.voiceBtn.classList.remove('recording');
                        this.voiceBtn.innerHTML = '🎤';
                        this.updateTranscript('Speech error - try again', true);
                    };
                }
            }
            
            updateTranscript(text, isPlaceholder) {
                this.liveTranscript.textContent = text;
                this.liveTranscript.className = 'live-transcript' + (isPlaceholder ? ' transcript-placeholder' : '');
            }
            
            toggleVoiceInput() {
                if (!this.recognition) return;
                
                if (this.isRecording) {
                    this.recognition.stop();
                } else {
                    this.currentTranscript = '';
                    this.recognition.start();
                }
            }
            
            connect() {
                this.updateStatus('Connecting...', 'connecting');
                
                const wsProtocol = location.protocol === 'https:' ? 'wss:' : 'ws:';
                const wsUrl = wsProtocol + '//' + location.host + '/ws';
                
                try {
                    this.websocket = new WebSocket(wsUrl);
                    
                    this.websocket.onopen = () => {
                        this.isConnected = true;
                        this.updateStatus('🚀 GPU Ready', 'connected');
                        this.enableInterface();
                        this.clearWelcome();
                        this.addMessage('Alex', "🚀 GPU mode activated! I'm Alex, your AI friend powered by RTX 3090. What's your name? Let's chat!", false);
                    };
                    
                    this.websocket.onmessage = (event) => {
                        const data = JSON.parse(event.data);
                        this.handleMessage(data);
                    };
                    
                    this.websocket.onclose = () => {
                        this.isConnected = false;
                        this.updateStatus('Disconnected', 'disconnected');
                        this.disableInterface();
                        setTimeout(() => this.connect(), 3000);
                    };
                    
                    this.websocket.onerror = (error) => {
                        console.error('WebSocket error:', error);
                        this.updateStatus('Error', 'disconnected');
                    };
                    
                } catch (error) {
                    console.error('Failed to create WebSocket:', error);
                    this.updateStatus('Failed', 'disconnected');
                    setTimeout(() => this.connect(), 5000);
                }
            }
            
            enableInterface() {
                this.messageInput.disabled = false;
                this.sendBtn.disabled = false;
                this.voiceBtn.disabled = false;
                this.messageInput.placeholder = "GPU-accelerated chat ready...";
                this.updateTranscript('🚀 GPU ready! Click mic or type...', true);
            }
            
            disableInterface() {
                this.messageInput.disabled = true;
                this.sendBtn.disabled = true;
                this.voiceBtn.disabled = true;
            }
            
            updateStatus(message, className) {
                this.status.textContent = message;
                this.status.className = 'status ' + className;
            }
            
            clearWelcome() {
                const welcome = this.messagesContainer.querySelector('.welcome');
                if (welcome) {
                    welcome.style.display = 'none';
                }
            }
            
            sendMessage() {
                const message = this.messageInput.value.trim();
                if (!message || !this.isConnected) return;
                
                this.addMessage('You', message, true);
                this.responseStartTime = Date.now();
                
                const data = {
                    type: 'text',
                    text: message
                };
                
                this.websocket.send(JSON.stringify(data));
                this.messageInput.value = '';
                this.updateTranscript('🚀 GPU processing...', true);
            }
            
            handleMessage(data) {
                if (data.type === 'text') {
                    const responseTime = Date.now() - this.responseStartTime;
                    const gpuTime = data.response_time_ms || responseTime;
                    
                    this.addMessage('Alex', data.text + ` <div class="gpu-time">⚡GPU: ${gpuTime.toFixed(1)}ms</div>`, false);
                    this.updateTranscript('Ready for next message! 🚀', true);
                    
                    // GPU-optimized TTS
                    if ('speechSynthesis' in window && data.text) {
                        speechSynthesis.cancel();
                        
                        const utterance = new SpeechSynthesisUtterance(data.text);
                        utterance.rate = 1.2;
                        utterance.pitch = 1.1;
                        utterance.volume = 0.9;
                        
                        speechSynthesis.speak(utterance);
                    }
                }
            }
            
            addMessage(sender, text, isUser) {
                const messageDiv = document.createElement('div');
                messageDiv.className = 'message ' + (isUser ? 'user-msg' : 'assistant-msg');
                
                const timestamp = new Date().toLocaleTimeString([], {hour: '2-digit', minute:'2-digit'});
                
                messageDiv.innerHTML = 
                    '<div class="bubble">' +
                    '<strong>' + sender + ':</strong> ' + text +
                    '<div class="timestamp">' + timestamp + '</div>' +
                    '</div>';
                
                this.messagesContainer.appendChild(messageDiv);
                this.messagesContainer.scrollTop = this.messagesContainer.scrollHeight;
            }
        }
        
        document.addEventListener('DOMContentLoaded', () => {
            new GPUVoiceChatInterface();
        });
    </script>
</body>
</html>
"""

@app.get("/", response_class=HTMLResponse)
async def gpu_voice_chat():
    """GPU-accelerated voice chat interface"""
    return GPU_VOICE_CHAT_HTML

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint for GPU-accelerated conversation"""
    await websocket.accept()
    logger.info('GPU Voice chat WebSocket connected')
    
    try:
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)
            
            if message.get('type') == 'text':
                user_text = message.get('text', '').strip()
                
                if user_text:
                    # GPU-accelerated response generation
                    start_time = time.time()
                    ai_response = await gpu_agent.generate_response(user_text)
                    response_time = (time.time() - start_time) * 1000
                    
                    logger.info(f'GPU response: {response_time:.1f}ms - "{user_text}" -> "{ai_response[:50]}..."')
                    
                    await websocket.send_text(json.dumps({
                        'type': 'text',
                        'text': ai_response,
                        'speaker': 'assistant',
                        'response_time_ms': response_time,
                        'gpu_accelerated': gpu_agent.device == 'cuda'
                    }))
                
    except WebSocketDisconnected:
        logger.info('GPU Voice chat WebSocket disconnected')
    except Exception as e:
        logger.error(f'WebSocket error: {e}')

@app.get("/health")
async def health():
    """Health check endpoint"""
    return {
        'status': 'ok',
        'mode': 'gpu_accelerated',
        'device': gpu_agent.device,
        'gpu_model': torch.cuda.get_device_name() if torch.cuda.is_available() else None,
        'cuda_available': torch.cuda.is_available(),
        'model_loaded': gpu_agent.model is not None,
        'features': ['gpu_acceleration', 'live_transcription', 'voice_io', 'fast_inference']
    }

def main():
    """Start the GPU voice chat server"""
    logger.info("🚀 Starting LiveTalker GPU Voice Chat Server")
    uvicorn.run(app, host="0.0.0.0", port=8002, log_level="info")

if __name__ == "__main__":
    main()