"""Movement models: Hidden Markov Models."""

import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from flask import current_app


def run_hmm(df, params, run_id=None):
    """Fit a Hidden Markov Model to step lengths and turning angles."""
    from hmmlearn.hmm import GaussianHMM

    n_states = params.get('n_states', 2)
    results = {}
    artifacts = {}

    for animal_id, grp in df.groupby('animal_id'):
        grp = grp.sort_values('timestamp').dropna(subset=['step_length', 'turning_angle'])
        if len(grp) < 100:
            continue

        # Feature matrix: log(step_length + 1) and turning_angle
        sl = np.log1p(grp['step_length'].values)
        ta = grp['turning_angle'].values
        X = np.column_stack([sl, ta])

        model = GaussianHMM(n_components=n_states, covariance_type='full',
                            n_iter=200, random_state=42)
        model.fit(X)
        states = model.predict(X)

        # Summarize state characteristics
        state_info = {}
        for s in range(n_states):
            mask = states == s
            state_info[f'state_{s}'] = {
                'proportion': round(float(mask.mean()), 4),
                'mean_step_length_m': round(float(np.expm1(sl[mask]).mean()), 2) if mask.any() else 0,
                'mean_turning_angle_rad': round(float(ta[mask].mean()), 4) if mask.any() else 0,
                'sd_step_length': round(float(np.expm1(sl[mask]).std()), 2) if mask.any() else 0,
            }

        results[animal_id] = {
            'n_states': n_states,
            'log_likelihood': round(float(model.score(X)), 2),
            'aic': round(float(-2 * model.score(X) + 2 * model._get_n_fit_scalars_per_param().sum()), 2),
            'states': state_info,
            'transition_matrix': model.transmat_.round(4).tolist(),
        }

        # Plot state assignments over time
        fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)
        times = np.arange(len(grp))

        colors = plt.cm.Set1(np.linspace(0, 1, n_states))
        for s in range(n_states):
            mask = states == s
            axes[0].scatter(times[mask], np.expm1(sl[mask]), c=[colors[s]], alpha=0.3,
                           s=5, label=f'State {s}')
        axes[0].set_ylabel('Step Length (m)')
        axes[0].legend()

        for s in range(n_states):
            mask = states == s
            axes[1].scatter(times[mask], ta[mask], c=[colors[s]], alpha=0.3, s=5)
        axes[1].set_ylabel('Turning Angle (rad)')

        axes[2].plot(times, states, drawstyle='steps-post')
        axes[2].set_ylabel('State')
        axes[2].set_xlabel('Fix Index')
        axes[2].set_yticks(range(n_states))

        fig.suptitle(f'HMM States - {animal_id} ({n_states} states)')

        if run_id:
            path = os.path.join(current_app.config['RESULTS_FOLDER'], f'hmm_{run_id}_{animal_id}.png')
            fig.savefig(path, dpi=100, bbox_inches='tight')
            artifacts[f'hmm_plot_{animal_id}'] = f'hmm_{run_id}_{animal_id}.png'
        plt.close(fig)

    return {'summary': {'per_animal': results}, 'artifacts': artifacts}
