"""Space use analysis: overlap indices and Brownian Bridge Movement Model."""

import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from flask import current_app


def _latlon_to_meters(lat, lon, ref_lat, ref_lon):
    R = 6371000
    x = np.radians(lon - ref_lon) * R * np.cos(np.radians(ref_lat))
    y = np.radians(lat - ref_lat) * R
    return x, y


def run_overlap(df, params, run_id=None):
    """Compute space use overlap between animals using UD overlap indices."""
    grid_size = params.get('grid_size', 100)
    animals = df['animal_id'].unique()

    if len(animals) < 2:
        return {'summary': {'error': 'Need at least 2 animals for overlap analysis'}, 'artifacts': {}}

    ref_lat = df['lat'].mean()
    ref_lon = df['lon'].mean()

    # Compute KDE for each animal
    uds = {}
    for animal_id in animals:
        grp = df[df['animal_id'] == animal_id]
        if len(grp) < 10:
            continue
        x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)
        try:
            kernel = gaussian_kde(np.vstack([x, y]))
            uds[animal_id] = kernel
        except np.linalg.LinAlgError:
            continue

    if len(uds) < 2:
        return {'summary': {'error': 'Not enough animals with valid KDE'}, 'artifacts': {}}

    # Common grid
    all_x, all_y = _latlon_to_meters(df['lat'].values, df['lon'].values, ref_lat, ref_lon)
    xmin, xmax = all_x.min() - 1000, all_x.max() + 1000
    ymin, ymax = all_y.min() - 1000, all_y.max() + 1000
    xi, yi = np.mgrid[xmin:xmax:complex(grid_size), ymin:ymax:complex(grid_size)]
    coords = np.vstack([xi.ravel(), yi.ravel()])

    # Evaluate densities
    densities = {}
    for aid, kernel in uds.items():
        z = kernel(coords)
        z = z / z.sum()  # normalize to UD
        densities[aid] = z

    # Compute Bhattacharyya's Affinity (BA) for each pair
    overlap_matrix = {}
    animal_list = list(densities.keys())
    for i in range(len(animal_list)):
        for j in range(i + 1, len(animal_list)):
            a1, a2 = animal_list[i], animal_list[j]
            ba = float(np.sum(np.sqrt(densities[a1] * densities[a2])))
            vi = float(np.sum(np.minimum(densities[a1], densities[a2])))
            overlap_matrix[f'{a1}_vs_{a2}'] = {
                'bhattacharyya_affinity': round(ba, 4),
                'volume_intersection': round(vi, 4),
            }

    artifacts = {}
    if run_id:
        n = len(animal_list)
        fig, ax = plt.subplots(figsize=(10, 10))
        for aid in animal_list:
            grp = df[df['animal_id'] == aid]
            x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)
            ax.plot(x, y, '.', alpha=0.2, markersize=2, label=aid)
        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_title('Space Use Overlap')
        ax.legend()
        ax.set_aspect('equal')
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'overlap_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        artifacts['overlap_plot'] = f'overlap_{run_id}.png'
        plt.close(fig)

    return {'summary': {'overlap_indices': overlap_matrix}, 'artifacts': artifacts}


def run_bbmm(df, params, run_id=None):
    """Brownian Bridge Movement Model home range."""
    grid_size = params.get('grid_size', 100)
    sig1 = params.get('sig1', None)  # Brownian motion variance (auto-estimate if None)
    sig2 = params.get('sig2', 20)    # Location error (meters)
    contour_levels = params.get('contour_levels', [50, 95])
    results = {}
    artifacts = {}

    for animal_id, grp in df.groupby('animal_id'):
        grp = grp.sort_values('timestamp').reset_index(drop=True)
        if len(grp) < 30:
            continue

        ref_lat = grp['lat'].mean()
        ref_lon = grp['lon'].mean()
        x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)
        t = (grp['timestamp'] - grp['timestamp'].iloc[0]).dt.total_seconds().values

        # Estimate sig1 from data if not provided
        if sig1 is None:
            dt = np.diff(t)
            dx = np.diff(x)
            dy = np.diff(y)
            valid = dt > 0
            sigma1_est = np.sqrt(np.mean((dx[valid] ** 2 + dy[valid] ** 2) / dt[valid]) / 2)
        else:
            sigma1_est = sig1

        # Grid
        pad = 3 * sigma1_est * np.sqrt(np.max(np.diff(t)))
        xmin, xmax = x.min() - pad, x.max() + pad
        ymin, ymax = y.min() - pad, y.max() + pad
        xi, yi = np.mgrid[xmin:xmax:complex(grid_size), ymin:ymax:complex(grid_size)]

        # Compute Brownian bridge density
        density = np.zeros(xi.shape)
        for i in range(len(x) - 1):
            dt_i = t[i + 1] - t[i]
            if dt_i <= 0:
                continue

            # For each grid point, compute BB contribution from segment i
            for alpha in np.linspace(0, 1, 20):
                mu_x = x[i] + alpha * (x[i + 1] - x[i])
                mu_y = y[i] + alpha * (y[i + 1] - y[i])
                var = sigma1_est ** 2 * dt_i * alpha * (1 - alpha) + sig2 ** 2
                density += np.exp(-((xi - mu_x) ** 2 + (yi - mu_y) ** 2) / (2 * var)) / (2 * np.pi * var)

        density /= (len(x) - 1) * 20  # normalize

        # Compute contour areas
        cell_area = (xmax - xmin) / grid_size * (ymax - ymin) / grid_size
        animal_results = {'sig1_estimated': round(float(sigma1_est), 2), 'sig2': sig2}
        for level in sorted(contour_levels):
            threshold = np.percentile(density[density > 0], 100 - level) if density.max() > 0 else 0
            area_m2 = float(np.sum(density >= threshold) * cell_area)
            animal_results[f'{level}pct_area_km2'] = round(area_m2 / 1e6, 6)
            animal_results[f'{level}pct_area_ha'] = round(area_m2 / 10000, 4)

        results[animal_id] = animal_results

        # Plot
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.contourf(xi, yi, density, levels=20, cmap='YlOrRd')
        ax.plot(x, y, 'k-', alpha=0.3, linewidth=0.5)
        ax.plot(x, y, 'k.', alpha=0.5, markersize=1)
        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_title(f'BBMM - {animal_id}')
        ax.set_aspect('equal')

        if run_id:
            path = os.path.join(current_app.config['RESULTS_FOLDER'], f'bbmm_{run_id}_{animal_id}.png')
            fig.savefig(path, dpi=100, bbox_inches='tight')
            artifacts[f'bbmm_plot_{animal_id}'] = f'bbmm_{run_id}_{animal_id}.png'
        plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}
