from torch.utils.data import DataLoader
import os
import torch

from manage.files import FileHandler
from data.data import get_dataset, CurrentDatasetInfo, Modality, StateSpace
from manage.logger import Logger
from manage.generation import GenerationManager
from manage.training import TrainingManager
from evaluate.EvaluationManager import EvaluationManager
from manage.checkpoints import load_experiment, save_experiment
from manage.setup import _get_device, _optimize_gpu, _set_seed

from ddpm_init import init_method_ddpm, init_models_optmizers_ls, init_learning_schedule

# import argparse
import argparse
import pathlib

import matplotlib.pyplot as plt
import numpy as np

def print_dict(d, indent = 0):
    for k, v in d.items():
        if isinstance(v, dict):
            print('\t'*indent, k, '\t:')
            print_dict(v, indent + 1)
        else:
            print('\t'*indent, k, ':', v)

def read_results():
    parser = argparse.ArgumentParser(
        description="Read multiple eval/param files, then group‐plot metrics by lr and bs"
    )
    parser.add_argument(
        '--folders', 
        type=str, 
        nargs='+', 
        required=True,
        help="One or more folder globs, each containing exactly one eval* and one parameters* file"
    )
    parser.add_argument(
        '--plot', 
        action='store_true',
        help="Generate and save grouped plots"
    )
    parser.add_argument(
        '--window', 
        type=int, 
        default=None, 
        help="If set, apply a moving‐average of this window size"
    )
    parser.add_argument(
        '--output', 
        type=str, 
        default=None,
        help="Directory in which to save plots (required if --plot)"
    )
    parser.add_argument(
        '--suffix',
        type=str,
        default='',
        help="Suffix to append to each output filename"
    )
    
    
    # topo_path
    parser.add_argument('--topo_path', type=str, default=None, help='Path to topo file')
    
    args = parser.parse_args()
    
    # 1) Expand all folder globs
    dirs = []
    for pat in args.folders:
        matched = list(pathlib.Path('.').glob(pat))
        if not matched:
            raise FileNotFoundError(f"No folders matched pattern {pat!r}")
        dirs.extend(matched)
        print(f"Matched {len(matched)} folders for pattern {pat!r}")
    dirs = [d for d in dirs if d.is_dir()]
    if not dirs:
        raise FileNotFoundError("No directories found in --folders globs")
    print(f"Found {len(dirs)} run folders.")

    # 2) For each folder, locate its eval & parameters file
    records = []
    for d in dirs:
        eval_fs  = list(d.glob('*eval*.*'))
        param_fs = list(d.glob('*parameter*.*'))
        if len(eval_fs) != 1 or len(param_fs) != 1:
            raise ValueError(
                f"In folder {d!r}, expected exactly 1 eval* and 1 parameters* file, "
                f"found {len(eval_fs)} eval, {len(param_fs)} parameter."
            )
        eval_fp, param_fp = eval_fs[0], param_fs[0]
        ev = torch.load(eval_fp,  weights_only=False)
        pm = torch.load(param_fp, weights_only=False)
        print('training batch size in {}: {}'.format(d, pm['training']['batch_size']))
        records.append({
            'eval': ev['eval'],
            'optim': pm['optim'],
            'training': pm['training'],
        })
    # --- dedupe records so that each (lr, bs) appears only once -------
    seen = set()
    deduped = []
    for rec in records:
        key = (rec['optim']['lr'], rec['training']['batch_size'])
        if key not in seen:
            seen.add(key)
            deduped.append(rec)
    records = deduped
    # -------------------------------------------------------------------
    
    print(f"Found {len(records)} unique (lr, bs) records.")
    # print unique lr and bs found:
    lrs = set()
    bss = set()
    for rec in records:
        lrs.add(rec['optim']['lr'])
        bss.add(rec['training']['batch_size'])
    print(f"Unique learning rates: {sorted(lrs)}")
    print(f"Unique batch sizes: {sorted(bss)}")
    
    # if averaging
    '''
    import numpy as np

metrics = ['losses_batch', 'score_losses_batch', 'grad_norm']

# group records by (lr, bs)
from collections import defaultdict
grouped = defaultdict(list)
for rec in records:
    key = (rec['optim']['lr'], rec['training']['batch_size'])
    grouped[key].append(rec)

# build a new list of “averaged” records
avg_records = []
for (lr, bs), recs in grouped.items():
    avg_eval = {}
    for m in metrics:
        # stack them (truncate to the shortest run)
        arrays = [r['eval'][m] for r in recs]
        L = min(a.shape[0] for a in arrays)
        stacked = np.vstack([a[:L] for a in arrays])
        avg = stacked.mean(axis=0)
        std = stacked.std(axis=0)
        avg_eval[m] = (avg, std)
    avg_records.append({
        'optim':      {'lr': lr},
        'training':   {'batch_size': bs},
        'eval_mean':  avg_eval,    # now holds (mean, std) tuples 
    })

# then plot using avg_records instead of records—
# e.g. plt.plot(mean) and plt.fill_between(mean±std) if you like.'''

    # 3) Optionally plot
    if args.plot:
        if args.output is None:
            raise ValueError("When using --plot, you must also set --output")
        out_dir = pathlib.Path(args.output)
        out_dir.mkdir(parents=True, exist_ok=True)

        metrics = ['losses_batch', 'score_losses_batch', 'grad_norm']

        # group by learning rate and batch size
        by_lr = {}
        by_bs = {}
        for rec in records:
            lr = rec['optim']['lr']
            bs = rec['training']['batch_size']
            by_lr.setdefault(lr, []).append(rec)
            by_bs.setdefault(bs, []).append(rec)

        # For each LR & metric → compare all BS
        for lr, recs in by_lr.items():
            for m in metrics:
                plt.figure()
                for rec in recs:
                    arr = rec['eval'][m]
                    if args.window:
                        arr = np.convolve(
                            arr, np.ones(args.window)/args.window, mode='valid'
                        )
                    bs = rec['training']['batch_size']
                    plt.plot(arr, label=f'bs={bs}')
                plt.xscale('log'); plt.yscale('log')
                plt.xlabel('Training steps')
                plt.ylabel(m)
                plt.xlim(left=len(arr) / 1000)
                plt.title(f"Metric={m} @ lr={lr}")
                plt.legend()
                fn = f"lr_{lr}_{m}_all_bs{('_'+args.suffix) if args.suffix else ''}.png"
                plt.savefig(out_dir / fn, bbox_inches='tight')
                plt.close()

        # For each BS & metric → compare all LR
        for bs, recs in by_bs.items():
            for m in metrics:
                plt.figure()
                for rec in recs:
                    arr = rec['eval'][m]
                    if args.window:
                        arr = np.convolve(
                            arr, np.ones(args.window)/args.window, mode='valid'
                        )
                    lr = rec['optim']['lr']
                    plt.plot(arr, label=f'lr={lr}')
                plt.xscale('log'); plt.yscale('log')
                plt.xlabel('Training steps')
                plt.ylabel(m)
                plt.xlim(left=len(arr) / 1000)
                plt.title(f"Metric={m} @ batch_size={bs}")
                plt.legend()
                fn = f"bs_{bs}_{m}_all_lr{('_'+args.suffix) if args.suffix else ''}.png"
                plt.savefig(out_dir / fn, bbox_inches='tight')
                plt.close()

        print(f"Saved grouped plots into {out_dir}")    
    
    # path = args.path
    
    # # use glob to list all files in the directory
    # files = list(pathlib.Path('.').glob(path))
    
    # print(f'Found {len(files)} files in {path}')
    
    # # there should be just one file containing 'eval', and one file containing 'param*'
    # # find them 
    # eval_files = [file for file in files if '/eval' in str(file)]
    # param_files = [file for file in files if '/param' in str(file)]
    
    # if len(eval_files) != 1:
    #     raise ValueError(f'Expected one eval file, found {len(eval_files)}')
    # if len(param_files) != 1:
    #     raise ValueError(f'Expected one param file, found {len(param_files)}')
    
    
    # # for file in files:
    # #     print(f'File: {file}')
    # #     file = str(file)
    # #     # if the file contains 'eval', it is a result file
    # #     if '/eval_' in file:
    # #         # load the file
    # eval_file = torch.load(eval_files[0], weights_only=False)
    # param_file = torch.load(param_files[0], weights_only=False)
    # # print the whole dictionnary
    # print_dict(eval_file)
    
    # if args.plot:
    #     e_keys = ['losses_batch', 'score_losses_batch', 'grad_norm']
    #     for key in e_keys:
    #         toplot = eval_file['eval'][key]
    #         if args.window is not None:
    #             toplot = np.convolve(toplot, np.ones(args.window)/args.window, mode='valid')
    #         # plot the results
    #         plt.figure()
    #         plt.plot(toplot, label=key)
    #         plt.yscale('log')
    #         plt.xscale('log')
    #         # set the x minimum to len(toplot) / 100
    #         plt.xlim(left=len(toplot) / 1000)
    #         plt.xlabel('Steps')
    #         plt.ylabel(key)
    #         plt.legend()
            
    #         # save file in output directory
    #         assert args.output is not None, 'Output directory must be specified'
    #         # prepend the working directory to the output path
    #         output_dir = pathlib.Path(args.output)
    #         output_dir.mkdir(parents=True, exist_ok=True)
            
    #         prefix = '_'.join([
    #             param_file['optim']['optimizer'],
    #             'lr_{}'.format(param_file['optim']['lr']),
    #             'bs_{}'.format(param_file['training']['batch_size']),
    #         ])
            
    #         output_file = output_dir / f'{prefix}_{key}_{args.suffix}.png'
    #         print(f'Output file: {output_file}')
    #         plt.savefig(output_file, bbox_inches='tight')
    
    
    # topo_path = args.topo_path                
    # if args.topo_path is not None:
    #     topo_files = list(pathlib.Path('.').glob(topo_path))
    #     print(f'Found {len(topo_files)} topo files in {topo_path}')
        
    #     # now select eval files
    #     files = [file for file in files if '/eval_' in str(file)]
    #     topo_files = [file for file in topo_files if '/eval_' in str(file)]
        
    #     # ensure that we only have one normal file and one topo file
    #     if len(topo_files) != 1:
    #         raise ValueError(f'Expected one topo file, found {len(topo_files)}')
    #     if len(files) != 1:
    #         raise ValueError(f'Expected one file, found {len(files)}')
    #     topo_file = topo_files[0]
    #     topo_file = str(topo_file)
    #     file = files[0]
    #     file = str(file)
        
    #     # torch.load each file
    #     eval_file = torch.load(file, weights_only=False)
    #     eval_topo_file = torch.load(topo_file, weights_only=False)
        
    #     # create output directory
    #     assert args.output is not None, 'Output directory must be specified'
    #     output_dir = pathlib.Path(args.output)
    #     output_dir.mkdir(parents=True, exist_ok=True)
    #     print(f'Output directory: {output_dir}')
        
    #     eval_keys = ['losses_batch', 'score_losses_batch', 'grad_norms']
    #     topo_keys = ['losses', 'score_losses', 'grad_norms']
        
    #     for e_key, t_key in zip(eval_keys, topo_keys):
    #         if e_key not in eval_file['eval']:
    #             raise ValueError(f'Key {e_key} not found in {file}')
    #         if t_key not in eval_topo_file['eval']:
    #             raise ValueError(f'Key {t_key} not found in {topo_file}')
            
    #         # associate losses in a single list
    #         losses = eval_file['eval'][e_key]
    #         topo_losses = eval_topo_file['eval'][t_key]
            
    #         total_losses = losses + topo_losses
            
    #         # average the losses with args.window, with numpy (valid mode)
    #         if args.window is not None:
    #             total_losses = np.convolve(total_losses, np.ones(args.window)/args.window, mode='valid')
            
    #         if args.plot:
    #             # plot the results
    #             plt.figure()
    #             plt.plot(total_losses[:len(losses)], label='pre-trained')
    #             plt.plot(total_losses[len(losses):], label='topo')
    #             plt.title(f'{e_key} vs {t_key}')
    #             plt.xlabel('Steps')
    #             plt.ylabel(t_key)
    #             plt.legend()
    #             plt.savefig(output_dir / f'{e_key}_{t_key}.png')
        
        
        
        
        
        

    


if __name__ == '__main__':
    read_results()