{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# E step, Kalman filter\n\nInner working of the E step :class:`GLE_analysisEM.GLE_Estimator`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport pandas as pd\nfrom matplotlib import pyplot as plt\nfrom GLE_analysisEM import GLE_Estimator, GLE_BasisTransform, sufficient_stats, sufficient_stats_hidden\n\n# Printing options\npd.set_option(\"display.max_rows\", None)\npd.set_option(\"display.max_columns\", None)\npd.set_option(\"display.width\", None)\npd.set_option(\"display.max_colwidth\", -1)\n\ndim_x = 1\ndim_h = 1\nrandom_state = None\nforce = -np.identity(dim_x)\nA = [[5, 1.0], [-1.0, 2.07]]\nC = np.identity(dim_x + dim_h)  #\nbasis = GLE_BasisTransform(basis_type=\"linear\")\ngenerator = GLE_Estimator(verbose=1, dim_x=dim_x, dim_h=dim_h, basis=basis, C_init=C, force_init=force, init_params=\"random\", random_state=random_state)\nX, idx, Xh = generator.sample(n_samples=10000, n_trajs=10, x0=0.0, v0=0.0)\ntraj_list_h = np.split(Xh, idx)\ntime = np.split(X, idx)[0][:, 0]\nfor n, traj in enumerate(traj_list_h):\n    traj_list_h[n] = traj_list_h[n][:-1, :]\n\nprint(generator.get_coefficients())\n\nest = GLE_Estimator(init_params=\"user\", dim_x=dim_x, dim_h=dim_h, basis=basis)\nest.set_init_coeffs(generator.get_coefficients())\nest.dt = time[1] - time[0]\nest._check_initial_parameters()\n\nXproc, idx = est.model_class.preprocessingTraj(est.basis, X, idx_trajs=idx)\ntraj_list = np.split(Xproc, idx)\nest.dim_coeffs_force = est.basis.nb_basis_elt_\n\ndatas = 0.0\nfor n, traj in enumerate(traj_list):\n    datas_visible = sufficient_stats(traj, est.dim_x)\n    zero_sig = np.zeros((len(traj), 2 * est.dim_h, 2 * est.dim_h))\n    muh = np.hstack((np.roll(traj_list_h[n], -1, axis=0), traj_list_h[n]))\n    datas += sufficient_stats_hidden(muh, zero_sig, traj, datas_visible, est.dim_x, est.dim_h, est.dim_coeffs_force) / len(traj_list)\n\nest._initialize_parameters(None)\nprint(est.get_coefficients())\nprint(\"Real datas\")\nprint(datas)\nnew_stat = 0.0\nnoise_corr = 0.0\nfor n, traj in enumerate(traj_list):\n    datas_visible = sufficient_stats(traj, est.dim_x)\n    muh, Sigh = est._e_step(traj)  # Compute hidden variable distribution\n    new_stat += sufficient_stats_hidden(muh, Sigh, traj, datas_visible, est.dim_x, est.dim_h, est.dim_coeffs_force) / len(traj_list)\nprint(\"Estimated datas\")\nprint(new_stat)\nprint(\"Diff\")\nprint((new_stat - datas) / np.abs(datas))\n\n\nPf = np.zeros((dim_x + dim_h, dim_x))\nPf[:dim_x, :dim_x] = 5e-3 * np.identity(dim_x)\nYX = new_stat[\"xdx\"].T - np.matmul(Pf, np.matmul(force, new_stat[\"bkx\"]))\nXX = new_stat[\"xx\"]\nA = -np.matmul(YX, np.linalg.inv(XX))\n\n\nPf = np.zeros((dim_x + dim_h, dim_x))\nPf[:dim_x, :dim_x] = 5e-3 * np.identity(dim_x)\n\n# A = generator.friction_coeffs\nprint(A)\nbkbk = np.matmul(Pf, np.matmul(np.matmul(force, np.matmul(new_stat[\"bkbk\"], force.T)), Pf.T))\nbkdx = np.matmul(Pf, np.matmul(force, new_stat[\"bkdx\"]))\nbkx = np.matmul(Pf, np.matmul(force, new_stat[\"bkx\"]))\n\nresiduals = new_stat[\"dxdx\"] + np.matmul(A, new_stat[\"xdx\"]) + np.matmul(A, new_stat[\"xdx\"]).T - bkdx.T - bkdx\nresiduals += np.matmul(A, np.matmul(new_stat[\"xx\"], A.T)) - np.matmul(A, bkx.T) - np.matmul(A, bkx.T).T + bkbk\nprint(residuals, generator.diffusion_coeffs)\n# SST = 0.5 * (residuals + residuals.T)\n\nfig, axs = plt.subplots(1, dim_h)\n# plt.show()\nfor k in range(dim_h):\n    axs.plot(time[:-1], muh[:, k], label=\"Prediction (with \\\\pm 2 \\\\sigma error lines)\", color=\"blue\")\n    axs.plot(time[:-1], muh[:, k] + 2 * np.sqrt(Sigh[:, k, k]), \"--\", color=\"blue\", linewidth=0.1)\n    axs.plot(time[:-1], muh[:, k] - 2 * np.sqrt(Sigh[:, k, k]), \"--\", color=\"blue\", linewidth=0.1)\n    axs.plot(time[:-1], traj_list_h[n][:, k], label=\"Real\", color=\"orange\")\n    axs.legend(loc=\"upper right\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}