from neuron import h
import numpy as np
import matplotlib.pyplot as plt
from neuron.units import ms, mV
import pandas as pd
from Cell_setup import cell

"""
Rewrite here if you want to change simulation parameter
"""

"""Initialization Setup"""
V0 = -75 #mV

timestep_init = 0.025 #ms
RunTime_init = 3000 #ms

"""Main Setup"""
timestep_main = 0.025 #ms
RunTime_main = 1000 #ms, total runtime will be RunTime_init+RunTime_main

"""Synaptic input"""
starttime_syn = 10 #ms, time after main run start
interval_syn = 35 #ms
inputsite = 7 #synaptic input array is inserted to this site

#inputloc = [1/22 + 1/11*i for i in range(0,11)] #to distal
inputloc = [1-(1/22+1/11*i) for i in range(0,11)] #to soma

gmax_syn = 3.264e-05 #peak synaptic conductance, uS of single input
Erev_syn = 0 #mV, reversal potential
tau1_syn = 8 #ms, rising time constant
tau2_syn = 25 #ms, decay time constant

"""Recording setup"""
Recordtimestep = 0.5 #preferred to be multiple of timestep
RecordStartTime = 2800 #must be 0<=RecordStartTime<RecordEndTime
RecordEndTime = 3900 #must be <RunTime_init+RunTime_main

#0 -> record off, 1 -> record on
Rec = [1 #voltage
      ,0 #[Na]i
      ,0 #[Na]o
      ,0 #[K]i
      ,0 #[K]o
      ,0 #[Cl]i
      ,0 #[Cl]o
      ,0 #INa leak
      ,0 #IK leak
      ,0 #ICl leak
      ,0 #flux of pump, x3 then INa pump, x-2 then IK pump
      ,0 #y (state of pump) pump
      ,0 #rate of KCC2, x1 then IK KCC2, x-1 then ICl KCC2
      ,0 #ICl of ClC2
      ,0 #m (open probability) ClC2
      ,0 #minf (steady state open probability) ClC2
      ,0 #INa of TTX-R Nav
      ,0 #m (open probability of activation gate) TTX-R Nav
      ,0 #h (open probability of inactivation gate) TTX-R Nav
      ]

#0 -> do not make csv file, 1 -> make csv
mkcsv = 1


"""
From here is code for NEURON
No need to edit
"""

cell1 = cell()
cell1.setup()

h.load_file("stdrun.hoc")

"""stimulation setup"""

stims = [0]*10
syns = [0]*10
ncstims = [0]*10

for i in range(0,10):
    stims[i] = h.NetStim()
    syns[i] = h.Exp2Syn(cell1.dend[inputsite](inputloc[i]))
    syns[i].e = Erev_syn * mV
    syns[i].tau1 = tau1_syn * ms
    syns[i].tau2 = tau2_syn * ms
    
    stims[i].number = 1
    stims[i].start = (RunTime_init + starttime_syn + 35*i) * ms
    
    ncstims[i] = h.NetCon(stims[i], syns[i])
    ncstims[i].delay = 0
    ncstims[i].weight[0] = gmax_syn


"""recording setup"""

Rec_ref = ["v" #voltage
          ,"nai" #[Na]i
          ,"nao" #[Na]o
          ,"ki" #[K]i
          ,"ko" #[K]o
          ,"cli" #[Cl]i
          ,"clo" #[Cl]o
          ,"ina_leak" #INa leak
          ,"ik_leak" #IK leak
          ,"icl_leak" #ICl leak
          ,"flux_NaKpump" #flux of pump, x3 then INa pump, x-2 then IK pump
          ,"yy_NaKpump" #y (state of pump) pump
          ,"ik_KCC2" #rate of KCC2, x1 then IK KCC2, x-1 then ICl KCC2
          ,"icl_ClC2" #ICl of ClC2
          ,"m_ClC2" #m (open probability) ClC2
          ,"minf_ClC2" #minf (steady state open probability) ClC2
          ,"ina_TTXRNa" #INa of TTX-R Nav
          ,"mn_TTXRNa" #m (open probability of activation gate) TTX-R Nav
          ,"hn_TTXRNa" #h (open probability of inactivation gate) TTX-R Nav
          ]

tt = h.Vector()
tt.record(h._ref_t,Recordtimestep)

records = [0] * (cell1.NofSection + cell1.NofSecinBranches)
secname = [0] * len(records)

for i in range(0,len(records)):
    records[i] = [0] * len(Rec)
    if i < cell1.NofSection:
        secref_str = "dend[i]"
        if i == 0:
            secname[i] = "soma"
        else:
            secname[i] = "dend" + str(i)
    else:
        secref_str = "branch[0][i-cell1.NofSection]"
        secname[i] = "branch 0-"+str(i-cell1.NofSection)
    
    for j in range(0,len(Rec)):
        if Rec[j] == 1:
            records[i][j] = h.Vector()
            execode = "records[i][j].record(cell1."+secref_str+"(0.5)._ref_"+Rec_ref[j]+",Recordtimestep)"
            exec(execode)

"""simultation setup""" 
h.dt = timestep_init
tstop = RunTime_init *ms
v_init = V0 * mV

"""Initialize"""
h.celsius = 27.0

cell1.initialize_ion()

h.finitialize(v_init)

"""Run"""
h.continuerun(tstop)

h.dt = timestep_main
tstop = (RunTime_init+RunTime_main) *ms

h.continuerun(tstop)
print("main run end")

"""Show Graph"""
tt_nparr = np.zeros(tt.size())
tt_nparr = tt.to_python(tt_nparr)

start_index = np.where(tt_nparr>=RecordStartTime)[0][0]
end_index = np.where(tt_nparr>=RecordEndTime)[0][0]

if mkcsv == 1:
    savedata = [0]*sum(Rec)
    savedata_attrname = [0]*sum(Rec)
    it = 0
    for i in range(0, len(Rec)):
        if Rec[i] == 1:
            savedata[it] = np.zeros((end_index-start_index,1+len(records)))
            savedata[it][:,0] = tt_nparr[start_index:end_index]
            savedata_attrname[it] = Rec_ref[i]
            it = it + 1
    
it = 0
for i in range(0,len(Rec)):

    if Rec[i] == 1:
        for j in range(0, len(records)):
            rec_nparr = records[j][i].to_python()
            plt.plot(tt_nparr[start_index:end_index],rec_nparr[start_index:end_index],label=secname[j])
            if mkcsv == 1:
                savedata[it][:,1+j] = rec_nparr[start_index:end_index]
        it = it + 1
        plt.title(Rec_ref[i])
        plt.legend(fontsize=9,loc="upper left")
        plt.show()

if mkcsv == 1:
    for i in range(0,len(savedata)):
        pd.DataFrame(data=savedata[i],columns=["time"]+secname).to_csv(savedata_attrname[i]+".csv",index=False)

if Rec[0] == 1: #voltage
    v_amp = np.zeros(len(records))
    base_index = np.where(tt_nparr>=RunTime_init+starttime_syn-timestep_main)[0][0]
    for i in range(0, len(records)):
        rec_nparr = records[i][0].to_python()
        v_amp[i] = np.amax(rec_nparr[base_index:])-rec_nparr[base_index]
        print("EPSP amp "+secname[i]+"= "+str(v_amp[i]))
    if mkcsv ==1:
        pd.DataFrame(data=v_amp,index=secname).to_csv("amp.csv")
    
        
cell1 = None
