#!/usr/bin/env python3
"""
Simple GPU Test for LiveTalker
Minimal implementation to test GPU model loading
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def test_gpu():
    logger.info("🔥 Testing GPU availability...")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"CUDA Version: {torch.version.cuda}")
    
    logger.info("🤖 Loading model on GPU...")
    try:
        tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
        model = AutoModelForCausalLM.from_pretrained(
            "microsoft/DialoGPT-small",
            torch_dtype=torch.float16,
            device_map="auto"
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        model.eval()
        logger.info("✅ Model loaded successfully!")
        
        # Test generation
        logger.info("🧠 Testing generation...")
        inputs = tokenizer.encode("Hello, how are you?" + tokenizer.eos_token, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = inputs.to("cuda")
            
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_length=inputs.shape[1] + 10,
                do_sample=True,
                temperature=0.8,
                pad_token_id=tokenizer.eos_token_id
            )
            
        response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        logger.info(f"✅ Generated response: '{response}'")
        return True
        
    except Exception as e:
        logger.error(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return False

if __name__ == "__main__":
    success = test_gpu()
    print(f"\n{'✅ SUCCESS' if success else '❌ FAILED'}")