from typing import Tuple, List, Any, Callable
from .constants import GRIP_OSC_LISTEN_PORTS, GRIP_OSC_RESPONSE_PORTS
from ..pythonosc.osc_message import OscMessage, ParseError
from ..pythonosc.osc_bundle import OscBundle
from ..pythonosc.osc_message_builder import OscMessageBuilder, BuildError

import re
import errno
import socket
import logging
import traceback

class OSCServer:
    def __init__(self,
                 listen_ports: List[int] = None,
                 response_ports: List[int] = None,
                 bind_host: str = '127.0.0.1'):
        """
        OSC server with multi-port support for dual-daemon architecture.

        Listens on multiple ports simultaneously (one per daemon mode) and broadcasts
        responses to all response ports. This allows both production and development
        Grip daemons to communicate with Ableton at the same time.

        Architecture:
            - Creates one UDP socket per listen port
            - Polls all sockets in process() to receive from any daemon
            - Broadcasts responses to ALL response ports so both daemons stay in sync

        Args:
            listen_ports: List of ports to listen on. Defaults to GRIP_OSC_LISTEN_PORTS
                          (47140 for production, 47240 for development).
            response_ports: List of ports to send responses to. Defaults to GRIP_OSC_RESPONSE_PORTS
                            (47141 for production, 47241 for development).
            bind_host: Host address to bind to. Defaults to '127.0.0.1' (localhost only).
        """
        self._listen_ports = listen_ports or GRIP_OSC_LISTEN_PORTS
        self._response_ports = response_ports or GRIP_OSC_RESPONSE_PORTS
        self._bind_host = bind_host
        self._callbacks = {}
        self.logger = logging.getLogger("grip")

        # Create a socket for each listen port
        self._sockets = []
        for port in self._listen_ports:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.setblocking(0)
            sock.bind((bind_host, port))
            self._sockets.append(sock)
            self.logger.info("OSC listening on %s:%d", bind_host, port)

        self.logger.info("OSC responses will broadcast to ports: %s", self._response_ports)

    def add_handler(self, address: str, handler: Callable) -> None:
        """
        Add an OSC handler.

        Args:
            address: The OSC address string
            handler: A handler function, with signature:
                     params: Tuple[Any, ...]
        """
        self._callbacks[address] = handler

    def clear_handlers(self) -> None:
        """
        Remove all existing OSC handlers.
        """
        self._callbacks = {}

    def send(self,
             address: str,
             params: Tuple = (),
             remote_addr: Tuple[str, int] = None) -> None:
        """
        Send an OSC message, broadcasting to all response ports by default.

        Args:
            address: The OSC address (e.g. /frequency)
            params: A tuple of zero or more OSC params
            remote_addr: Specific remote address to send to, as a 2-tuple (hostname, port).
                         If None, broadcasts to all configured response ports.
        """
        msg_builder = OscMessageBuilder(address)
        for param in params:
            msg_builder.add_arg(param)

        try:
            msg = msg_builder.build()
            if remote_addr is not None:
                # Send to specific address
                self._sockets[0].sendto(msg.dgram, remote_addr)
            else:
                # Broadcast to all response ports
                for port in self._response_ports:
                    self._sockets[0].sendto(msg.dgram, (self._bind_host, port))
        except BuildError:
            self.logger.error("Grip: OSC build error: %s" % (traceback.format_exc()))

    def process_message(self, message, remote_addr):
        """
        Process an incoming OSC message and broadcast response to all daemons.

        Responses are broadcast to all configured response ports so both production
        and development daemons receive the data.
        """
        if message.address in self._callbacks:
            callback = self._callbacks[message.address]
            rv = callback(message.params)

            if rv is not None:
                assert isinstance(rv, tuple)
                # Broadcast response to all response ports
                self.send(address=message.address, params=rv)
        elif "*" in message.address:
            regex = message.address.replace("*", "[^/]+")
            for callback_address, callback in self._callbacks.items():
                if re.match(regex, callback_address):
                    try:
                        rv = callback(message.params)
                    except ValueError:
                        # Don't throw errors for queries that require more arguments
                        # (e.g. /live/track/get/send with no args)
                        continue
                    except AttributeError:
                        # Don't throw errors when trying to create listeners for properties
                        # that can't be listened for (e.g. can_be_armed, is_foldable)
                        continue
                    if rv is not None:
                        assert isinstance(rv, tuple)
                        # Broadcast response to all response ports
                        self.send(address=callback_address, params=rv)
        else:
            self.logger.error("Grip: Unknown OSC address: %s" % message.address)

    def process_bundle(self, bundle, remote_addr):
        for i in bundle:
            if OscBundle.dgram_is_bundle(i.dgram):
                self.process_bundle(i, remote_addr)
            else:
                self.process_message(i, remote_addr)

    def parse_bundle(self, data, remote_addr):
        if OscBundle.dgram_is_bundle(data):
            try:
                bundle = OscBundle(data)
                self.process_bundle(bundle, remote_addr)
            except ParseError:
                self.logger.error("Grip: Error parsing OSC bundle: %s" % (traceback.format_exc()))
        else:
            try:
                message = OscMessage(data)
                self.process_message(message, remote_addr)
            except ParseError:
                self.logger.error("Grip: Error parsing OSC message: %s" % (traceback.format_exc()))

    def process(self) -> None:
        """
        Poll all listen sockets and process any queued OSC messages.

        Iterates through all sockets (one per listen port) and drains any pending
        messages. This allows both production and development daemons to send
        commands to Ableton simultaneously.
        """
        for sock in self._sockets:
            self._process_socket(sock)

    def _process_socket(self, sock: socket.socket) -> None:
        """
        Process all pending messages on a single socket.

        Args:
            sock: The UDP socket to drain messages from.
        """
        try:
            while True:
                data, remote_addr = sock.recvfrom(65536)
                self.parse_bundle(data, remote_addr)

        except socket.error as e:
            if e.errno == errno.ECONNRESET:
                # Benign error that occurs on startup on Windows
                self.logger.warning("Grip: Non-fatal socket error: %s" % (traceback.format_exc()))
            elif e.errno == errno.EAGAIN or e.errno == errno.EWOULDBLOCK:
                # No data available on non-blocking socket (expected)
                pass
            else:
                self.logger.error("Grip: Socket error: %s" % (traceback.format_exc()))

        except Exception as e:
            self.logger.error("Grip: Error handling OSC message: %s" % e)
            self.logger.warning("Grip: %s" % traceback.format_exc())

    def shutdown(self) -> None:
        """
        Shutdown all server network sockets.
        """
        for sock in self._sockets:
            sock.close()
