#!/usr/bin/env python3
"""
Stable GPU Voice Chat Server - Simplified Version
"""
import asyncio
import logging
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for GPU model
gpu_model = None
gpu_tokenizer = None
device = None

async def initialize_gpu_model():
    """Initialize GPU model once at startup"""
    global gpu_model, gpu_tokenizer, device
    
    logger.info("🔥 Initializing GPU model...")
    
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        
        if torch.cuda.is_available():
            logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        
        # Load tokenizer
        gpu_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
        if gpu_tokenizer.pad_token is None:
            gpu_tokenizer.pad_token = gpu_tokenizer.eos_token
        
        # Load model
        gpu_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/DialoGPT-small",
            torch_dtype=torch.float16,
            device_map="auto" if torch.cuda.is_available() else None
        )
        
        gpu_model.eval()
        logger.info("✅ GPU model initialized successfully!")
        
    except Exception as e:
        logger.error(f"❌ Failed to initialize GPU model: {e}")
        gpu_model = None
        gpu_tokenizer = None

async def generate_gpu_response(user_input: str) -> str:
    """Generate response using GPU model"""
    global gpu_model, gpu_tokenizer, device
    
    if not gpu_model or not gpu_tokenizer:
        return "Sorry, GPU model not available. Please try again."
    
    try:
        # Prepare input
        chat_history_ids = gpu_tokenizer.encode(
            user_input + gpu_tokenizer.eos_token, 
            return_tensors="pt"
        )
        
        if device and device.type == "cuda":
            chat_history_ids = chat_history_ids.to(device)
        
        # Generate response
        with torch.no_grad():
            chat_history_ids = gpu_model.generate(
                chat_history_ids,
                max_length=chat_history_ids.shape[1] + 30,
                num_beams=3,
                do_sample=True,
                temperature=0.8,
                pad_token_id=gpu_tokenizer.eos_token_id
            )
        
        # Decode response
        response = gpu_tokenizer.decode(
            chat_history_ids[:, chat_history_ids.shape[1]:][0], 
            skip_special_tokens=True
        )
        
        # Clean up response
        response = response.strip()
        if not response:
            response = "I understand. Could you tell me more about that?"
        
        return response
        
    except Exception as e:
        logger.error(f"GPU generation error: {e}")
        return f"I'm having trouble processing that. Could you rephrase? (GPU Error: {str(e)[:50]})"

# Create FastAPI app
app = FastAPI(title="Stable GPU LiveTalker", version="1.0.0")

# Voice Chat HTML (same as before but with GPU branding)
VOICE_CHAT_HTML = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>LiveTalker GPU Voice Chat</title>
    <style>
        * { margin: 0; padding: 0; box-sizing: border-box; }
        body {
            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            height: 100vh; display: flex; justify-content: center; align-items: center;
            color: white;
        }
        .chat-container {
            background: rgba(255,255,255,0.95); border-radius: 20px;
            box-shadow: 0 20px 60px rgba(0,0,0,0.3);
            width: 90%; max-width: 600px; height: 80vh;
            display: flex; flex-direction: column; overflow: hidden; color: #333;
        }
        .header {
            background: linear-gradient(135deg, #667eea, #764ba2);
            color: white; padding: 30px; text-align: center;
        }
        .header h1 { font-size: 2em; margin-bottom: 10px; }
        .status { padding: 8px 16px; border-radius: 20px; font-weight: bold;
            margin-top: 15px; display: inline-block; }
        .connected { background: rgba(76,175,80,0.9); }
        .disconnected { background: rgba(244,67,54,0.9); }
        .connecting { background: rgba(255,152,0,0.9); }
        .gpu-info {
            background: linear-gradient(135deg, #4CAF50, #45a049);
            border: none; border-radius: 10px; padding: 15px;
            margin: 20px; text-align: center; color: white; font-weight: bold;
        }
        .messages { flex: 1; overflow-y: auto; padding: 20px; background: #f8f9fa; }
        .message { margin-bottom: 20px; animation: fadeIn 0.4s ease; }
        @keyframes fadeIn { from { opacity: 0; transform: translateY(20px); } to { opacity: 1; transform: translateY(0); } }
        .user-msg { text-align: right; }
        .user-msg .bubble { background: linear-gradient(135deg, #667eea, #764ba2); color: white; }
        .assistant-msg .bubble { background: white; border: 2px solid #e1e8ed; color: #333; }
        .bubble {
            display: inline-block; max-width: 80%; padding: 15px 20px;
            border-radius: 18px; font-size: 16px; line-height: 1.4;
        }
        .input-area { padding: 25px; background: white; border-top: 1px solid #e1e8ed; }
        .input-container { display: flex; gap: 15px; align-items: center; }
        .message-input {
            flex: 1; padding: 15px 20px; border: 2px solid #e1e8ed;
            border-radius: 25px; font-size: 16px; outline: none; transition: all 0.3s;
        }
        .message-input:focus { border-color: #667eea; box-shadow: 0 0 0 3px rgba(102,126,234,0.1); }
        .send-btn, .voice-btn {
            padding: 15px; border: none; border-radius: 50%; cursor: pointer;
            font-size: 20px; transition: all 0.3s; width: 50px; height: 50px;
            display: flex; align-items: center; justify-content: center;
        }
        .send-btn { background: linear-gradient(135deg, #667eea, #764ba2); color: white; }
        .voice-btn { background: #4CAF50; color: white; }
        .voice-btn.recording { background: #f44336; animation: pulse 1s infinite; }
        @keyframes pulse { 0%, 100% { transform: scale(1); } 50% { transform: scale(1.1); } }
        .send-btn:hover, .voice-btn:hover { transform: translateY(-2px); box-shadow: 0 5px 15px rgba(0,0,0,0.2); }
        .send-btn:disabled, .voice-btn:disabled { background: #ccc; cursor: not-allowed; transform: none; }
    </style>
</head>
<body>
    <div class="chat-container">
        <div class="header">
            <h1>🎙️ LiveTalker GPU</h1>
            <p>RTX 3090 Powered AI Voice Assistant</p>
            <div id="status" class="status disconnected">Disconnected</div>
        </div>
        
        <div class="gpu-info">
            <strong>🚀 GPU ACCELERATED:</strong> Ultra-fast DialoGPT responses with CUDA 11.8
        </div>
        
        <div class="messages" id="messages">
            <div style="text-align: center; padding: 40px 20px; color: #666;">
                <h2 style="margin-bottom: 15px; color: #333;">GPU Voice Chat Ready!</h2>
                <p>Click the microphone to speak, or type your message below.</p>
                <p><small>Powered by RTX 3090 • Real-time GPU inference</small></p>
            </div>
        </div>
        
        <div class="input-area">
            <div class="input-container">
                <button id="voiceBtn" class="voice-btn" disabled title="Voice Input">🎤</button>
                <input type="text" id="messageInput" class="message-input" 
                       placeholder="Speak or type your message..." disabled />
                <button id="sendBtn" class="send-btn" disabled title="Send Message">➤</button>
            </div>
        </div>
    </div>

    <script>
        class VoiceChatInterface {
            constructor() {
                this.websocket = null;
                this.isConnected = false;
                this.isRecording = false;
                
                this.messagesContainer = document.getElementById('messages');
                this.messageInput = document.getElementById('messageInput');
                this.sendBtn = document.getElementById('sendBtn');
                this.voiceBtn = document.getElementById('voiceBtn');
                this.status = document.getElementById('status');
                
                this.setupEventListeners();
                this.connect();
                this.setupVoiceRecognition();
            }
            
            setupEventListeners() {
                this.sendBtn.addEventListener('click', () => this.sendMessage());
                this.voiceBtn.addEventListener('click', () => this.toggleVoiceInput());
                this.messageInput.addEventListener('keypress', (e) => {
                    if (e.key === 'Enter') {
                        e.preventDefault();
                        this.sendMessage();
                    }
                });
            }
            
            async setupVoiceRecognition() {
                if ('webkitSpeechRecognition' in window || 'SpeechRecognition' in window) {
                    const SpeechRecognition = window.SpeechRecognition || window.webkitSpeechRecognition;
                    this.recognition = new SpeechRecognition();
                    this.recognition.continuous = false;
                    this.recognition.interimResults = false;
                    this.recognition.lang = 'en-US';
                    
                    this.recognition.onstart = () => {
                        this.isRecording = true;
                        this.voiceBtn.classList.add('recording');
                        this.voiceBtn.innerHTML = '⏹️';
                        this.messageInput.placeholder = 'Listening...';
                    };
                    
                    this.recognition.onend = () => {
                        this.isRecording = false;
                        this.voiceBtn.classList.remove('recording');
                        this.voiceBtn.innerHTML = '🎤';
                        this.messageInput.placeholder = 'Speak or type your message...';
                    };
                    
                    this.recognition.onresult = (event) => {
                        const transcript = event.results[0][0].transcript;
                        this.messageInput.value = transcript;
                        this.sendMessage();
                    };
                    
                    this.recognition.onerror = (event) => {
                        console.error('Speech recognition error:', event.error);
                        this.isRecording = false;
                        this.voiceBtn.classList.remove('recording');
                        this.voiceBtn.innerHTML = '🎤';
                    };
                }
            }
            
            toggleVoiceInput() {
                if (!this.recognition) return;
                
                if (this.isRecording) {
                    this.recognition.stop();
                } else {
                    this.recognition.start();
                }
            }
            
            connect() {
                this.updateStatus('Connecting...', 'connecting');
                
                const wsProtocol = location.protocol === 'https:' ? 'wss:' : 'ws:';
                const wsUrl = wsProtocol + '//' + location.host + '/ws';
                
                console.log('Connecting to WebSocket:', wsUrl);
                
                try {
                    this.websocket = new WebSocket(wsUrl);
                    
                    this.websocket.onopen = () => {
                        this.isConnected = true;
                        this.updateStatus('Connected', 'connected');
                        this.enableInterface();
                        this.clearWelcome();
                        this.addMessage('LiveTalker GPU', 'Hello! GPU acceleration is active. How can I help you today?', false);
                    };
                    
                    this.websocket.onmessage = (event) => {
                        const data = JSON.parse(event.data);
                        this.handleMessage(data);
                    };
                    
                    this.websocket.onclose = () => {
                        this.isConnected = false;
                        this.updateStatus('Disconnected', 'disconnected');
                        this.disableInterface();
                        console.log('WebSocket disconnected, attempting reconnect in 3s...');
                        setTimeout(() => this.connect(), 3000);
                    };
                    
                    this.websocket.onerror = (error) => {
                        console.error('WebSocket error:', error);
                        this.updateStatus('Error', 'disconnected');
                    };
                    
                } catch (error) {
                    console.error('Failed to create WebSocket:', error);
                    this.updateStatus('Failed', 'disconnected');
                    setTimeout(() => this.connect(), 5000);
                }
            }
            
            enableInterface() {
                this.messageInput.disabled = false;
                this.sendBtn.disabled = false;
                this.voiceBtn.disabled = false;
            }
            
            disableInterface() {
                this.messageInput.disabled = true;
                this.sendBtn.disabled = true;
                this.voiceBtn.disabled = true;
            }
            
            updateStatus(message, className) {
                this.status.textContent = message;
                this.status.className = 'status ' + className;
            }
            
            clearWelcome() {
                const welcome = this.messagesContainer.querySelector('div[style]');
                if (welcome) {
                    welcome.style.display = 'none';
                }
            }
            
            sendMessage() {
                const message = this.messageInput.value.trim();
                if (!message || !this.isConnected) return;
                
                this.addMessage('You', message, true);
                
                const data = {
                    type: 'text',
                    text: message
                };
                
                this.websocket.send(JSON.stringify(data));
                this.messageInput.value = '';
            }
            
            handleMessage(data) {
                console.log('Received message:', data);
                
                if (data.type === 'text') {
                    this.addMessage('LiveTalker GPU', data.text, false);
                    
                    // Speak the response using browser TTS
                    if ('speechSynthesis' in window) {
                        const utterance = new SpeechSynthesisUtterance(data.text);
                        utterance.rate = 0.9;
                        utterance.pitch = 1.0;
                        speechSynthesis.speak(utterance);
                    }
                } else if (data.type === 'error') {
                    this.addMessage('System', 'Error: ' + data.error, false);
                }
            }
            
            addMessage(sender, text, isUser) {
                const messageDiv = document.createElement('div');
                messageDiv.className = 'message ' + (isUser ? 'user-msg' : 'assistant-msg');
                messageDiv.innerHTML = '<div class="bubble"><strong>' + sender + ':</strong> ' + text + '</div>';
                
                this.messagesContainer.appendChild(messageDiv);
                this.messagesContainer.scrollTop = this.messagesContainer.scrollHeight;
            }
        }
        
        document.addEventListener('DOMContentLoaded', () => {
            new VoiceChatInterface();
        });
    </script>
</body>
</html>
"""

@app.on_event("startup")
async def startup_event():
    """Initialize GPU model on startup"""
    await initialize_gpu_model()

@app.get("/", response_class=HTMLResponse)
async def voice_chat():
    """Main voice chat interface"""
    return VOICE_CHAT_HTML

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint for real-time voice chat"""
    await websocket.accept()
    logger.info('GPU Voice chat WebSocket connected')
    
    try:
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)
            logger.info(f'Received message: {message}')
            
            if message.get('type') == 'text':
                user_text = message.get('text', '')
                
                # Generate GPU response
                response = await generate_gpu_response(user_text)
                
                await websocket.send_text(json.dumps({
                    'type': 'text',
                    'text': response,
                    'speaker': 'assistant'
                }))
                
    except WebSocketDisconnect:
        logger.info('GPU Voice chat WebSocket disconnected')
    except Exception as e:
        logger.error(f'WebSocket error: {e}')

@app.get("/health")
async def health():
    """Health check endpoint"""
    global gpu_model, device
    
    return {
        'status': 'ok', 
        'mode': 'stable_gpu',
        'gpu_available': gpu_model is not None,
        'device': str(device) if device else 'unknown',
        'features': ['voice_input', 'voice_output', 'gpu_inference']
    }

def main():
    """Start the stable GPU voice chat server"""
    logger.info("🚀 Starting Stable GPU LiveTalker Server")
    uvicorn.run(app, host="0.0.0.0", port=8002, log_level="info")

if __name__ == "__main__":
    main()