

import numpy as np
import pickle
class ContinualMetricCalculator():
    def __init__(
            self,
            file_path='./data/test.pkl',
            mode = 'mmworld',
        ):
        self.file_path = file_path
        self.data = None
    
        self.data_templets = {
            'train' : ['task_names'],
            'eval' : {
                'task_name' : 1.0,
            },
        }
        self.mode = mode
        self.phases = len(self.data) if self.data is not None else 0

    def calculate_fgt(
            self,    
        ) :
        phase, score = 0. , 0.
        self.fgt_dict = {}
        self.fgt_data = []
        for phase, phase_data in enumerate(self.data) :
            if type(phase_data) == list : # multi metric per phase
                last_data = phase_data[-1]
            else : # single metric for last phase 
                last_data = phase_data
            self.fgt_data.append(last_data)


        for phase , phase_data in enumerate(self.fgt_data) :
            for train_task in phase_data['train'] :
                original_train_task = train_task

                if train_task not in phase_data['eval'].keys() :
                    continue
                if train_task in self.fgt_dict.keys() :
                    original_train_task = train_task
                    train_task = train_task + '_' + str(phase)
                    # continue
                    

                self.fgt_dict[train_task] = {
                    'fgt_score' : 0.0,
                    'learned_phase' : 0, # v 
                    'learned_score_init' : 0.0, # v 
                    'fgt_scores' : [],  
                    'learned_score' : 0.0,
                    'learned_scores' : [],
                }
                self.fgt_dict[train_task]['learned_phase'] = phase # learnd phase 
                
                train_phase_metric = phase_data['eval'][original_train_task]
                self.fgt_dict[train_task]['learned_score_init'] = train_phase_metric
                # self.fgt_dict[train_task]['learned_scores'].append((phase, train_phase_metric))

                # after phase calculation
                for f_phase in range(phase, self.phases) :
                    f_phase_data = self.data[f_phase]
                    f_phase_metric = f_phase_data['eval'][original_train_task]

                    # fgt_dif = max(train_phase_metric - f_phase_metric, 0)
                     
                    fgt_score = f_phase_metric - train_phase_metric 
                    if f_phase != phase :
                        self.fgt_dict[train_task]['fgt_scores'].append((f_phase, fgt_score))
                    self.fgt_dict[train_task]['learned_scores'].append((f_phase, f_phase_metric))

                # numpy processing
                self.fgt_dict[train_task]['fgt_scores'] = np.array(self.fgt_dict[train_task]['fgt_scores']).copy()
                self.fgt_dict[train_task]['learned_scores'] = np.array(self.fgt_dict[train_task]['learned_scores']).copy()

                # exception handling
                # print( self.fgt_dict[train_task]['learned_scores'] )

                if phase == len(self.fgt_data)-1 :
                    self.fgt_dict[train_task]['fgt_score'] = 0.
                    self.fgt_dict[train_task]['learned_score'] = self.fgt_dict[train_task]['learned_scores'][0, 1]

                else :
                    score_mean = np.mean(self.fgt_dict[train_task]['fgt_scores'][:, 1])
                    self.fgt_dict[train_task]['fgt_score'] = score_mean

                    # if self.fgt_dict[train_task]['learned_scores'].shape[0] == 1 :
                    #     print("asdfasdf")
                    #     self.fgt_dict[train_task]['learned_score'] = 0.
                    #     continue
                    score_mean = np.mean(self.fgt_dict[train_task]['learned_scores'][:, 1])
                    self.fgt_dict[train_task]['learned_score'] = score_mean



        # return self.fgt_dict
                
        # print(self.fgt_dict)
        # print(self.fgt_dict.keys()) 
        # for k in self.fgt_dict.keys() :
        #     print(k, self.fgt_dict[k]['learned_score_init'], self.fgt_dict[k]['fgt_score'])

        fgt_sum = 0.    
        fgt_cnt = 0

        lsi_sum = 0.
        lsi_cnt = 0

        ls_sum = 0.
        ls_cnt = 0
        
        # print(self.fgt_dict)

        # print(self.fgt_dict)
        for k in self.fgt_dict.keys() :
            # if self.fgt_dict[k]['fgt_score'] < 0 :
            #     continue
            fgt_sum += self.fgt_dict[k]['fgt_score']
            fgt_cnt += 1

            lsi_sum += self.fgt_dict[k]['learned_score_init']
            lsi_cnt += 1

            ls_sum += self.fgt_dict[k]['learned_score']
            ls_cnt += 1

        # print(self.fgt_dict)
        fgt_mean = fgt_sum / fgt_cnt
        lsi_mean = lsi_sum / lsi_cnt
        ls_mean = ls_sum / ls_cnt

        print(f'FWT_mean : {lsi_mean:.3f}')
        print(f'FGT_mean : {fgt_mean:.3f}')
        print(f'AUC_mean : {ls_mean:.3f}')

    
def from_logging(logging_file_path) :
    with open(logging_file_path, 'r') as f :
        lines = f.readlines()
    lines = [line.strip() for line in lines]
    lines = [line.split(' ') for line in lines]
    lines = [line for line in lines if len(line) > 1]
    # lines = [line for line in lines if line[0] == 'train/loss']

    phase_metric_list = []

    for line in lines :
        # print(line)
        if line[0] == "{'data_name':" :
            #phase init 
            if line[2] != "'data_paths':" :
                dp_idx = line.index("'data_paths':")
                task_line = '-'.join(line[1:dp_idx])
            else : 
                task_line = line[1]

            task_line = task_line.replace("'", '')
            task_line = task_line.split(',')
            task_line = [item for item in task_line if item != '']
            phase_metric_dict = {
                'train' : task_line,
                'eval' : {},
            }
            phase_metric_list.append(phase_metric_dict)
        if 'skill' in line[0]  and line[1] == 'is' :
            sidx = line.index('rew')
            task_id_line = line[3:sidx]
            task_id_line = [item for item in task_id_line if item != '']
            task_id_line = [item.strip(",'[]") for item in task_id_line if item != '']
            joined_task = '-'.join(task_id_line)
            if joined_task not in phase_metric_dict['eval'].keys() :
                phase_metric_dict['eval'][joined_task] = float(line[-1])
            else :
                phase_metric_dict['eval'][joined_task] += float(line[-1])
    for i , tl in enumerate(phase_metric_list) :
        # print(f'phase {i} : ', len(tl['eval'].keys()))
        if len(tl['eval'].keys()) == 0 :
            del phase_metric_list[i]

    return phase_metric_list

def kitchen_possible_tasks() :
    evaluation_sequences = [
            'mtlh','mlsh','mktl','mkth','mksh',
            'mkls','mklh','mkbs','mkbh','mbts',
            'mbtl','mbth','mbsh','mbls','ktls',
            'klsh','kbts','kbtl','kbth','kbsh',
            'kbls','kblh','btsh','btls',
        ]

    total_var = 0 
    for i in range(1,5) :
        idi_len = i
        evs_is = []
        for s in evaluation_sequences :
            idi = s[:idi_len]
            evs_is.append(idi)
        evs_is = list(set(evs_is))
        print( evs_is)
        total_var += len(evs_is)
    print(total_var)
    ## TOTAL 51 variations for learning


import argparse

parser = argparse.ArgumentParser(description='L2M based continual learner trianing function.')

parser.add_argument('-al', '--algo', type=str, help='algorithm', default='cilu')
parser.add_argument('-e', '--env', type=str, help='env', default='kitchen')
parser.add_argument('-p', '--path', type=str, help='path', default='cilu')
parser.add_argument('-g', '--grep', type=str, help='grep', default='')
args = parser.parse_args()

import os 

if __name__ == '__main__' :
    
    algo = args.algo    
    env = args.env

    etype = 'CiLu_expriments'

    if args.path == 'seq' or 'seq' in args.algo or 'ewc' in args.algo or args.algo in ['seq', 'ewc', 'er'] :
        etype = 'seq_expriments'
    if args.algo in ['GA', 'ERGA', 'ERSA', 'ER' ] :
        etype = 'seq_expriments'

    base_path = f'data/{etype}/{algo}/{env}'


    eval_paths = os.listdir(base_path)
    eval_paths = [path for path in eval_paths if args.grep in path]
    eval_paths = [os.path.join(base_path, path,'training_log.log') for path in eval_paths]
    eval_paths.sort()
    print(eval_paths)


    # eval_paths = ['/home/meohee/clus/src/dummy/test_unseen.txt']

    # print(pml)
    cal = ContinualMetricCalculator()
    # print()

    for i, path in enumerate(eval_paths) :
        try : 
            print(path)
            pml = from_logging(path)
            # print(pml[0])
            # exit()
            cal.data = pml
            cal.phases = len(cal.data)
            cal.calculate_fgt()
        except :
            pass
    

