#!/usr/bin/env python3
"""
uncloseai. - Async Python Client (httpx)
A Python async client for OpenAI-compatible APIs with streaming support
"""

import httpx
import json
import os
import asyncio
from typing import List, Dict, Optional, AsyncIterator


class uncloseai:
    """Async client for OpenAI-compatible API endpoints with streaming support"""

    def __init__(
        self,
        model_endpoints: Optional[List[str]] = None,
        tts_endpoints: Optional[List[str]] = None,
        api_key: Optional[str] = None,
        timeout: float = 30.0
    ):
        self.timeout = timeout
        self.api_key = api_key
        self.models: List[Dict] = []
        self.tts_endpoints: List[str] = []
        self._initialized = False
        self._model_endpoints = model_endpoints or self._discover_env_endpoints("MODEL_ENDPOINT")
        self._tts_endpoints = tts_endpoints or self._discover_env_endpoints("TTS_ENDPOINT")

    def _discover_env_endpoints(self, prefix: str) -> List[str]:
        endpoints = []
        for i in range(1, 10000):
            endpoint = os.getenv(f"{prefix}_{i}")
            if not endpoint:
                break
            endpoints.append(endpoint)
        return endpoints

    async def _ensure_initialized(self):
        if self._initialized:
            return

        async with httpx.AsyncClient(timeout=self.timeout) as client:
            for endpoint in self._model_endpoints:
                await self._discover_models_from_endpoint(client, endpoint)

        self.tts_endpoints = self._tts_endpoints
        self._initialized = True

    async def _discover_models_from_endpoint(self, client: httpx.AsyncClient, endpoint: str) -> None:
        try:
            headers = {}
            if self.api_key:
                headers["Authorization"] = f"Bearer {self.api_key}"

            response = await client.get(f"{endpoint}/models", headers=headers)

            if response.status_code == 200:
                data = response.json()
                for model in data.get("data", []):
                    self.models.append({
                        "id": model["id"],
                        "endpoint": endpoint,
                        "max_tokens": model.get("max_model_len", 8192)
                    })
        except Exception as e:
            print(f"Warning: Failed to discover models from {endpoint}: {e}")

    async def list_models(self) -> List[Dict]:
        await self._ensure_initialized()
        return self.models.copy()

    async def chat(
        self,
        messages: List[Dict[str, str]],
        model: Optional[str] = None,
        max_tokens: int = 100,
        temperature: float = 0.7,
        **kwargs
    ) -> Dict:
        await self._ensure_initialized()
        model_info = self._get_model_info(model)

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        payload = {
            "model": model_info["id"],
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": False,
            **kwargs
        }

        async with httpx.AsyncClient(timeout=self.timeout) as client:
            response = await client.post(
                f"{model_info['endpoint']}/chat/completions",
                headers=headers,
                json=payload
            )
            response.raise_for_status()
            return response.json()

    async def chat_stream(
        self,
        messages: List[Dict[str, str]],
        model: Optional[str] = None,
        max_tokens: int = 500,
        temperature: float = 0.7,
        **kwargs
    ) -> AsyncIterator[Dict]:
        await self._ensure_initialized()
        model_info = self._get_model_info(model)

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        payload = {
            "model": model_info["id"],
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": True,
            **kwargs
        }

        async with httpx.AsyncClient(timeout=self.timeout) as client:
            async with client.stream(
                "POST",
                f"{model_info['endpoint']}/chat/completions",
                headers=headers,
                json=payload
            ) as response:
                response.raise_for_status()

                async for line in response.aiter_lines():
                    if not line:
                        continue

                    if line.startswith('data: '):
                        data = line[6:]

                        if data.strip() == '[DONE]':
                            break

                        try:
                            chunk = json.loads(data)
                            yield chunk
                        except json.JSONDecodeError:
                            continue

    async def tts(
        self,
        text: str,
        voice: str = "alloy",
        model: str = "tts-1",
        response_format: str = "mp3"
    ) -> bytes:
        await self._ensure_initialized()

        if not self.tts_endpoints:
            raise ValueError("No TTS endpoints available")

        endpoint = self.tts_endpoints[0]

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        payload = {
            "model": model,
            "voice": voice,
            "input": text,
            "response_format": response_format
        }

        async with httpx.AsyncClient(timeout=self.timeout) as client:
            response = await client.post(
                f"{endpoint}/audio/speech",
                headers=headers,
                json=payload
            )
            response.raise_for_status()
            return response.content

    def _get_model_info(self, model: Optional[str] = None) -> Dict:
        if not self.models:
            raise ValueError("No models available. Check endpoint configuration.")

        if model is None:
            return self.models[0]

        for m in self.models:
            if m["id"] == model:
                return m

        raise ValueError(f"Model '{model}' not found in discovered models")


async def main():
    print("=== uncloseai. Python Async Client (httpx) ===\n")

    client = uncloseai()

    models = await client.list_models()
    if not models:
        print("ERROR: No models discovered. Set environment variables:")
        print("  MODEL_ENDPOINT_1, MODEL_ENDPOINT_2, etc.")
        return

    print(f"Discovered {len(models)} model(s)")
    for model in models:
        print(f"  - {model['id']} (max_tokens: {model['max_tokens']})")
    print()

    # Non-streaming chat
    print("=== Non-Streaming Chat ===")
    response = await client.chat(
        messages=[
            {"role": "system", "content": "You are a helpful AI assistant."},
            {"role": "user", "content": "Explain quantum computing in one sentence."}
        ],
        max_tokens=100
    )
    print(f"Model: {response['model']}")
    print(f"Response: {response['choices'][0]['message']['content']}\n")

    # Streaming chat
    print("=== Streaming Chat ===")
    model_id = models[1]["id"] if len(models) > 1 else None
    print(f"Model: {model_id or models[0]['id']}")
    print("Response: ", end="", flush=True)

    async for chunk in client.chat_stream(
        messages=[
            {"role": "system", "content": "You are a coding assistant."},
            {"role": "user", "content": "Write a Python async function to fetch multiple URLs"}
        ],
        model=model_id,
        max_tokens=200
    ):
        if chunk.get("choices") and len(chunk["choices"]) > 0:
            delta = chunk["choices"][0].get("delta", {})
            content = delta.get("content", "")
            if content:
                print(content, end="", flush=True)

    print("\n")

    # TTS
    if client.tts_endpoints:
        print("=== TTS Speech Generation ===")
        audio_data = await client.tts(
            text="Hello from uncloseai. Python async client! This demonstrates text to speech with streaming support.",
            voice="alloy"
        )

        with open("speech.mp3", "wb") as f:
            f.write(audio_data)

        print(f"[OK] Speech file created: speech.mp3 ({len(audio_data)} bytes)\n")

    print("=== Examples Complete ===")


if __name__ == "__main__":
    asyncio.run(main())
