from torch.utils.data import DataLoader
import os

from manage.files import FileHandler
from data.data import get_dataset, CurrentDatasetInfo, Modality, StateSpace
from manage.logger import Logger
from manage.generation import GenerationManager
from manage.training import TrainingManager
from evaluate.EvaluationManager import EvaluationManager
from manage.checkpoints import load_experiment, save_experiment
from manage.setup import _get_device, _optimize_gpu, _set_seed

from ddpm_init import init_method_ddpm, init_models_optmizers_ls, init_learning_schedule, init_optimizer

from script_utils import *
from torch.utils.data import Subset, Dataset
import torch
from torch import autocast
from torch.cuda.amp import GradScaler
from contextlib import nullcontext


CONFIG_PATH = './configs/'

def compute_gradient_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm
    
def run_exp(config_path):
    args = parse_args()
    
    
    # open and get parameters from file
    p = FileHandler.get_param_from_config(config_path, args.config + '.yml')
    update_parameters_before_loading(p, args)

    trainer, logger, file_handler, models, optimizers, learning_schedules, method, eval, gen_manager = initialize_experiment(p)
    
    # print the number of parameters in the models
    print('Number of parameters in the models:')
    for name, model in models.items():
        print(f'{name}: {sum(p.numel() for p in model.parameters())}')
    
    
    if args.from_pretrained is not None:
        # load pretrained weights
        print('Loading pretrained weights from {}'.format(args.from_pretrained))
        trainer.load(args.from_pretrained)
        # reinitialize training steps
        # trainer.total_steps = 0
        # trainer.epochs = 0
        # assumes pre-training was done with a different optimizer
        print('Reintializing optimizer (assuming pre-training was done with a different optimizer)')
        for name in optimizers:
            optimizers[name] = init_optimizer(p, models[name])
        trainer.optimizers = optimizers
        # reinitialize learning schedule
        for name in learning_schedules:
            learning_schedules[name] = init_learning_schedule(p, optimizers[name])
        trainer.learning_schedules = learning_schedules
    else:
        # Specify directory to load checkpoints
        checkpoint_dir = 'checkpoints'
        save_dir = os.path.join(checkpoint_dir, args.name)
        if args.latest_checkpoint:
            print('Loading latest checkpoint from {}'.format(save_dir))
            checkpoint_steps = None
        else:
            print('Loading checkpoint {} from {}'.format(args.steps, save_dir))
            checkpoint_steps = args.steps
        # assert args.latest_checkpoint, 'latest_checkpoint must be True: must run from latest checkpoint'
        # selected_lr = p['optim']['lr']
        # selected_bs = p['training']['batch_size']
        # # force the model to use some pretrained weights
        # print('Using pre-trained weights, lr = {}, bs = {}'.format(0.0001, 128))
        # p['optim']['lr'] = 0.0001
        # p['training']['batch_size'] = 128
        
        load_experiment(
                p=p,
                trainer=trainer,
                fh = file_handler,
                save_dir=save_dir,
                checkpoint_steps=checkpoint_steps,
                )
        
        # revert to the original parameters
        # p['optim']['lr'] = selected_lr
        # p['training']['batch_size'] = selected_bs
    
    # update parameters after loading, like new optim learning rate...
    update_experiment_after_loading(p, 
        optimizers,
        learning_schedules,
        init_learning_schedule,
        args,
        update_optimizer_lr=True,
    )
    
    # log some additional information
    additional_logging(p,
        logger,
        trainer,
        file_handler,
        args
    ) 
    
    # print parameters to stdout
    print_dict(p)
    
    if args.topo_save_dir is not None:
        print('Saving topological bounds in {}'.format(args.topo_save_dir))
        topo_save_dir = args.topo_save_dir
    else:
        print('Saving topological bounds in {} (default)'.format('topological_bounds'))
        topo_save_dir = 'topological_bounds'
    
    save_path = os.path.join(
        topo_save_dir, 
        '_'.join(
            ['nsamples_{}'.format(p['data']['n_samples']), 
            'lr_{}'.format(p['optim']['lr']), 
            'bs_{}'.format(p['training']['batch_size'])
                ]
            )
        )

    if args.load_topo_run:
        # load the model from the checkpoint
        print('Loading model from checkpoint', save_dir)
        topological_losses = load_experiment(
                p=p,
                trainer=trainer,
                fh = file_handler,
                save_dir=save_path,
                checkpoint_steps=None,
                get_topological_losses=True,
                )
        assert topological_losses is not None, 'topological losses are None'
    else:
            
        
        TOTAL_ITERATIONS = args.topo_total_iterations # 5000
        SUBSET_DATASET_SIZE = 3000

        subset_dataset = Subset(trainer.data.dataset, range(min(SUBSET_DATASET_SIZE, len(trainer.data.dataset))))
        
        subset_data_test = DataLoader(subset_dataset, 
                            batch_size=p['eval']['batch_size'], # faster for testing
                            shuffle=False,
                            num_workers=p['training']['num_workers'])

        subset_data_losses = torch.zeros(TOTAL_ITERATIONS, len(subset_dataset))
        subset_data_score_losses = torch.zeros(TOTAL_ITERATIONS, len(subset_dataset))
        
        grad_norms = torch.zeros(TOTAL_ITERATIONS)
        
        # set the random seed for reproducibility
        _set_seed(0, p['device'])
        
        # get a predefined set of random time and Gaussian noise to use for the training, such that they are the same for all iterations
        given_t = torch.randint(1, trainer.method.reverse_steps, size=(len(subset_dataset),)).to(trainer.method.device)
        given_z = torch.randn(len(subset_dataset), *trainer.data.dataset[0][0].shape).to(trainer.method.device)
        
        # check if given_t and given_z exist in the current folder
        # if they do, load them
        # path_t = os.path.join('.', 'given_t.pt')
        # path_z = os.path.join('.', 'given_z.pt')
        # if os.path.exists(path_t) and os.path.exists(path_z):
        #     print('Loading given_t and given_z from {} and {}'.format(path_t, path_z))
        #     tmp_given_t = torch.load(path_t)
        #     tmp_given_z = torch.load(path_z)
        #     # check that they match our given_t and given_z, thus that the seed is the same
        #     if (tmp_given_t == given_t).all() and (tmp_given_z == given_z).all():
        #         print('given_t and given_z match, using them')
        #     else:
        #         print('given_t and given_z do not match, exiting')
        #         exit()
        # else:
        #     print('Saving given_t and given_z to {} and {}'.format(path_t, path_z))
        #     torch.save(given_t, path_t)
        #     torch.save(given_z, path_z)
        
        
        
        # --- before your loop: decide if we use FP16 ---
        use_fp16 = method.device.type in {'cuda'}  #, 'mps'}
        use_fp16 = use_fp16 and (args.fp16)
        # autocast context for cuda or mps, else a no‐op
        fp16_ctx = (autocast(device_type=method.device.type, dtype=torch.float16)
                    if use_fp16 else nullcontext())
        
        scaler = GradScaler('cuda') if use_fp16 else None
            
        
        train_loader_iter = iter(trainer.data)
        
        from tqdm import tqdm
        for total_steps in tqdm(range(TOTAL_ITERATIONS)):
            # first, compute loss on subset_data
            
            for name, model in trainer.models.items():
                model.eval()
            with torch.inference_mode():
                start = 0
                for i, Xbatch in enumerate(subset_data_test):
                    
                    kwargs = {}
                    Xbatch, y = Xbatch
                    kwargs['y'] = y
                    
                    end = start + len(Xbatch)
                    
                    mc = 1
                    loss = torch.zeros(mc, len(Xbatch)).to(Xbatch.device)
                    score_loss = torch.zeros(mc, len(Xbatch)).to(Xbatch.device)
                    for m in range(mc):
                        with fp16_ctx:
                            training_results = trainer.method.training_losses(trainer.models, 
                                                                        Xbatch, 
                                                                        given_z = given_z[start:end],
                                                                        given_t = given_t[start:end],
                                                                        **kwargs)
                        loss[m] = training_results['loss']
                        score_loss[m] = training_results['score_loss']
                    loss = loss.mean(dim=0)
                    score_loss = score_loss.mean(dim=0)
                    if len(loss.shape) == 1:
                        subset_data_losses[total_steps, start:end] = loss.cpu()
                        subset_data_score_losses[total_steps, start:end] += score_loss.cpu()
                    else:
                        subset_data_losses[total_steps, start:end] = loss.cpu().mean(dim = list(range(1, len(loss.shape))))
                        subset_data_score_losses[total_steps, start:end] += score_loss.cpu().mean(dim = list(range(1, len(score_loss.shape))))
                    start = end
            
            # second, update model
            for name, model in trainer.models.items():
                model.train()
            # and finally gradient descent
            
            try:
                Xbatch, y = next(train_loader_iter)
            except StopIteration:
                train_loader_iter = iter(trainer.data)
                Xbatch, y = next(train_loader_iter)

            kwargs = {}
            kwargs['y'] = y
            
            for name in trainer.models:
                trainer.optimizers[name].zero_grad()
                
            with fp16_ctx:
                training_results = trainer.method.training_losses(trainer.models, Xbatch, **kwargs)
            loss = training_results['loss'].mean()
            score_loss = training_results['score_loss'].mean()
            # and finally gradient descent
            
            if scaler is not None:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            for name in trainer.models:
                if scaler is not None:
                    scaler.step(trainer.optimizers[name])
                else:
                    trainer.optimizers[name].step()
                if trainer.exists_ls(name):
                    trainer.learning_schedules[name].step()
            
            if scaler is not None:
                scaler.update()
            
            grad_norms[total_steps] = compute_gradient_norm(trainer.models['default'])
            
            # update ema models
            if trainer.ema_objects is not None:
                for e in trainer.ema_objects:
                    for name in trainer.models:
                        e[name].update(trainer.models[name])
            
            if (total_steps % 100 == 0) and (total_steps > 0):
                # print('step {}/{}'.format(total_steps, TOTAL_ITERATIONS), end =' ; ')
                # save every eval every 100 steps
                topological_losses = {
                    'losses': subset_data_losses, 
                    'score_losses': subset_data_score_losses,
                    'grad_norms': grad_norms,
                    }     
                    
                paths = save_experiment(
                        p=p,
                        trainer = trainer,
                        fh = file_handler,
                        save_dir = save_path,
                        files=['eval', 'param'], 
                        new_eval_subdir=False, 
                        checkpoint_steps=trainer.total_steps,
                        topological_losses=topological_losses,
                        )
            topological_losses = {
                'losses': subset_data_losses,
                'score_losses': subset_data_score_losses,
                'grad_norms': grad_norms,
                } 
            
    # then, run evaluation
    trainer.evaluate(evaluate_emas=False)
    trainer.evaluate(evaluate_emas=True)
    
    paths = save_experiment(
            p=p,
            trainer = trainer,
            fh = file_handler,
            save_dir = save_path,
            files='all', # ['eval', 'param'], 
            new_eval_subdir=False, 
            checkpoint_steps=trainer.total_steps,
            topological_losses=topological_losses,
            )
    
    print('Saved files in ', paths)
        
        
    # save subset_data_losses and subset_data_score_losses in .pt files, in a folder called topological_bounds/nsamples_{}_lr_{}_bs_{}
    # format the path

    # os.makedirs(save_path, exist_ok=True)
    # torch.save(subset_data_losses, os.path.join(save_path, 'iter_{}_subset_{}_losses.pt'.format(TOTAL_ITERATIONS, SUBSET_DATASET_SIZE)))
    # torch.save(subset_data_score_losses, os.path.join(save_path, 'iter_{}_subset_{}_score_losses.pt'.format(TOTAL_ITERATIONS, SUBSET_DATASET_SIZE)))



if __name__ == '__main__':
    run_exp(CONFIG_PATH)

