"""Basic movement metrics: speed/distance summary, net squared displacement, first passage time."""

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
from flask import current_app


def _haversine(lat1, lon1, lat2, lon2):
    R = 6371000
    phi1, phi2 = np.radians(lat1), np.radians(lat2)
    dphi = np.radians(lat2 - lat1)
    dlam = np.radians(lon2 - lon1)
    a = np.sin(dphi / 2) ** 2 + np.cos(phi1) * np.cos(phi2) * np.sin(dlam / 2) ** 2
    return 2 * R * np.arcsin(np.sqrt(a))


def run_speed_distance(df, params, run_id=None):
    """Compute speed and distance summary statistics per animal."""
    results = {}
    for animal_id, grp in df.groupby('animal_id'):
        grp = grp.sort_values('timestamp')
        steps = grp['step_length'].dropna()
        speeds = grp['speed'].dropna()

        total_dist = steps.sum()
        results[animal_id] = {
            'total_distance_m': round(float(total_dist), 2),
            'total_distance_km': round(float(total_dist / 1000), 3),
            'mean_speed_ms': round(float(speeds.mean()), 4) if len(speeds) > 0 else None,
            'max_speed_ms': round(float(speeds.max()), 4) if len(speeds) > 0 else None,
            'mean_step_length_m': round(float(steps.mean()), 2) if len(steps) > 0 else None,
            'num_fixes': len(grp),
        }

    # Plot speed distributions
    artifacts = {}
    if run_id:
        fig, ax = plt.subplots(figsize=(10, 6))
        for animal_id, grp in df.groupby('animal_id'):
            speeds = grp['speed'].dropna()
            if len(speeds) > 0:
                ax.hist(speeds, bins=50, alpha=0.5, label=animal_id)
        ax.set_xlabel('Speed (m/s)')
        ax.set_ylabel('Count')
        ax.set_title('Speed Distribution')
        ax.legend()
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'speed_dist_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        plt.close(fig)
        artifacts['speed_distribution'] = f'speed_dist_{run_id}.png'

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}


def run_nsd(df, params, run_id=None):
    """Net Squared Displacement from first fix per animal."""
    results = {}
    artifacts = {}

    fig, ax = plt.subplots(figsize=(10, 6))

    for animal_id, grp in df.groupby('animal_id'):
        grp = grp.sort_values('timestamp')
        if len(grp) < 2:
            continue
        origin_lat = grp.iloc[0]['lat']
        origin_lon = grp.iloc[0]['lon']
        dist = _haversine(origin_lat, origin_lon, grp['lat'].values, grp['lon'].values)
        nsd = dist ** 2

        hours = (grp['timestamp'] - grp['timestamp'].iloc[0]).dt.total_seconds() / 3600
        results[animal_id] = {
            'max_nsd_m2': round(float(nsd.max()), 2),
            'max_displacement_m': round(float(np.sqrt(nsd.max())), 2),
        }
        ax.plot(hours, nsd / 1e6, label=animal_id, alpha=0.7)

    ax.set_xlabel('Time (hours)')
    ax.set_ylabel('NSD (km²)')
    ax.set_title('Net Squared Displacement')
    ax.legend()

    if run_id:
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'nsd_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        artifacts['nsd_plot'] = f'nsd_{run_id}.png'
    plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}


def run_fpt(df, params, run_id=None):
    """First Passage Time analysis."""
    radii = params.get('radii', [100, 500, 1000, 5000])  # meters
    results = {}
    artifacts = {}

    for animal_id, grp in df.groupby('animal_id'):
        grp = grp.sort_values('timestamp').reset_index(drop=True)
        if len(grp) < 10:
            continue

        animal_results = {}
        for radius in radii:
            fpt_values = []
            for i in range(len(grp)):
                j = i
                while j < len(grp):
                    dist = _haversine(grp.loc[i, 'lat'], grp.loc[i, 'lon'],
                                      grp.loc[j, 'lat'], grp.loc[j, 'lon'])
                    if dist > radius:
                        dt = (grp.loc[j, 'timestamp'] - grp.loc[i, 'timestamp']).total_seconds()
                        fpt_values.append(dt)
                        break
                    j += 1

            if fpt_values:
                animal_results[f'radius_{radius}m'] = {
                    'mean_fpt_s': round(float(np.mean(fpt_values)), 2),
                    'variance_fpt_s': round(float(np.var(fpt_values)), 2),
                    'log_var_fpt': round(float(np.log(np.var(fpt_values))) if np.var(fpt_values) > 0 else 0, 4),
                }
        results[animal_id] = animal_results

    # Plot log(variance) vs radius
    if run_id and results:
        fig, ax = plt.subplots(figsize=(10, 6))
        for animal_id, ar in results.items():
            rads = []
            logvars = []
            for key, val in ar.items():
                r = int(key.split('_')[1].replace('m', ''))
                rads.append(r)
                logvars.append(val['log_var_fpt'])
            ax.plot(rads, logvars, 'o-', label=animal_id)
        ax.set_xlabel('Radius (m)')
        ax.set_ylabel('log(Variance of FPT)')
        ax.set_title('First Passage Time Analysis')
        ax.legend()
        path = os.path.join(current_app.config['RESULTS_FOLDER'], f'fpt_{run_id}.png')
        fig.savefig(path, dpi=100, bbox_inches='tight')
        artifacts['fpt_plot'] = f'fpt_{run_id}.png'
        plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}
