import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import copy
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from scheduling.environment import SchedulingEnvironment, generate_environment, make_env

from solvers.hetgat_solver_individual import HetGatSolverIndividual
from models.graph_scheduler import GraphSchedulerCritic

from solvers.edf import EarliestDeadlineFirstAgent
from solvers.improved_edf import ImprovedEarliestDeadlineFirstAgent
from solvers.milp_solver import MILP_Solver
from training.replay_buffer import ReplayBuffer

from utils.utils import set_seed

# Read the input from a file and generate boxplot for the data.
# The input format for each line is Makespan Greedy, Feasible Count Greedy, Makespan Sample 1, Feasible Count Sample 1... Makespan Sample N, Feasible Count Sample N
# -0.16253305097829104, 1, -0.45960973617376166, 4, -0.2593153821318272, 2, -0.45957008979456104, 4, -0.4589840112819287, 4, -0.4637361008452714, 4, -0.36265866508653744, 3, -0.3589840112819287, 3, -0.2596734000434074, 2

# environment_name = "data_problem_set_r5_t20_s10_f30_w50_euc_2000"
# seeds = [10, 11, 12]
environment_name = "data_problem_set_r10_t100_s10_f80_w25_euc_2000"
seeds = [10, 11, 12]

filename_prefix = os.path.join(parentdir, f"results/evals/sample_cp/cp__{environment_name}__")

filename_prefix = f"/{filename_prefix}"
graph_mode = "hgt_edge_resnet"
checkpoint = 1000

# baseline_path_prefix = os.path.join(parentdir, "results/evals/evaluation__data_problem_set_r5_t20_s10_f30_w50_euc_2000__10__1__")
baseline_path_prefix = os.path.join(parentdir, f"results/evals/evaluation__{environment_name}__10__1__")

baselines = ['edf', 'improved_edf', 'milp_solver']
baseline_data = []

baseline_makespan = {}
baseline_feasible = {}
for baseline in baselines:
    with open(f"{baseline_path_prefix}{baseline}.txt", "r") as f:
        baseline_makespan[baseline] = []
        baseline_feasible[baseline] = []
        
        data = f.readlines()
        data = [d.strip() for d in data]
        for i, datum in enumerate(data):
            if datum == '':
                continue
            datum = datum.split(", ")
            makespan = datum[0]
            feasible_count = datum[1]
            baseline_makespan[baseline].append(float(makespan))
            baseline_feasible[baseline].append(int(feasible_count))
        baseline_makespan[baseline] = np.array(baseline_makespan[baseline])
        baseline_feasible[baseline] = np.array(baseline_feasible[baseline])    


greedy_makespan = []
greedy_feasible = []
sampled_makespan = []
sampled_feasible = []

for seed in seeds:
    filename = f"{filename_prefix}{seed}__{graph_mode}__{str(checkpoint).zfill(5)}.txt"
    print(filename)
    # filename = os.path.append()
    with open(filename, "r") as f:
        data = f.readlines()
        data = [d.strip() for d in data]
        
        for i, datum in enumerate(data):
            # if i > 17:
            #     break
            if datum == '':
                continue
            datum = datum.split(", ")
            # even indices are makespan, odd indices are feasible count
            makespan = [float(datum[i]) for i in range(0, len(datum), 2)]
            feasible_count = [int(datum[i]) for i in range(1, len(datum), 2)]
            print("Makespan:", makespan)
            print("Feasible:", feasible_count)
            
            if len(greedy_makespan) <= i:
                greedy_makespan.append([])
                greedy_feasible.append([])
                sampled_makespan.append([])
                sampled_feasible.append([])
            greedy_makespan[i].append(makespan[0])
            greedy_feasible[i].append(feasible_count[0])
            # if len(sampled_makespan[i]) == 0:
            #     sampled_makespan[i].append([])
            #     sampled_feasible[i].append([])
            sampled_makespan[i].extend(makespan[1:])
            sampled_feasible[i].extend(feasible_count[1:])
            
# print("Greedy Makespan:", greedy_makespan)
# print("Greedy Feasible:", greedy_feasible)
# print("Sampled Makespan:", sampled_makespan)
# print("Sampled Feasible:", sampled_feasible)

# convert everything to numpy for easier analysis
greedy_makespan = np.array(greedy_makespan)
greedy_feasible = np.array(greedy_feasible)
sampled_makespan = np.array(sampled_makespan)
sampled_feasible = np.array(sampled_feasible)

# Print Mean and Standard Deviation
print("Greedy Makespan Mean:", np.mean(greedy_makespan), "Std:", np.std(greedy_makespan))
print("Greedy Feasible Mean:", np.mean(greedy_feasible), "Std:", np.std(greedy_feasible))
print("Sampled Makespan Mean:", np.mean(sampled_makespan), "Std:", np.std(sampled_makespan))
print("Sampled Feasible Mean:", np.mean(sampled_feasible), "Std:", np.std(sampled_feasible))

print(f"Dim: {greedy_makespan.shape}")
print(f"{np.mean(np.max(greedy_makespan, axis=1))} {np.mean(np.min(greedy_makespan, axis=1))} {np.mean(np.max(sampled_makespan, axis=1))} {np.mean(np.min(sampled_makespan, axis=1))}")
print(f"{np.mean(np.max(greedy_feasible, axis=1))} {np.mean(np.min(greedy_feasible, axis=1))} {np.mean(np.max(sampled_feasible, axis=1))} {np.mean(np.min(sampled_feasible, axis=1))}")

# combine greedy and sampled for data analysis
makespan = np.concatenate((greedy_makespan, sampled_makespan), axis=1)
feasible = np.concatenate((greedy_feasible, sampled_feasible), axis=1)
print(f"Dim: {makespan.shape}")
print(f"{np.mean(np.max(makespan, axis=1))} {np.mean(np.min(makespan, axis=1))} {np.mean(np.max(feasible, axis=1))} {np.mean(np.min(feasible, axis=1))}")

# create an display a boxplot of greedy, sampled and combined models in the x axis and feasible count in the y axis
import matplotlib
import matplotlib.pyplot as plt
def boxplot(data, labels, y_axis = "Makespan(Higher Better)", env_name = "Problem Set", title_prefix = "Performance Comparison of Models on ",):
    font = {'family' : 'Times New Roman',
        'weight' : 'normal',
        'size'   : 14}

    matplotlib.rc('font', **font)
    fig = plt.figure(figsize =(10, 7))
    ax = fig.add_subplot(111)

    bp = ax.boxplot(data, patch_artist=True, notch='True', vert=1)

    color_lens = len(data)
    
    cm = plt.get_cmap('gist_rainbow')
    colors = [cm((i + 1)/color_lens) if i >= 1 else cm(i/color_lens) for i in range(color_lens)]
    
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # changing color and linewidth of
    # whiskers
    for whisker in bp['whiskers']:
        whisker.set(color ='#8B008B',
                    linewidth = 1.5,
                    linestyle =":")
    
    # changing color and linewidth of
    # caps
    for cap in bp['caps']:
        cap.set(color ='#8B008B',
                linewidth = 2)
    
    # changing color and linewidth of
    # medians
    for median in bp['medians']:
        median.set(color ='black',
                linewidth = 3)
    
    # changing style of fliers
    for flier in bp['fliers']:
        flier.set(marker ='D',
                color ='#e7298a',
                alpha = 0.5)
        
    # x-axis labels
    ax.set_xticklabels(labels)
    
    plt.title(f"{title_prefix} {env_name}")
    # Removing top axes and right axes
    # ticks
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    
    plt.ylabel(y_axis)
    
    # show plot
    # plt.show()
    plt.savefig(f"figures/{filename_prefix.split('/')[-1]}_{graph_mode}_{checkpoint}_{env_name}_{y_axis}.png")
    


print(greedy_feasible.shape, sampled_feasible.shape, feasible.shape)
data = []
labels = []
for seed in range(greedy_feasible.shape[1]):
    data.append(greedy_feasible[:, seed])
    labels.append(f"Policy\nMax {seed+10}")

data.extend([greedy_feasible.min(1), sampled_feasible.min(1), feasible.min(1)])
labels.extend(["Policy Max\nEnsemble", "Sampled\nEnsemble", "Combined\nEnsemble"])

# add baselines to the data
for baseline in baselines:
    data.append(baseline_feasible[baseline])
    
    # remove the underscore and capitalize the first letter
    baseline = baseline.replace("_", "\n").title()
    # replace just edf and milp with capital letters in string
    baseline = baseline.replace("Edf", "EDF").replace("Milp", "MILP")
    
    labels.append(baseline)
boxplot(data, labels, y_axis="Infeasible Count (Lower Better)", env_name=environment_name)

# makespan
data = []
labels = []
for seed in range(greedy_makespan.shape[1]):
    data.append(greedy_makespan[:, seed])
    labels.append(f"Policy\nMax{seed+10}")

data.extend([greedy_makespan.max(1), sampled_makespan.max(1), makespan.max(1)])
labels.extend(["Policy Max\nEnsemble", "Sampled", "Combined"])

# add baselines to the data
for baseline in baselines:
    data.append(baseline_makespan[baseline])
    
    # remove the underscore and capitalize the first letter of every word
    baseline = baseline.replace("_", "\n").title()
    # replace just edf and milp with capital letters in string
    baseline = baseline.replace("Edf", "EDF").replace("Milp", "MILP")
    labels.append(baseline)
boxplot(data, labels, y_axis="Makespan (Higher Better)", env_name=environment_name)
