# Combined trace file should have two loss terms
# Usage: python plot_trace.py comb_trace_file indv_trace_file png_file_nmae alpha plot_accuracy

import sys
import os
import matplotlib.pyplot as plt
import argparse



class TraceIndividual():
    def __init__(self, epoch):
        self.epoch = epoch
        self.train_loss, self.valid_loss = -1.0, -1.0
        self.train_top1_acc, self.valid_top1_acc = -1.0, -1.0
        self.train_top5_acc, self.valid_top5_acc = -1.0, -1.0

    def line_to_trace(self, line, number_losses=1):
        values = line.split(',')
        for i in range(1, len(values)-1):
            values[i] = values[i].strip('\'')
            values[i] = values[i].strip()
            values[i] = values[i].strip('\'')
            values[i] = float(values[i])

        
        self.train_loss, self.valid_loss = values[1], values[4]
        self.train_top1_acc, self.valid_top1_acc = values[2], values[5]
        self.train_top5_acc, self.valid_top5_acc = values[3], values[6]
        
        return values 

    def __repr__(self):
        s = []
        s.append(self.train_loss)
        s.append(self.train_top1_acc)
        s.append(self.train_top5_acc)
        s.append(self.valid_loss)
        s.append(self.valid_top1_acc)
        s.append(self.valid_top5_acc)

        return '\t'.join(['{0:.8f}'.format(v) for v in s])

class Trace():
    def __init__(self, epoch):
        self.epoch = epoch
        self.train_loss, self.valid_loss = -1.0, -1.0
        self.train_full_top1_acc, self.valid_full_top1_acc = -1.0, -1.0
        self.train_full_top5_acc, self.valid_full_top5_acc = -1.0, -1.0
        self.train_small_top1_acc, self.valid_small_top1_acc = -1.0, -1.0
        self.train_small_top5_acc, self.valid_small_top5_acc = -1.0, -1.0

    def line_to_trace(self, line, number_losses=1):
        values = line.split(',')
        for i in range(1, len(values)-1):
            values[i] = values[i].strip('\'')
            values[i] = values[i].strip()
            values[i] = values[i].strip('\'')
            values[i] = float(values[i])

        
        idx1 = number_losses 
        idx2 = idx1 + 5 + number_losses - 1

        '''
        print(values)
        print(idx1, idx2)
        '''
        self.train_loss, self.valid_loss = values[1], values[idx1+5]
        self.train_full_top1_acc, self.valid_full_top1_acc = values[idx1+1], values[idx2+1]
        self.train_full_top5_acc, self.valid_full_top5_acc = values[idx1+2], values[idx2+2]
        self.train_small_top1_acc, self.valid_small_top1_acc = values[idx1+3], values[idx2+3]
        self.train_small_top5_acc, self.valid_small_top5_acc = values[idx1+4], values[idx2+4]
        
        return values 

    def __repr__(self):
        s = []
        s.append(self.train_loss)
        s.append(self.train_full_top1_acc)
        s.append(self.train_full_top5_acc)
        s.append(self.train_small_top1_acc)
        s.append(self.train_small_top5_acc)
        s.append(self.valid_loss)
        s.append(self.valid_full_top1_acc)
        s.append(self.valid_full_top5_acc)
        s.append(self.valid_small_top1_acc)
        s.append(self.valid_small_top5_acc)

        return '\t'.join(['{0:.8f}'.format(v) for v in s])


class TraceWithTwoLosses(Trace):
    def __init__(self, epoch):
        super().__init__(epoch)
        self.train_loss_ce, self.valid_loss_ce = -1.0, -1.0

    def line_to_trace(self, line):
        values = super().line_to_trace(line, number_losses=2)
        self.train_loss_ce = values[2]
        self.valid_loss_ce = values[8]
        return values

    def __repr__(self):
        s = []
        s.append(self.train_loss)
        s.append(self.train_loss_ce)
        s.append(self.train_full_top1_acc)
        s.append(self.train_full_top5_acc)
        s.append(self.train_small_top1_acc)
        s.append(self.train_small_top5_acc)
        s.append(self.valid_loss)
        s.append(self.valid_loss_ce)
        s.append(self.valid_full_top1_acc)
        s.append(self.valid_full_top5_acc)
        s.append(self.valid_small_top1_acc)
        s.append(self.valid_small_top5_acc)

        return '\t'.join(['{0:.8f}'.format(v) for v in s])



def get_traces_from_file(file_name, is_combined=True):
    traces = []
    with open(file_name, 'r') as f:
        i = 1
        for line in f:
            trace = TraceWithTwoLosses(i) if is_combined else TraceIndividual(i)
            line = line.replace(" ", "") 
            line = line.strip()
            if len(line) == 0: continue
           
            trace.line_to_trace(line)
            i += 1
            traces.append(trace)
    return traces


def plot_traces(mult_traces, plot_accuracy=True, alphas=[4], filename="cifar_100_loss_profile.png", legends=["Adjoint full", "Standard full"]):
    fig = plt.figure()
    if plot_accuracy:
        plt.plot([trace.valid_full_top1_acc for trace in comb_traces])
        plt.plot([trace.valid_small_top1_acc for trace in comb_traces])
        plt.plot([trace.valid_top1_acc for trace in indv_traces])
        plt.ylabel('Accuracy', fontsize=16) 
        fig.suptitle('Valid accuracy vs #Epochs', fontsize=20)
        plt.gca().legend(("Adjoint-{} full".format(alpha), "Adjoint-{} small".format(alpha), 'Standard full'))
    
    else: 
        for traces in mult_traces:
            plt.plot([trace.valid_loss_ce if hasattr(trace, 'valid_loss_ce') else trace.valid_loss for trace in traces])

        plt.gca().legend(tuple(legends))
        plt.ylabel('CE loss', fontsize=16) 
        fig.suptitle('Valid loss vs #Epochs', fontsize=20)

    plt.xlabel('Epoch', fontsize=18)
    plt.savefig(filename)


def generate_plots(dataset_path, use_dropout, dropout_rate):
    alphas = [4, 8, 16]
    #alphas = [4]
    individual_file = os.path.join(dataset_path, "individual.txt")
    if use_dropout: individual_file = os.path.join(dataset_path, "individual_dropout{}.txt".format(dropout_rate))
    
    indv_traces = get_traces_from_file(individual_file, is_combined=False)
    
    for alpha in alphas:
        combined_file = os.path.join(dataset_path, "combined{}.txt".format(alpha))
        comb_traces = get_traces_from_file(combined_file)

        save_file_name = os.path.join(dataset_path, "plots", "combined{}.png".format(alpha))
        legends = ["Adjoint-{} full".format(alpha)]
        if use_dropout: 
            save_file_name = os.path.join(dataset_path, "plots", "combined{}_dropout{}.png".format(alpha, dropout_rate)) 
            legends.append("Standard full + dropout")
        else: 
            legends.append("Standard full")
        plot_traces(indv_traces, [comb_traces], filename=save_file_name, alphas=[alpha], plot_accuracy=False, legends=legends)



def get_plot(architecture, dataset, alphas, is_mask, mask_values, dropoutvalue):
    l = []
    l.append("Adjoint-trace/{}/{}".format(architecture, dataset))
    if is_mask: l.append("adjoint{}_rand.png".format(alphas[0]))
    else: l.append("adjoint.png")

    if is_mask:
        for mask in mask_values:
            l.append(("combined{}_random{}.txt".format(alphas[0], mask), True, "Adj-{} + Rand-{:.1f}".format(alphas[0], mask/100.)) )
    else:
        for alpha in alphas:
            l.append( ("combined{}.txt".format(alpha), True, "Adjoint-{}".format(alpha)) )
    
    l.append(("individual.txt", False, "Standard"))
    if dropoutvalue is not None: l.append(("individual_dropout{}.txt".format(dropoutvalue), False, "Standard + dropout") )
    
    return l

def generate_multiplots(dataset_path):
    all_plots = []
    all_plots.append([ "Adjoint-trace/loss-function", "loss.png", ("cos.txt", True, "Cos"), ("exponential.txt", True, "Exp"), ("linear.txt", True, "Linear"), ("quadratic.txt", True, "Quadratic") ] )
   
    # Resnet 50, adjoint plots with individual and dropouts
    arch = "resnet50"
   
    # CIFAR-100
    dataset = "cifar100"
    all_plots.append(get_plot(arch, dataset, [4, 8, 16], False, None, 75))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 60, 70, 80, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 60, 70, 80, 90], None))
    all_plots.append(get_plot(arch, dataset, [16], True, [50, 60, 70, 80, 90], None))

    # CIFAR-10
    dataset = "cifar10"
    all_plots.append(get_plot(arch, dataset, [4, 8, 16], False, None, 75))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [16], True, [50, 90], None))
    
    # Imagenet
    dataset = "imagenet"
    all_plots.append(get_plot(arch, dataset, [4, 8], False, None, 50))
   
    # Imagewoof
    dataset = "imagewoof"
    all_plots.append(get_plot(arch, dataset, [4, 8, 16], False, None, 50))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 60, 70, 80, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 60, 70, 80, 90], None))
    all_plots.append(get_plot(arch, dataset, [16], True, [50, 80, 90], None))
   
    # Pets
    dataset = "oxford-pets"
    all_plots.append(get_plot(arch, dataset, [2, 4, 8], False, None, 50))
    all_plots.append(get_plot(arch, dataset, [2], True, [50], None))
    all_plots.append(get_plot(arch, dataset, [4], True, [50], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 90], None))


    # Resnet 18, adjoint plots with individual and dropouts 
    arch = "resnet18"
   
    # CIFAR-100
    dataset = "cifar100"
    all_plots.append(get_plot(arch, dataset, [4, 8, 16], False, None, 50))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 60, 70, 80, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 60, 70, 80, 90], None))
    all_plots.append(get_plot(arch, dataset, [16], True, [50, 60, 70, 80, 90], None))

    # CIFAR-10
    dataset = "cifar10"
    all_plots.append(get_plot(arch, dataset, [2, 4, 8, 16], False, None, 50))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [16], True, [90], None))
    
    # Imagewoof
    dataset = "imagewoof"
    all_plots.append(get_plot(arch, dataset, [4, 8, 16], False, None, 50))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 90], None))
   
    # Pets
    dataset = "oxford-pets"
    all_plots.append(get_plot(arch, dataset, [2, 4, 8, 16], False, None, 50))
    all_plots.append(get_plot(arch, dataset, [2], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [4], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [8], True, [50, 90], None))
    all_plots.append(get_plot(arch, dataset, [16], True, [50, 90], None))

    for fnl in all_plots:
        dp = fnl[0]
        pn = fnl[1]

        mult_traces = []
        legends = []
        for i in range(2, len(fnl)):
            filename = os.path.join(dp, fnl[i][0])
            print(filename)
            mult_traces.append(get_traces_from_file(filename, is_combined=fnl[i][1]) )
            legends.append(fnl[i][2])

        save_file_dir = os.path.join(dp, "plots")
        if not os.path.exists(save_file_dir): os.makedirs(save_file_dir)
        save_file_name = os.path.join(save_file_dir, pn) 

        plot_traces(mult_traces, filename=save_file_name, plot_accuracy=False, legends=legends)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='Adjoint Network')
    parser.add_argument('--dataset_path', type=str, default='Adjoint-trace/resnet50/cifar100', help='')
    parser.add_argument('--use_dropout', type=str, default=None, help='')
    parser.add_argument('--dropout_rate', type=str, default=None, help='')
    args = parser.parse_args()
    
    dataset_path = os.path.join(os.getcwd(), args.dataset_path)
    use_dropout = False
    dropout_rate = None

    if args.use_dropout is not None: 
        use_dropout = (args.use_dropout.lower() == "true")
        dropout_rate = args.dropout_rate

    print(dataset_path, use_dropout, dropout_rate)
    #generate_plots(dataset_path, use_dropout, dropout_rate)
    generate_multiplots(dataset_path)


