from models.model import TPNNS
from models.tpnn import ATPNN, TPNN
import csv
import os
from collections import Counter
import numpy as np

def str_to_csv(str_list : list[str], file_name: str, header_str: str):
    with open(file_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        header = header_str.split(',')
        writer.writerow(header)
        
        for row_str in str_list:
            row = row_str.split(',')
            writer.writerow(row)

def tpnn_info(tpnn: TPNN):
    height = tpnn.height
    z = tpnn.z
    variables = []; bs = []; gammas = []
    for atpnn in tpnn.structure:
        variables.append(atpnn.variable)
        bs.append(atpnn.b)
        gammas.append(atpnn.gamma)
    return (height, z, variables, bs, gammas)

def make_tpnns_log(tpnns: TPNNS, config: dict, log_max_depth: int=5) -> str:
    assert len(tpnns.model) == config['K_max']      
    adc_log = tpnns.adc_log                         # 'gppgs'
    height_log = tpnns.height_log                   # 'H_H__H'

    tpnns_log = ''
    
    if config['y_dist'] == 'normal':
        tpnns_log += str(tpnns.nui['sigma2'])[:7] + ','
    else :
        tpnns_log += str(tpnns.nui)[:7] + ','

    for t, tpnn in enumerate(tpnns.model):
        adc_t = adc_log[t]     # g/p/c/s str
        height_t, z_t, variables, bs, gammas = tpnn_info(tpnn)
        height_updated = (height_log[t] == 'H')
        tpnn_log_t = ','.join([adc_t, str(height_updated), str(height_t)[:5], str(z_t)])
        tpnn_log_t += ','

        depth_t = 0
        for v, b, g in zip(variables, bs, gammas):
            tpnn_log_t += ','.join([str(v), str(b)[:5], str(g)[:5]]) + ','
            depth_t += 1
        if depth_t < log_max_depth:
            tpnn_log_t += ',,,' * (log_max_depth - depth_t)
        tpnns_log += tpnn_log_t

    tpnns_log = tpnns_log[:-1]
    return tpnns_log

def tpnns_log_to_csv(samples: list[TPNNS], config: dict):
    log_str_list = []
    samples_max_depth = 0
    for tpnns in samples:
        for tpnn in tpnns.model:
            if samples_max_depth < len(tpnn.structure):
                samples_max_depth = len(tpnn.structure)

    for tpnns in samples:
        tpnns_log_str = make_tpnns_log(tpnns, config, samples_max_depth)
        log_str_list.append(tpnns_log_str)

    file_path = os.path.join(config['experiment_name'], 'tpnns_log.csv')

    # header
    tpnn_header = ','.join(['ADCS', 'H', 'height', 'z'])+','
    for d in range(samples_max_depth):
        tpnn_header += f'var_{d+1},b_{d+1},gam_{d+1},'
    tpnn_header *= config['K_max']

    if config['y_dist'] == 'normal':
        header = 'error_var,' + tpnn_header[:-1]
    else :
        header = 'beta_0,' + tpnn_header[:-1]

    str_to_csv(log_str_list, file_path, header)
        
def update_log_to_csv(samples: list[TPNNS], config: dict):
    # tpnns log str list
    tpnns_log_str_list = []
    for tpnns in samples:
        tpnns_log_str = ''

        tpnns_adcs = Counter(tpnns.adc_log)
        tpnns_adcs = [tpnns_adcs[i] for i in ['a', 'd', 'c', 's']]  # list[int]
        tpnns_log_str += ','.join([str(s) for s in tpnns_adcs]) + ','

        tpnns_bg = Counter(tpnns.bg_log)
        tpnns_bg = [tpnns_bg[i] for i in ['Y', '_']]              # list[int]
        tpnns_log_str += ','.join([str(s) for s in tpnns_bg]) + ','

        tpnns_height_updates = Counter(tpnns.height_log)['H']
        tpnns_z_l1 = np.sum(tpnns.zs).astype(int).item()
        tpnns_z_updated = tpnns.z_updated
        tpnns_train_metric = tpnns.train_metric; tpnns_test_metric = tpnns.test_metric      # float
        hzm_logs = [tpnns_height_updates, tpnns_z_l1, tpnns_z_updated, tpnns_train_metric, tpnns_test_metric]
        tpnns_log_str += ','.join([str(ss) for ss in hzm_logs])

        if config['y_dist'] == 'normal':
            tpnns_log_str += ',' + str(tpnns.nui['sigma2'])[:7]
        
        tpnns_log_str_list.append(tpnns_log_str)

    # file path
    file_path = os.path.join(config['experiment_name'], 'update_log.csv')

    # header
    header = '#A,#D,#C,#S,#bg,#bg-reject,#H,|z|,z upd,train,test'
    if config['y_dist'] == 'normal':
        header += ',var_error'

    str_to_csv(tpnns_log_str_list, file_path, header)
            

