"""Home range estimation: MCP, KDE, LoCoH."""

import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from scipy.stats import gaussian_kde
from flask import current_app


def _latlon_to_meters(lat, lon, ref_lat, ref_lon):
    """Simple equirectangular projection to meters from a reference point."""
    R = 6371000
    x = np.radians(lon - ref_lon) * R * np.cos(np.radians(ref_lat))
    y = np.radians(lat - ref_lat) * R
    return x, y


def run_mcp(df, params, run_id=None):
    """Minimum Convex Polygon home range."""
    pct = params.get('percent', 100)
    results = {}
    artifacts = {}
    fig, ax = plt.subplots(figsize=(10, 10))

    for animal_id, grp in df.groupby('animal_id'):
        if len(grp) < 5:
            continue

        ref_lat = grp['lat'].mean()
        ref_lon = grp['lon'].mean()
        x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)

        if pct < 100:
            # Keep only points closest to centroid
            cx, cy = np.mean(x), np.mean(y)
            dists = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
            threshold = np.percentile(dists, pct)
            mask = dists <= threshold
            x, y = x[mask], y[mask]

        points = np.column_stack([x, y])
        if len(points) < 3:
            continue
        hull = ConvexHull(points)
        area_m2 = hull.volume  # In 2D, volume = area
        area_km2 = area_m2 / 1e6

        results[animal_id] = {
            'area_m2': round(float(area_m2), 2),
            'area_km2': round(float(area_km2), 6),
            'area_ha': round(float(area_m2 / 10000), 4),
            'num_vertices': len(hull.vertices),
            'percent': pct,
        }

        # Plot
        hull_pts = points[hull.vertices]
        hull_pts = np.vstack([hull_pts, hull_pts[0]])  # close polygon
        ax.plot(x, y, '.', alpha=0.3, markersize=2)
        ax.plot(hull_pts[:, 0], hull_pts[:, 1], '-', linewidth=2, label=f'{animal_id} ({area_km2:.3f} km²)')

    ax.set_xlabel('X (m)')
    ax.set_ylabel('Y (m)')
    ax.set_title(f'Minimum Convex Polygon ({pct}%)')
    ax.legend()
    ax.set_aspect('equal')

    if run_id:
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'mcp_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        artifacts['mcp_plot'] = f'mcp_{run_id}.png'
    plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}


def run_kde(df, params, run_id=None):
    """Kernel Density Estimation home range."""
    contour_levels = params.get('contour_levels', [50, 95])
    grid_size = params.get('grid_size', 100)
    results = {}
    artifacts = {}

    for animal_id, grp in df.groupby('animal_id'):
        if len(grp) < 30:
            continue

        ref_lat = grp['lat'].mean()
        ref_lon = grp['lon'].mean()
        x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)

        try:
            kernel = gaussian_kde(np.vstack([x, y]))
        except np.linalg.LinAlgError:
            continue

        # Create evaluation grid
        xmin, xmax = x.min() - 1000, x.max() + 1000
        ymin, ymax = y.min() - 1000, y.max() + 1000
        xi, yi = np.mgrid[xmin:xmax:complex(grid_size), ymin:ymax:complex(grid_size)]
        coords = np.vstack([xi.ravel(), yi.ravel()])
        zi = kernel(coords).reshape(xi.shape)

        # Compute contour areas
        animal_results = {}
        cell_area = (xmax - xmin) / grid_size * (ymax - ymin) / grid_size

        for level in sorted(contour_levels):
            threshold = np.percentile(zi, 100 - level)
            area_m2 = float(np.sum(zi >= threshold) * cell_area)
            animal_results[f'{level}pct_area_km2'] = round(area_m2 / 1e6, 6)
            animal_results[f'{level}pct_area_ha'] = round(area_m2 / 10000, 4)

        animal_results['bandwidth'] = float(kernel.factor)
        results[animal_id] = animal_results

        # Plot
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.plot(x, y, '.', alpha=0.2, markersize=2, color='black')
        cs = ax.contour(xi, yi, zi, levels=[np.percentile(zi, 100 - l) for l in sorted(contour_levels)])
        ax.clabel(cs, fmt={lev: f'{pct}%' for lev, pct in zip(cs.levels, sorted(contour_levels))})
        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_title(f'KDE Home Range - {animal_id}')
        ax.set_aspect('equal')

        if run_id:
            path = os.path.join(current_app.config['RESULTS_FOLDER'], f'kde_{run_id}_{animal_id}.png')
            fig.savefig(path, dpi=100, bbox_inches='tight')
            artifacts[f'kde_plot_{animal_id}'] = f'kde_{run_id}_{animal_id}.png'
        plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}


def run_locoh(df, params, run_id=None):
    """Local Convex Hull (LoCoH) home range - k-LoCoH method."""
    k = params.get('k', None)  # Auto if None
    results = {}
    artifacts = {}

    for animal_id, grp in df.groupby('animal_id'):
        if len(grp) < 20:
            continue

        ref_lat = grp['lat'].mean()
        ref_lon = grp['lon'].mean()
        x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)
        points = np.column_stack([x, y])
        n = len(points)

        if k is None:
            k_use = min(max(int(np.sqrt(n)), 5), n - 1)
        else:
            k_use = min(k, n - 1)

        # For each point, find k nearest neighbors and compute local hull
        from scipy.spatial import cKDTree
        tree = cKDTree(points)
        hull_areas = []
        all_hull_points = []

        for i in range(n):
            dists, indices = tree.query(points[i], k=k_use + 1)
            local_pts = points[indices]
            if len(local_pts) < 3:
                continue
            try:
                local_hull = ConvexHull(local_pts)
                hull_areas.append(local_hull.volume)
                all_hull_points.extend(local_pts[local_hull.vertices].tolist())
            except Exception:
                continue

        if not hull_areas:
            continue

        # Union area approximation: total MCP of all hull vertices
        all_hull_points = np.array(all_hull_points)
        if len(all_hull_points) >= 3:
            try:
                union_hull = ConvexHull(all_hull_points)
                total_area = union_hull.volume
            except Exception:
                total_area = sum(hull_areas)
        else:
            total_area = sum(hull_areas)

        results[animal_id] = {
            'area_km2': round(float(total_area / 1e6), 6),
            'area_ha': round(float(total_area / 10000), 4),
            'k': k_use,
            'num_local_hulls': len(hull_areas),
            'mean_local_hull_area_m2': round(float(np.mean(hull_areas)), 2),
        }

    if run_id and results:
        fig, ax = plt.subplots(figsize=(10, 10))
        for animal_id, grp in df.groupby('animal_id'):
            ref_lat = grp['lat'].mean()
            ref_lon = grp['lon'].mean()
            x, y = _latlon_to_meters(grp['lat'].values, grp['lon'].values, ref_lat, ref_lon)
            ax.plot(x, y, '.', alpha=0.3, markersize=2, label=animal_id)
        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_title(f'LoCoH Home Range (k={k_use if results else "?"})')
        ax.legend()
        ax.set_aspect('equal')
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'locoh_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        artifacts['locoh_plot'] = f'locoh_{run_id}.png'
        plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}
