import math
import numpy as np
import pandas as pd


def haversine(lat1, lon1, lat2, lon2):
    """Compute haversine distance in meters between two points."""
    R = 6371000
    phi1, phi2 = np.radians(lat1), np.radians(lat2)
    dphi = np.radians(lat2 - lat1)
    dlam = np.radians(lon2 - lon1)
    a = np.sin(dphi / 2) ** 2 + np.cos(phi1) * np.cos(phi2) * np.sin(dlam / 2) ** 2
    return 2 * R * np.arcsin(np.sqrt(a))


def compute_derived_fields(df):
    """Compute speed, step_length, turning_angle per animal.

    Expects df with columns: animal_id, timestamp, lat, lon, sorted by animal_id and timestamp.
    """
    df = df.sort_values(['animal_id', 'timestamp']).copy()

    # Step length
    df['prev_lat'] = df.groupby('animal_id')['lat'].shift(1)
    df['prev_lon'] = df.groupby('animal_id')['lon'].shift(1)
    df['prev_time'] = df.groupby('animal_id')['timestamp'].shift(1)

    mask = df['prev_lat'].notna()
    df.loc[mask, 'step_length'] = haversine(
        df.loc[mask, 'prev_lat'], df.loc[mask, 'prev_lon'],
        df.loc[mask, 'lat'], df.loc[mask, 'lon']
    )

    # Speed (m/s)
    dt = (df['timestamp'] - df['prev_time']).dt.total_seconds()
    df.loc[mask & (dt > 0), 'speed'] = df.loc[mask & (dt > 0), 'step_length'] / dt[mask & (dt > 0)]

    # Turning angle
    df['dx'] = df['lon'] - df['prev_lon']
    df['dy'] = df['lat'] - df['prev_lat']
    df['prev_dx'] = df.groupby('animal_id')['dx'].shift(1)
    df['prev_dy'] = df.groupby('animal_id')['dy'].shift(1)

    has_prev2 = df['prev_dx'].notna()
    if has_prev2.any():
        angle1 = np.arctan2(df.loc[has_prev2, 'prev_dy'], df.loc[has_prev2, 'prev_dx'])
        angle2 = np.arctan2(df.loc[has_prev2, 'dy'], df.loc[has_prev2, 'dx'])
        ta = angle2 - angle1
        # Wrap to [-pi, pi]
        ta = (ta + math.pi) % (2 * math.pi) - math.pi
        df.loc[has_prev2, 'turning_angle'] = ta

    df.drop(columns=['prev_lat', 'prev_lon', 'prev_time', 'dx', 'dy', 'prev_dx', 'prev_dy'],
            inplace=True)

    return df


def parse_csv_preview(file_path, nrows=20):
    """Read first N rows of a CSV to detect columns for mapping."""
    df = pd.read_csv(file_path, nrows=nrows)
    return {
        'columns': list(df.columns),
        'dtypes': {col: str(dt) for col, dt in df.dtypes.items()},
        'sample': df.head(5).to_dict(orient='records'),
    }


def ingest_csv(file_path, column_mapping):
    """Parse full CSV using the column mapping and return a DataFrame
    ready for insertion into data_points.

    column_mapping: dict with keys 'timestamp', 'lat', 'lon', 'animal_id'
                    mapped to the actual CSV column names.
    """
    df = pd.read_csv(file_path)

    # Rename mapped columns
    rename = {
        column_mapping['timestamp']: 'timestamp',
        column_mapping['lat']: 'lat',
        column_mapping['lon']: 'lon',
        column_mapping['animal_id']: 'animal_id',
    }
    df.rename(columns=rename, inplace=True)

    # Parse timestamp
    df['timestamp'] = pd.to_datetime(df['timestamp'], infer_datetime_format=True)

    # Ensure numeric lat/lon
    df['lat'] = pd.to_numeric(df['lat'], errors='coerce')
    df['lon'] = pd.to_numeric(df['lon'], errors='coerce')
    df['animal_id'] = df['animal_id'].astype(str)

    # Drop rows with missing required fields
    df.dropna(subset=['timestamp', 'lat', 'lon', 'animal_id'], inplace=True)

    # Collect extra columns as sensor_data JSON
    core_cols = {'timestamp', 'lat', 'lon', 'animal_id'}
    extra_cols = [c for c in df.columns if c not in core_cols]
    if extra_cols:
        df['sensor_data'] = df[extra_cols].apply(
            lambda row: {k: v for k, v in row.items() if pd.notna(v)}, axis=1
        )
        df.drop(columns=extra_cols, inplace=True)
    else:
        df['sensor_data'] = None

    # Compute derived fields
    df = compute_derived_fields(df)

    return df
