from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional

from django.db import connections


# Use the 'uksi' database connection for all queries in this module
def _get_connection():
    """Get the UKSI database connection."""
    return connections['uksi']


def _dictfetchall(cursor) -> List[Dict[str, Any]]:
    columns = [col[0] for col in cursor.description]
    return [dict(zip(columns, row)) for row in cursor.fetchall()]


def _dictfetchone(cursor) -> Optional[Dict[str, Any]]:
    row = cursor.fetchone()
    if row is None:
        return None
    columns = [col[0] for col in cursor.description]
    return dict(zip(columns, row))


def get_taxon_by_id(taxon_id: str) -> Optional[Dict[str, Any]]:
    query = "SELECT * FROM taxa WHERE taxon_id = %s"
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (taxon_id,))
        return _dictfetchone(cursor)


def get_children(parent_taxon_id: Optional[str], *, order_by: str = "rank") -> List[Dict[str, Any]]:
    """
    Get children taxa for a given parent.
    
    Args:
        parent_taxon_id: ID of parent taxon (None for root kingdoms)
        order_by: Ordering mode - "rank" (default) or "alpha"
    """
    # Determine ORDER BY clause
    if order_by == "alpha":
        order_clause = "ORDER BY scientific_name COLLATE NOCASE"
    else:  # default to "rank"
        order_clause = "ORDER BY rank_order, scientific_name COLLATE NOCASE"
    
    if parent_taxon_id is None:
        query = f"""
            SELECT 
                taxon_id, 
                scientific_name, 
                authorship,
                taxon_rank, 
                taxonomic_status, 
                rank_order,
                children_count,
                descendant_count
            FROM taxa
            WHERE parent_taxon_id IS NULL AND taxonomic_status = 'accepted'
            {order_clause}
            LIMIT 500
        """
        params: Iterable[Any] = []
    else:
        query = f"""
            SELECT 
                taxon_id, 
                scientific_name, 
                authorship,
                taxon_rank, 
                taxonomic_status, 
                rank_order,
                children_count,
                descendant_count
            FROM taxa
            WHERE parent_taxon_id = %s AND taxonomic_status = 'accepted'
            {order_clause}
            LIMIT 500
        """
        params = (parent_taxon_id,)

    with _get_connection().cursor() as cursor:
        cursor.execute(query, params)
        return _dictfetchall(cursor)


def get_lineage_up(taxon_id: str) -> List[Dict[str, Any]]:
    query = """
        WITH RECURSIVE lineage AS (
            SELECT
                t.taxon_id,
                t.parent_taxon_id,
                t.scientific_name,
                t.taxon_rank,
                t.taxonomic_status,
                t.rank_order,
                0 AS depth
            FROM taxa t
            WHERE t.taxon_id = %s
            UNION ALL
            SELECT
                parent.taxon_id,
                parent.parent_taxon_id,
                parent.scientific_name,
                parent.taxon_rank,
                parent.taxonomic_status,
                parent.rank_order,
                lineage.depth + 1
            FROM taxa parent
            JOIN lineage ON parent.taxon_id = lineage.parent_taxon_id
        )
        SELECT *
        FROM lineage
        ORDER BY depth DESC, rank_order
    """
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (taxon_id,))
        return _dictfetchall(cursor)


def get_synonyms_for(accepted_taxon_id: str) -> List[Dict[str, Any]]:
    query = """
        SELECT synonym_id, scientific_name, authorship, taxon_rank, nomenclatural_status, rank_order
        FROM synonyms
        WHERE accepted_taxon_id = %s
        ORDER BY rank_order, scientific_name COLLATE NOCASE
    """
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (accepted_taxon_id,))
        return _dictfetchall(cursor)


def get_vernaculars_for(taxon_id: str) -> List[Dict[str, Any]]:
    query = """
        SELECT vernacular_name, language, locality, country, preferred, source
        FROM vernaculars
        WHERE taxon_id = %s
        ORDER BY preferred DESC, vernacular_name COLLATE NOCASE
    """
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (taxon_id,))
        return _dictfetchall(cursor)


def search_taxa(query_text: str, *, limit: int = 50, offset: int = 0, order_by: str = "rank") -> List[Dict[str, Any]]:
    """
    Search taxa by scientific name.
    
    Args:
        query_text: Search term
        limit: Maximum results
        offset: Results offset for pagination
        order_by: Ordering mode - "rank" (default) or "alpha"
    """
    pattern = f"%{query_text}%"
    
    # Determine ORDER BY clause
    if order_by == "alpha":
        order_clause = "ORDER BY scientific_name COLLATE NOCASE"
    else:  # default to "rank"
        order_clause = "ORDER BY rank_order, scientific_name COLLATE NOCASE"
    
    query = f"""
        SELECT taxon_id, scientific_name, authorship, taxon_rank, taxonomic_status, rank_order
        FROM taxa
        WHERE scientific_name LIKE %s ESCAPE '\\'
        {order_clause}
        LIMIT %s OFFSET %s
    """
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (pattern, limit, offset))
        return _dictfetchall(cursor)


def search_by_vernacular(query_text: str, *, limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
    pattern = f"%{query_text}%"
    query = """
        SELECT
            v.taxon_id,
            v.vernacular_name,
            v.language,
            t.scientific_name,
            t.taxon_rank,
            t.taxonomic_status,
            t.rank_order
        FROM vernaculars v
        JOIN taxa t ON t.taxon_id = v.taxon_id
        WHERE v.vernacular_name LIKE %s ESCAPE '\\'
        ORDER BY v.preferred DESC, t.rank_order, v.vernacular_name COLLATE NOCASE
        LIMIT %s OFFSET %s
    """
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (pattern, limit, offset))
        return _dictfetchall(cursor)


def fuzzy_candidates(query_text: str, *, limit: int = 20) -> List[Dict[str, Any]]:
    pattern = f"%{query_text}%"
    query = """
        SELECT taxon_id, scientific_name, authorship, taxon_rank, rank_order
        FROM taxa
        WHERE scientific_name LIKE %s ESCAPE '\\'
            OR taxon_concept_id LIKE %s ESCAPE '\\'
        ORDER BY rank_order, scientific_name COLLATE NOCASE
        LIMIT %s
    """
    with _get_connection().cursor() as cursor:
        cursor.execute(query, (pattern, pattern, limit))
        return _dictfetchall(cursor)
