import numpy as np
from app.extensions import db
from app.models.dataset import DataPoint
from app.models.analysis import AnalysisType
from sqlalchemy import func


def compute_subset_stats(dataset_id, animal_ids=None, time_start=None, time_end=None):
    """Compute summary statistics for a data subset to determine analysis eligibility."""
    q = DataPoint.query.filter_by(dataset_id=dataset_id)
    if animal_ids:
        q = q.filter(DataPoint.animal_id.in_(animal_ids))
    if time_start:
        q = q.filter(DataPoint.timestamp >= time_start)
    if time_end:
        q = q.filter(DataPoint.timestamp <= time_end)

    row = db.session.query(
        func.count(DataPoint.id).label('num_fixes'),
        func.count(func.distinct(DataPoint.animal_id)).label('num_animals'),
        func.min(DataPoint.timestamp).label('t_min'),
        func.max(DataPoint.timestamp).label('t_max'),
        func.min(DataPoint.lat).label('lat_min'),
        func.max(DataPoint.lat).label('lat_max'),
        func.min(DataPoint.lon).label('lon_min'),
        func.max(DataPoint.lon).label('lon_max'),
    ).filter(DataPoint.dataset_id == dataset_id)

    if animal_ids:
        row = row.filter(DataPoint.animal_id.in_(animal_ids))
    if time_start:
        row = row.filter(DataPoint.timestamp >= time_start)
    if time_end:
        row = row.filter(DataPoint.timestamp <= time_end)

    row = row.first()

    if not row or row.num_fixes == 0:
        return {
            'num_fixes': 0, 'num_animals': 0, 'duration_hours': 0,
            'has_regular_sampling': False,
        }

    duration_hours = 0
    if row.t_min and row.t_max:
        duration_hours = (row.t_max - row.t_min).total_seconds() / 3600

    # Estimate sampling regularity from a sample of time deltas
    # Get timestamps for one animal to check regularity
    sample_animal = db.session.query(DataPoint.animal_id)\
        .filter_by(dataset_id=dataset_id)
    if animal_ids:
        sample_animal = sample_animal.filter(DataPoint.animal_id.in_(animal_ids))
    sample_animal = sample_animal.first()

    has_regular = False
    median_rate = None
    if sample_animal:
        times = db.session.query(DataPoint.timestamp)\
            .filter_by(dataset_id=dataset_id, animal_id=sample_animal[0])\
            .order_by(DataPoint.timestamp).limit(500).all()
        if len(times) > 1:
            deltas = [(times[i + 1][0] - times[i][0]).total_seconds()
                      for i in range(len(times) - 1)]
            deltas = [d for d in deltas if d > 0]
            if deltas:
                median_rate = float(np.median(deltas))
                cv = float(np.std(deltas) / np.mean(deltas)) if np.mean(deltas) > 0 else 999
                has_regular = cv < 0.5

    return {
        'num_fixes': row.num_fixes,
        'num_animals': row.num_animals,
        'duration_hours': round(duration_hours, 2),
        'has_regular_sampling': has_regular,
        'median_sampling_rate_s': median_rate,
        'bbox': {
            'min_lat': row.lat_min, 'max_lat': row.lat_max,
            'min_lon': row.lon_min, 'max_lon': row.lon_max,
        } if row.lat_min is not None else None,
    }


def get_eligible_analyses(subset_stats):
    """Return list of AnalysisType objects that are eligible for this subset."""
    if subset_stats['num_fixes'] == 0:
        return []

    all_types = AnalysisType.query.all()
    eligible = []

    for atype in all_types:
        if subset_stats['num_fixes'] < (atype.min_fixes or 0):
            continue
        if subset_stats['num_animals'] < (atype.min_animals or 1):
            continue
        if subset_stats['duration_hours'] < (atype.min_duration_hours or 0):
            continue
        if atype.requires_regular_sampling and not subset_stats['has_regular_sampling']:
            continue
        eligible.append(atype)

    return eligible
