"""Autocorrelation analysis: empirical variograms for movement data."""

import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from flask import current_app


def _haversine(lat1, lon1, lat2, lon2):
    R = 6371000
    phi1, phi2 = np.radians(lat1), np.radians(lat2)
    dphi = np.radians(lat2 - lat1)
    dlam = np.radians(lon2 - lon1)
    a = np.sin(dphi / 2) ** 2 + np.cos(phi1) * np.cos(phi2) * np.sin(dlam / 2) ** 2
    return 2 * R * np.arcsin(np.sqrt(a))


def run_variogram(df, params, run_id=None):
    """Compute empirical semi-variogram of positions over time lags."""
    n_lags = params.get('n_lags', 30)
    max_lag_frac = params.get('max_lag_fraction', 0.5)
    results = {}
    artifacts = {}

    fig, ax = plt.subplots(figsize=(10, 6))

    for animal_id, grp in df.groupby('animal_id'):
        grp = grp.sort_values('timestamp').reset_index(drop=True)
        if len(grp) < 50:
            continue

        n = len(grp)
        max_lag = int(n * max_lag_frac)
        lag_indices = np.unique(np.linspace(1, max_lag, n_lags, dtype=int))

        lags_hours = []
        semivariances = []

        for lag in lag_indices:
            if lag >= n:
                break
            d = _haversine(
                grp['lat'].values[:-lag], grp['lon'].values[:-lag],
                grp['lat'].values[lag:], grp['lon'].values[lag:]
            )
            sv = np.mean(d ** 2) / 2
            dt = (grp['timestamp'].iloc[lag] - grp['timestamp'].iloc[0]).total_seconds() / 3600
            lags_hours.append(dt)
            semivariances.append(sv)

        if not lags_hours:
            continue

        lags_hours = np.array(lags_hours)
        semivariances = np.array(semivariances)

        # Detect if variogram plateaus (home range behavior)
        half = len(semivariances) // 2
        if half > 2:
            first_half_slope = np.polyfit(lags_hours[:half], semivariances[:half], 1)[0]
            second_half_slope = np.polyfit(lags_hours[half:], semivariances[half:], 1)[0]
            has_plateau = second_half_slope < first_half_slope * 0.3
        else:
            has_plateau = False

        results[animal_id] = {
            'max_semivariance_m2': round(float(semivariances.max()), 2),
            'has_range_resident_behavior': has_plateau,
            'n_lags_computed': len(lags_hours),
        }

        ax.plot(lags_hours, semivariances / 1e6, 'o-', alpha=0.7, label=animal_id, markersize=3)

    ax.set_xlabel('Time Lag (hours)')
    ax.set_ylabel('Semi-variance (km²)')
    ax.set_title('Empirical Variogram')
    ax.legend()

    if run_id:
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'variogram_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        artifacts['variogram_plot'] = f'variogram_{run_id}.png'
    plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}
