#!/usr/bin/env python3
"""
uncloseai - Python client library for OpenAI-compatible APIs
Supports streaming and non-streaming chat, model discovery, and TTS
Compatible with vLLM, Ollama, and OpenAI-compatible endpoints
"""

import requests
import json
import os
from typing import List, Dict, Optional, Iterator, Union


class uncloseai:
    """Client for OpenAI-compatible API endpoints with streaming support"""

    def __init__(
        self,
        model_endpoints: Optional[List[str]] = None,
        tts_endpoints: Optional[List[str]] = None,
        api_key: Optional[str] = None,
        timeout: int = 30
    ):
        """
        Initialize uncloseai. client with automatic model discovery

        Args:
            model_endpoints: List of model endpoint URLs (defaults to MODEL_ENDPOINT_* env vars)
            tts_endpoints: List of TTS endpoint URLs (defaults to TTS_ENDPOINT_* env vars)
            api_key: Optional API key for authentication
            timeout: Request timeout in seconds
        """
        self.timeout = timeout
        self.api_key = api_key
        self.models: List[Dict] = []
        self.tts_endpoints: List[str] = []

        # Discover endpoints from environment or use provided
        if model_endpoints is None:
            model_endpoints = self._discover_env_endpoints("MODEL_ENDPOINT")
        if tts_endpoints is None:
            tts_endpoints = self._discover_env_endpoints("TTS_ENDPOINT")

        # Discover models from each endpoint
        for endpoint in model_endpoints:
            self._discover_models_from_endpoint(endpoint)

        self.tts_endpoints = tts_endpoints

    def _discover_env_endpoints(self, prefix: str) -> List[str]:
        """Discover endpoints from environment variables like PREFIX_1, PREFIX_2, ..."""
        endpoints = []
        for i in range(1, 10000):
            endpoint = os.getenv(f"{prefix}_{i}")
            if not endpoint:
                break
            endpoints.append(endpoint)
        return endpoints

    def _discover_models_from_endpoint(self, endpoint: str) -> None:
        """Discover available models from an endpoint"""
        try:
            headers = {}
            if self.api_key:
                headers["Authorization"] = f"Bearer {self.api_key}"

            response = requests.get(
                f"{endpoint}/models",
                headers=headers,
                timeout=self.timeout
            )

            if response.status_code == 200:
                data = response.json()
                for model in data.get("data", []):
                    self.models.append({
                        "id": model["id"],
                        "endpoint": endpoint,
                        "max_tokens": model.get("max_model_len", 8192)
                    })
        except Exception as e:
            print(f"Warning: Failed to discover models from {endpoint}: {e}")

    def list_models(self) -> List[Dict]:
        """Return list of discovered models with their metadata"""
        return self.models.copy()

    def chat(
        self,
        messages: List[Dict[str, str]],
        model: Optional[str] = None,
        max_tokens: int = 100,
        temperature: float = 0.7,
        **kwargs
    ) -> Dict:
        """
        Non-streaming chat completion

        Args:
            messages: List of message dicts with 'role' and 'content'
            model: Model ID (defaults to first available model)
            max_tokens: Maximum tokens in response
            temperature: Sampling temperature
            **kwargs: Additional parameters to pass to the API

        Returns:
            Response dict with 'choices' containing the completion
        """
        model_info = self._get_model_info(model)

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        payload = {
            "model": model_info["id"],
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": False,
            **kwargs
        }

        response = requests.post(
            f"{model_info['endpoint']}/chat/completions",
            headers=headers,
            json=payload,
            timeout=self.timeout
        )
        response.raise_for_status()

        return response.json()

    def chat_stream(
        self,
        messages: List[Dict[str, str]],
        model: Optional[str] = None,
        max_tokens: int = 500,
        temperature: float = 0.7,
        **kwargs
    ) -> Iterator[Dict]:
        """
        Streaming chat completion using Server-Sent Events

        Args:
            messages: List of message dicts with 'role' and 'content'
            model: Model ID (defaults to first available model)
            max_tokens: Maximum tokens in response
            temperature: Sampling temperature
            **kwargs: Additional parameters to pass to the API

        Yields:
            Chunk dicts with 'choices' containing delta content
        """
        model_info = self._get_model_info(model)

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        payload = {
            "model": model_info["id"],
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": True,
            **kwargs
        }

        response = requests.post(
            f"{model_info['endpoint']}/chat/completions",
            headers=headers,
            json=payload,
            timeout=self.timeout,
            stream=True
        )
        response.raise_for_status()

        # Parse SSE stream
        for line in response.iter_lines():
            if not line:
                continue

            line = line.decode('utf-8')

            # SSE format: "data: {...}"
            if line.startswith('data: '):
                data = line[6:]  # Remove "data: " prefix

                # Check for stream termination
                if data.strip() == '[DONE]':
                    break

                try:
                    chunk = json.loads(data)
                    yield chunk
                except json.JSONDecodeError:
                    continue

    def tts(
        self,
        text: str,
        voice: str = "alloy",
        model: str = "tts-1",
        response_format: str = "mp3"
    ) -> bytes:
        """
        Generate speech from text

        Args:
            text: Input text to convert to speech
            voice: Voice name (alloy, echo, fable, onyx, nova, shimmer)
            model: TTS model (tts-1 or tts-1-hd)
            response_format: Audio format (mp3, opus, aac, flac)

        Returns:
            Audio data as bytes
        """
        if not self.tts_endpoints:
            raise ValueError("No TTS endpoints available")

        endpoint = self.tts_endpoints[0]

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        payload = {
            "model": model,
            "voice": voice,
            "input": text,
            "response_format": response_format
        }

        response = requests.post(
            f"{endpoint}/audio/speech",
            headers=headers,
            json=payload,
            timeout=self.timeout
        )
        response.raise_for_status()

        return response.content

    def _get_model_info(self, model: Optional[str] = None) -> Dict:
        """Get model info by ID or return first available model"""
        if not self.models:
            raise ValueError("No models available. Check endpoint configuration.")

        if model is None:
            return self.models[0]

        for m in self.models:
            if m["id"] == model:
                return m

        raise ValueError(f"Model '{model}' not found in discovered models")


# Demo usage when run as script
if __name__ == "__main__":
    print("=== uncloseai. Python Client (with Streaming) ===\n")

    # Initialize client (auto-discovers from environment)
    client = uncloseai()

    if not client.models:
        print("ERROR: No models discovered. Set environment variables:")
        print("  MODEL_ENDPOINT_1, MODEL_ENDPOINT_2, etc.")
        exit(1)

    print(f"Discovered {len(client.models)} model(s)")
    for model in client.models:
        print(f"  - {model['id']} (max_tokens: {model['max_tokens']})")
    print()

    # Non-streaming chat example
    print("=== Non-Streaming Chat ===")
    response = client.chat(
        messages=[
            {"role": "system", "content": "You are a helpful AI assistant."},
            {"role": "user", "content": "Explain quantum computing in one sentence."}
        ],
        max_tokens=100
    )
    print(f"Model: {response['model']}")
    print(f"Response: {response['choices'][0]['message']['content']}\n")

    # Streaming chat example
    print("=== Streaming Chat ===")
    if len(client.models) > 1:
        model_id = client.models[1]["id"]
    else:
        model_id = None

    print(f"Model: {model_id or client.models[0]['id']}")
    print("Response: ", end="", flush=True)

    for chunk in client.chat_stream(
        messages=[
            {"role": "system", "content": "You are a coding assistant."},
            {"role": "user", "content": "Write a Python function to check if a number is prime"}
        ],
        model=model_id,
        max_tokens=200
    ):
        if chunk.get("choices") and len(chunk["choices"]) > 0:
            delta = chunk["choices"][0].get("delta", {})
            content = delta.get("content", "")
            if content:
                print(content, end="", flush=True)

    print("\n")

    # TTS example
    if client.tts_endpoints:
        print("=== TTS Speech Generation ===")
        audio_data = client.tts(
            text="Hello from uncloseai. Python client! This demonstrates text to speech with streaming support.",
            voice="alloy"
        )

        with open("speech.mp3", "wb") as f:
            f.write(audio_data)

        print(f"[OK] Speech file created: speech.mp3 ({len(audio_data)} bytes)\n")

    print("=== Examples Complete ===")
