import java.io.*;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.util.*;
import java.util.function.Consumer;

class ModelInfo {
    String id;
    String endpoint;
    int maxTokens;

    ModelInfo(String id, String endpoint, int maxTokens) {
        this.id = id;
        this.endpoint = endpoint;
        this.maxTokens = maxTokens;
    }
}

public class UncloseAI {
    private List<ModelInfo> models = new ArrayList<>();
    private List<String> ttsEndpoints = new ArrayList<>();
    private String apiKey;
    private int timeout = 30000;
    private boolean debug = false;

    public UncloseAI() {
        this(null, null, null, 30000, false);
    }

    public UncloseAI(List<String> endpoints, List<String> ttsEndpoints, String apiKey, int timeout, boolean debug) {
        this.apiKey = apiKey;
        this.timeout = timeout;
        this.debug = debug;

        if (endpoints == null) {
            endpoints = discoverEndpointsFromEnv("MODEL_ENDPOINT");
        }
        if (ttsEndpoints == null) {
            ttsEndpoints = discoverEndpointsFromEnv("TTS_ENDPOINT");
        }

        if (debug) {
            System.out.println("[DEBUG] Initialized with " + endpoints.size() + " endpoint(s)");
        }

        discoverModels(endpoints);
        this.ttsEndpoints = ttsEndpoints;
    }

    public List<ModelInfo> listModels() {
        return new ArrayList<>(models);
    }

    public String chat(List<Map<String, String>> messages, String model, int maxTokens) throws IOException {
        ModelInfo modelInfo = resolveModel(model);
        String jsonRequest = buildChatRequest(modelInfo.id, messages, maxTokens, false);
        String response = postJSON(modelInfo.endpoint + "/chat/completions", jsonRequest);
        return extractContent(response);
    }

    public void chatStream(List<Map<String, String>> messages, String model, int maxTokens, Consumer<String> callback) throws IOException {
        ModelInfo modelInfo = resolveModel(model);
        String jsonRequest = buildChatRequest(modelInfo.id, messages, maxTokens, true);
        
        URL url = new URL(modelInfo.endpoint + "/chat/completions");
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setDoOutput(true);
        conn.setConnectTimeout(timeout);
        conn.setReadTimeout(timeout);

        try (OutputStream os = conn.getOutputStream()) {
            os.write(jsonRequest.getBytes(StandardCharsets.UTF_8));
        }

        try (BufferedReader br = new BufferedReader(
                new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
            String line;
            while ((line = br.readLine()) != null) {
                if (line.startsWith("data: ")) {
                    String data = line.substring(6).trim();
                    if ("[DONE]".equals(data)) {
                        break;
                    }
                    String content = extractStreamContent(data);
                    if (content != null && !content.isEmpty()) {
                        callback.accept(content);
                    }
                }
            }
        }
    }

    public byte[] tts(String text, String voice, String model) throws IOException {
        if (ttsEndpoints.isEmpty()) {
            throw new IOException("No TTS endpoints available");
        }

        String jsonRequest = String.format(
            "{\"model\":\"%s\",\"voice\":\"%s\",\"input\":\"%s\"}",
            model, voice, text.replace("\"", "\\\"")
        );

        return postJSONBinary(ttsEndpoints.get(0) + "/audio/speech", jsonRequest);
    }

    private List<String> discoverEndpointsFromEnv(String prefix) {
        List<String> endpoints = new ArrayList<>();
        for (int i = 1; i < 10000; i++) {
            String endpoint = System.getenv(prefix + "_" + i);
            if (endpoint == null || endpoint.isEmpty()) {
                break;
            }
            endpoints.add(endpoint);
        }
        return endpoints;
    }

    private void discoverModels(List<String> endpoints) {
        for (String endpoint : endpoints) {
            if (debug) {
                System.out.println("[DEBUG] Discovering from: " + endpoint);
            }

            try {
                String response = getJSON(endpoint + "/models");
                parseModels(response, endpoint);
            } catch (Exception e) {
                if (debug) {
                    System.out.println("[DEBUG] Error: " + e.getMessage());
                }
            }
        }
    }

    private void parseModels(String jsonResponse, String endpoint) {
        int dataIndex = jsonResponse.indexOf("\"data\":[");
        if (dataIndex == -1) return;

        String dataSection = jsonResponse.substring(dataIndex + 8);
        int pos = 0;
        while (pos < dataSection.length()) {
            int idIndex = dataSection.indexOf("\"id\":\"", pos);
            if (idIndex == -1) break;

            int idStart = idIndex + 6;
            int idEnd = dataSection.indexOf("\"", idStart);
            String modelId = dataSection.substring(idStart, idEnd);

            if (modelId.startsWith("modelperm-")) {
                pos = idEnd + 1;
                continue;
            }

            int maxTokens = 8192;
            int maxLenIndex = dataSection.indexOf("\"max_model_len\":", idEnd);
            if (maxLenIndex != -1 && maxLenIndex < dataSection.indexOf("}", idEnd)) {
                int maxLenStart = maxLenIndex + 16;
                int maxLenEnd = dataSection.indexOf(",", maxLenStart);
                if (maxLenEnd == -1) maxLenEnd = dataSection.indexOf("}", maxLenStart);
                if (maxLenEnd != -1) {
                    try {
                        maxTokens = Integer.parseInt(dataSection.substring(maxLenStart, maxLenEnd).trim());
                    } catch (NumberFormatException ignored) {}
                }
            }

            models.add(new ModelInfo(modelId, endpoint, maxTokens));
            if (debug) {
                System.out.println("[DEBUG] Discovered: " + modelId);
            }

            pos = idEnd + 1;
        }
    }

    private ModelInfo resolveModel(String model) throws IOException {
        if (models.isEmpty()) {
            throw new IOException("No models available");
        }
        if (model == null || model.isEmpty()) {
            return models.get(0);
        }
        for (ModelInfo m : models) {
            if (m.id.equals(model)) {
                return m;
            }
        }
        throw new IOException("Model '" + model + "' not found");
    }

    private String buildChatRequest(String modelId, List<Map<String, String>> messages, int maxTokens, boolean stream) {
        StringBuilder sb = new StringBuilder();
        sb.append("{\"model\":\"").append(modelId).append("\",");
        sb.append("\"messages\":[");
        for (int i = 0; i < messages.size(); i++) {
            if (i > 0) sb.append(",");
            Map<String, String> msg = messages.get(i);
            sb.append("{\"role\":\"").append(msg.get("role")).append("\",");
            sb.append("\"content\":\"").append(msg.get("content").replace("\"", "\\\"")).append("\"}");
        }
        sb.append("],\"max_tokens\":").append(maxTokens);
        if (stream) {
            sb.append(",\"stream\":true");
        }
        sb.append("}");
        return sb.toString();
    }

    private String getJSON(String urlString) throws IOException {
        URL url = new URL(urlString);
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("GET");
        conn.setConnectTimeout(10000);
        conn.setReadTimeout(10000);

        try (BufferedReader br = new BufferedReader(
                new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
            StringBuilder response = new StringBuilder();
            String line;
            while ((line = br.readLine()) != null) {
                response.append(line.trim());
            }
            return response.toString();
        }
    }

    private String postJSON(String urlString, String jsonRequest) throws IOException {
        URL url = new URL(urlString);
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setDoOutput(true);
        conn.setConnectTimeout(timeout);
        conn.setReadTimeout(timeout);

        try (OutputStream os = conn.getOutputStream()) {
            os.write(jsonRequest.getBytes(StandardCharsets.UTF_8));
        }

        try (BufferedReader br = new BufferedReader(
                new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
            StringBuilder response = new StringBuilder();
            String line;
            while ((line = br.readLine()) != null) {
                response.append(line.trim());
            }
            return response.toString();
        }
    }

    private byte[] postJSONBinary(String urlString, String jsonRequest) throws IOException {
        URL url = new URL(urlString);
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setDoOutput(true);
        conn.setConnectTimeout(timeout);
        conn.setReadTimeout(timeout);

        try (OutputStream os = conn.getOutputStream()) {
            os.write(jsonRequest.getBytes(StandardCharsets.UTF_8));
        }

        try (InputStream is = conn.getInputStream()) {
            ByteArrayOutputStream buffer = new ByteArrayOutputStream();
            byte[] data = new byte[1024];
            int nRead;
            while ((nRead = is.read(data, 0, data.length)) != -1) {
                buffer.write(data, 0, nRead);
            }
            return buffer.toByteArray();
        }
    }

    private String extractContent(String jsonResponse) {
        int contentIndex = jsonResponse.indexOf("\"content\":\"");
        if (contentIndex == -1) return jsonResponse;

        int startIndex = contentIndex + 11;
        int endIndex = jsonResponse.indexOf("\"", startIndex);
        while (endIndex > 0 && jsonResponse.charAt(endIndex - 1) == '\\') {
            endIndex = jsonResponse.indexOf("\"", endIndex + 1);
        }
        if (endIndex == -1) return jsonResponse.substring(startIndex);

        String content = jsonResponse.substring(startIndex, endIndex);
        return content.replace("\\n", "\n").replace("\\\"", "\"").replace("\\\\", "\\");
    }

    private String extractStreamContent(String jsonChunk) {
        int contentIndex = jsonChunk.indexOf("\"content\":\"");
        if (contentIndex == -1) return null;

        int startIndex = contentIndex + 11;
        int endIndex = jsonChunk.indexOf("\"", startIndex);
        if (endIndex == -1) return null;

        return jsonChunk.substring(startIndex, endIndex)
            .replace("\\n", "\n").replace("\\\"", "\"").replace("\\\\", "\\");
    }

    // Demo when run as application
    public static void main(String[] args) {
        System.out.println("=== UncloseAI Java Client (with Streaming) ===\n");

        UncloseAI client = new UncloseAI(null, null, null, 30000, true);

        if (client.models.isEmpty()) {
            System.out.println("ERROR: No models discovered. Set environment variables:");
            System.out.println("  MODEL_ENDPOINT_1, MODEL_ENDPOINT_2, etc.");
            System.exit(1);
        }

        System.out.println("\nDiscovered " + client.models.size() + " model(s):");
        for (ModelInfo m : client.models) {
            System.out.println("  - " + m.id + " (max_tokens: " + m.maxTokens + ")");
        }
        System.out.println();

        // Non-streaming chat
        System.out.println("=== Non-Streaming Chat ===");
        try {
            List<Map<String, String>> messages = Arrays.asList(
                new HashMap<String, String>() {{ put("role", "system"); put("content", "You are a helpful AI assistant."); }},
                new HashMap<String, String>() {{ put("role", "user"); put("content", "Explain quantum computing in one sentence."); }}
            );
            String response = client.chat(messages, null, 100);
            System.out.println("Response: " + response + "\n");
        } catch (IOException e) {
            System.out.println("Error: " + e.getMessage() + "\n");
        }

        // Streaming chat
        System.out.println("=== Streaming Chat ===");
        String modelId = client.models.size() > 1 ? client.models.get(1).id : null;
        System.out.println("Model: " + (modelId != null ? modelId : client.models.get(0).id));
        System.out.print("Response: ");
        try {
            List<Map<String, String>> messages = Arrays.asList(
                new HashMap<String, String>() {{ put("role", "system"); put("content", "You are a coding assistant."); }},
                new HashMap<String, String>() {{ put("role", "user"); put("content", "Write a Java function to check if a number is prime"); }}
            );
            client.chatStream(messages, modelId, 200, content -> System.out.print(content));
            System.out.println("\n");
        } catch (IOException e) {
            System.out.println("\nError: " + e.getMessage() + "\n");
        }

        // TTS
        if (!client.ttsEndpoints.isEmpty()) {
            System.out.println("=== TTS Speech Generation ===");
            try {
                byte[] audio = client.tts("Hello from UncloseAI Java client! This demonstrates streaming support.", "alloy", "tts-1");
                Files.write(Paths.get("speech.mp3"), audio);
                System.out.println("[OK] Speech file created: speech.mp3 (" + audio.length + " bytes)\n");
            } catch (IOException e) {
                System.out.println("[ERROR] TTS Error: " + e.getMessage() + "\n");
            }
        }

        System.out.println("=== Examples Complete ===");
    }
}
