import argparse
import os

import numpy as np
from matplotlib import pyplot as plt
from torch import nn

# from datasetconfig import uni_run_configs, draw_result
# from utils.metrics import metric_gw

def draw_yq(model, inputs, target, preds, path, color):
   
    plt.rcParams.update({'legend.fontsize':25})
    inputs = np.squeeze(inputs)
    target = np.squeeze(target)
    pred = np.squeeze(preds)
   

    input_len = inputs.shape[0]
    output_len = target.shape[0] 

    L1, = plt.plot(range(0, input_len + output_len), np.concatenate([inputs, target]), label='Target', linewidth=5, color='red')
    L2, = plt.plot(range(input_len - 1, input_len + output_len), np.concatenate([inputs[input_len - 1:input_len], pred]),  
            label=f'{model}', linewidth=5, color=color)
    plt.legend(handles=[L1,L2], labels=['Target', f'{model}'], loc='upper left')

    if path != None:
        plt.savefig(path, bbox_inches='tight',dpi=300)
    else:
        plt.show()
    plt.close()


parser = argparse.ArgumentParser()

parser.add_argument('-size', type=int, default=1)
parser.add_argument('-datasets', default=['CESM2'], type=str, nargs='+')
parser.add_argument('-models', default=['IDOL','TDRL','CARD','FITS','MICN','iTransformer','TimesNet','Autoformer'], type=str, nargs='+')
parser.add_argument('-lens', default=[96, 192, 336], type=int, nargs='+')
parser.add_argument('-seed', default=[2024], type=int, nargs='+')
parser.add_argument('-color', default=['green','#A5AEB7','#925EB0','#CC7C71','#9BBBE1','#F09BA0','#EAB883', '#7E99F4'], type=str, nargs='+')
# parser.add_argument('-i_list', type=int, nargs='+')
# parser.add_argument('-j_list', type=int, nargs='+')
# parser.add_argument('-k_list', type=int, nargs='+')
parser.add_argument('-device', default=0, type=int)
parser.add_argument('-k', default=1, type=int)

args = parser.parse_args()
# ['IDOL', 'TDRL',   'CARD',   'FITS',   'MICN',   'iTransformer','TimesNet','Autoformer']
# ['green','#A5AEB7','#925EB0','#CC7C71','#9BBBE1','#F09BA0',     '#EAB883', '#7E99F4']
# mse = nn.MSELoss()
for model, color in zip(args.models, args.color):
    for dataset in args.datasets:
        for pred_len in args.lens:
            for seed in args.seed:
                inputs = np.load(f'./draw_results/{dataset}/{model}/{pred_len}_{seed}/x.npy', allow_pickle=True)
                targets = np.load(f'./draw_results/{dataset}/{model}/{pred_len}_{seed}/trues.npy', allow_pickle=True)
                preds = np.load(f'./draw_results/{dataset}/{model}/{pred_len}_{seed}/preds.npy', allow_pickle=True)
               
                print('start draw')
                k = args.k
                root_path = f"./IDOL_draw_pictures/{dataset}_{k}/{model}/{pred_len}_{seed}"
                os.makedirs(root_path, exist_ok=True)
                a = inputs.shape[0]
                for i in range(a):
                    print(i)
                    path = os.path.join(root_path, f'{i}_{k}.png')
                    draw_yq(model, inputs[i][:, k], targets[i][:, k], preds[i][:, k], path, color)

