import os
import argparse

import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json
from tqdm import trange

from fig_atscore_curve import read_all_data, \
    color_palette, model_list, value_mapper, seg_indices
    
PLOT_IDX, LAYER_IDX, IS_REPLACE = -1, 1, False
BASE_DIR = ['./output/attenscore/res_noGT+Sys', './output/attenscore/res_noGT+User']
is_pdf = True
model_list = ['glm', 'llama31', 'qwen']

layer_ids = {
    'glm': [20, 40],
    'llama31': [15, 31],
    'qwen': [39, 79],
}


def is_round_related(json_file='./datas/system_benchmark_eval_datas.json'):
    json_data = json.load(open(json_file, encoding="utf-8"))
    round_related_list = []
    for data in json_data:
        round_related_list.append(data['rounds_related'])
    #print(len(round_related_list))
    return round_related_list


memories = {}
def load_batch_data(file_name, layer_id):
    if not file_name in memories:
        memories[file_name] = np.load(file_name, allow_pickle=True).item()
    data = memories[file_name]
    
    if layer_id in data.keys():
        data = data[layer_id]
    else:
        # print(data.keys())
        return None, None
    
     # filter out system message
    valid_length = data['split_indices'][-1] - data['split_indices'][1]
    start_row = data['split_indices'][1] + 1
    
    assert start_row + valid_length <= len(data['data']), f'{start_row} + {valid_length} > {len(data["data"])}'
    
    data['data'] = data['data'][start_row:start_row + valid_length]

    return data['data'], data['split_indices']


def read_batch_data(is_replace, plot_sid, layer_idx, round_related):
    global memories
    
    data_full = {}
    split_indices_full = {}
    dir_path = BASE_DIR[is_replace]
    round_realted_list = is_round_related()
    
    def _build_data_inner(sid, layer_id):
        file_path = os.path.join(sub_dir_path, f'sid{sid}.npy')
        if os.path.exists(file_path):
            if round_related == -1 or round_related == round_realted_list[sid-1]:
                data, split_indices = load_batch_data(file_path, layer_id)
                if data is None:
                    print(f"WARNING: Layer {layer_id} not found in {file_path}")
                    return
                data_full[fn].append(data)
                split_indices_full[fn].append(split_indices)
        else:
            print(f"WARNING: System ID {sid} not found.")
    
    def _build_data_outer(sid, fn):
        if layer_idx == -1:
            for lid in range(1 if fn=='glm' else 0, layer_ids[fn][1] + 1):
                _build_data_inner(sid, lid)
        else:
            _build_data_inner(sid, layer_ids[fn][layer_idx])
    
    for model_name in model_list:
        memories = {} # clear memories
        
        fn = model_name
        sub_dir_path = os.path.join(dir_path, fn)
        data_full[fn] = []
        split_indices_full[fn] = []
        if plot_sid == -1:
            for i in range(1, 501):
                _build_data_outer(i, fn)
            print(f"Model {fn} loaded {len(data_full[fn])} data files.")
        else:
            _build_data_outer(plot_sid, fn)
    
    return data_full, split_indices_full


def calc_turn(model_name, idx,  data, split_indices, seg_length_list):
    #model_name = model_list[model_idx]
    
    y = [value_mapper(row, split_indices) for row in data]
    x = np.arange(len(y)).astype(np.float32)
    
    start_idx = 0
    ans_y = []
    for i in range(len(seg_length_list)):
        seg_length = seg_length_list[i]
        
        assert seg_length.is_integer()
        end_idx = start_idx + int(seg_length)
        assert end_idx <= len(x), f'{end_idx} > {len(x)} for model {model_name} system id {idx}'
        
        ans_y.append(np.mean(y[start_idx:end_idx]))

        start_idx = end_idx

    #print(model_name, ans_y)
    return ans_y

def process_data(is_replace=False, **kwargs):
    data_full, split_indices_full = read_batch_data(is_replace=is_replace, **kwargs)
    
    # cal average length for each segment
    #data_seg_len = np.zeros((len(model_list), 5))
    all_res = []
    
    for model_name in model_list:
        n_entries = len(data_full[model_name])
        res = np.empty((n_entries, 5))
        
        for j in trange(n_entries, desc=f'Processing {model_name}'):
            data, split_indices = data_full[model_name][j], split_indices_full[model_name][j]
            data_seg_length = np.diff(split_indices[seg_indices])
            res[j] = calc_turn(model_name, j, data, split_indices, data_seg_length)
        
        res = np.mean(res, axis=0)
        #print(res)
        all_res.append(res)
        print(model_name, res)
    
    return all_res

def plot_with_replace(ax, res_org, res_rep):
    x = np.arange(5) + 1
    for i, res in enumerate([res_org, res_rep]):
        for j in range(len(res)):
            ax.plot(x, res[j], label=model_list[j], 
                    color=color_palette[j], 
                    linestyle='-' if i == 0 else '--',
                    linewidth=1.2 if i == 0 else 0.5)
            # ax.set_ylim(0, 0.45)
    
    dy = 0.035
    base_x, basey = 3.1, 0.19
    head_length_ratio = 1 / 8
    for i in range(len(model_list)):
        get_avg = lambda x: np.mean(x)
        
        #res_org = [x * 1024 for x in res_org]
        #res_rep = [x * 1024 for x in res_rep]
        
        avg_org = get_avg(res_org[i])
        avg_rep = get_avg(res_rep[i])
        
        print(f"---------------{model_list[i]}-----------------")
        print(res_org[i], avg_org)
        print(res_rep[i], avg_rep)
        
        
        differences = [b - a for a, b in zip(res_org[i], res_rep[i])]
        print(differences)
        print(sum(differences))
        diff_avg = sum(differences) / 5
        print(f"diff_avg: {diff_avg * 100:.2f}%")
        print(f"avg_rep - avg_org: {(avg_rep - avg_org) *100:.2f}%")
        
        if np.sign(avg_rep - avg_org) > 0:
            ax.arrow(base_x, basey + dy * i, 0, dy * 0.7, color=color_palette[i],
                     head_width=0.05, head_length=dy * head_length_ratio, 
                     length_includes_head=True, 
                     fc=color_palette[i], ec=color_palette[i])
        else:
            ax.arrow(base_x, basey + dy * (i + 0.7), 0, -dy * 0.7, color=color_palette[i],
                     head_width=0.05, head_length=dy * head_length_ratio, 
                     length_includes_head=True,
                     fc=color_palette[i], ec=color_palette[i])
        ax.text(base_x + 0.1, basey + 0.003 + dy * i, 
                f'{(avg_rep - avg_org)*100:+.2f}%', 
                color=color_palette[i], 
                fontdict={'fontsize': 8, 
                        #   'font': 'Consolas'
                })
    
    rectangle = patches.Rectangle((base_x - 0.15, basey - 0.008), 1.35, dy * 3+0.04, 
                                edgecolor='black', facecolor='none', 
                                linewidth=0.8)
    ax.add_patch(rectangle)
    ax.text(base_x - 0.02, basey + dy * 3 + 0.004, 'Avg Diff.', fontsize=7.5, weight='bold')
    
    ax.text(3, 0.42, 'Treat System Message', ha='center', va='center', fontsize=9)
    ax.text(3, 0.385, 'as User Instruction', ha='center', va='center', fontsize=9)
    
    ax.set_xticks(x)
    ax.set_xticklabels([f'T{i+1}' for i in range(5)])

def plot_without_replace(ax, res_org, round_related, marker='o'):
    x = np.arange(5) + 1
    for i, res in enumerate(res_org):
        ax.plot(x, res, label=model_list[i], color=color_palette[i], 
                linewidth=1.2, marker=marker, markersize=3)
    
    ax.set_xticks(x)
    ax.set_xticklabels([f'T{i+1}' for i in range(5)])
    
    y_margin = 0.015
    ax.text(3.5, 0.28 + y_margin, 'Multi-turn', 
            ha='center', va='center', fontsize=9, weight='bold')
    ax.text(3.5, 0.28 - y_margin, 'Dependent' if round_related else 'Parallel', 
            ha='center', va='center', fontsize=9, weight='bold')

def do_plot(ax, plot_replace=True, **kwargs):
    cached_file_name = f"attenscore_id{kwargs['plot_sid']}_layer{kwargs['layer_idx']}_round{kwargs['round_related']}.npy"
    cached_file_name = f"plot/cache/{cached_file_name}"
    try:
        if kwargs['ignore_cache']:
            raise FileNotFoundError
        print(f"Loading from cache: {cached_file_name}")
        data = np.load(cached_file_name, allow_pickle=True).item()
        res_org, res_rep = data['res_org'], data['res_rep']
    except:
        del kwargs['ignore_cache']
        print(f"Cache not found, processing data...")
        res_org = process_data(is_replace=False, **kwargs)
        res_rep = process_data(is_replace=True, **kwargs)
        os.makedirs(os.path.dirname(cached_file_name), exist_ok=True)
        np.save(cached_file_name, {'res_org': res_org, 'res_rep': res_rep})
        print(f"Saved to cache: {cached_file_name}")
    
    if plot_replace:
        plot_with_replace(ax, res_org, res_rep)
    else:
        plot_without_replace(ax, res_org, kwargs['round_related'], marker='o')

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--id', '-i', type=int, default=-1, help='System Message ID for the plot')
    parser.add_argument('--layer', '-l', type=int, default=-1, help='Middle(0) / final(1) Layer for the plot')
    parser.add_argument('--related', '-r', type=int, default=-1, help='Whether to plot rounds related, -1 for All, 0 for False, 1 for True')
    parser.add_argument('--ignore_cache', '-c', action='store_true', help='Ignore cache and reprocess data')
    
    args = parser.parse_args()
    
    kwargs = {
        'plot_sid': args.id,
        'layer_idx': args.layer,
        'round_related': args.related,
        'ignore_cache': args.ignore_cache
    }
    
    plt.rcParams["font.family"] = "Calibri"
    mpl.rcParams.update({'font.size': 14})
    
    fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=300, tight_layout=True)
    
    do_plot(ax, **kwargs)
    
    hadles, labels = ax.get_legend_handles_labels()
    fig.legend(hadles[:3], labels[:3], loc='upper center', ncol=3, fontsize=12)
    
    file_name = f'figures/attenscore_id{args.id}_layer{args.layer}_round{args.related}' + ('.pdf' if is_pdf else '.png')
    plt.savefig(file_name, bbox_inches='tight', pad_inches=0.1)
    print(f'Figure saved to {file_name}.')
    
    # load_batch_data('../attenscore/res_noGT+Sys/glm/sid1.npy', 20)