import numpy as np
import torch
import matplotlib.pyplot as plt
import copy
import torch.nn as nn
from tqdm import tqdm
from .utils_ema import EMAHelper
import torch
from torch import autocast
from torch.amp import GradScaler
from contextlib import nullcontext

#Wrapper around training and evaluation functions

def compute_layer_norms(model):
    layer_norms = []
    for name, param in model.named_parameters():
        if param.requires_grad:  # Ensure the parameter is trainable
            norm = torch.norm(param)  # Frobenius norm by default
            # layer_norms.append((name, norm))
            layer_norms.append(norm)
    return torch.stack(layer_norms).detach()


class TrainingManager:
    
    def __init__(self, 
                 models, # dictionnary with model name as key. Must contain 'default' model
                 data,
                 method, 
                 optimizers,
                 learning_schedules,
                 eval,
                 logger = None,
                 ema_rates = None,
                 reset_models = None,
                 p = None,
                 fp16 = False,
                 **kwargs):
        self.epochs = 0
        self.total_steps = 0
        self.models = models
        self.data = data
        self.method = method
        self.optimizers = optimizers
        self.learning_schedules = learning_schedules
        self.eval = eval
        self.reset_models = reset_models
        self.p = p
        self.fp16 = fp16
        
        if (ema_rates is None):
            self.ema_objects = None
        else:
            #self.ema_models = [ema.EMAHelper(model, mu = mu) for mu in ema_rates]
            logger = eval.logger
            # need to set logger to None for the eval deepcopy
            eval.logger = None
            self.ema_objects = []
            for mu in ema_rates:
                ema_dict = {name: EMAHelper(model, mu = mu) for name, model in self.models.items()}
                ema_dict['eval'] = copy.deepcopy(eval)
                self.ema_objects.append(ema_dict)
            eval.logger = logger 
            for ema_object in self.ema_objects:
                ema_object['eval'].logger = logger
        
        self.kwargs = kwargs
        self.logger = logger
    
    def exists_ls(self, name = 'default'):
        return (name in self.learning_schedules) and (self.learning_schedules[name] is not None)
    
    def train(self, **kwargs):
        tmp_kwargs = copy.deepcopy(self.kwargs)
        tmp_kwargs.update(kwargs)

        def epoch_callback(epoch_loss, score_epoch_loss, models):
            self.eval.register_epoch_loss(epoch_loss, score_epoch_loss)
            self.eval.register_grad_norm(models)
            if self.logger is not None:
                self.logger.log('current_epoch', self.epochs)
        
        def batch_callback(batch_loss, score_batch_loss, models):
            # self.eval.register_grad_norm(models)
            # batch_loss is nan, reintialize models
            if np.isnan(batch_loss): 
                print('nan in loss detected, reinitializing models...')
                models, optimizers, learning_schedules = self.reset_models(self.p)
                self.models = models
                self.optimizers = optimizers
                self.learning_schedules = learning_schedules
            
            self.eval.register_batch_loss(batch_loss, score_batch_loss)
            if self.logger is not None:
                self.logger.log('current_batch', self.total_steps)
            
            # layer_norms = compute_layer_norms(self.models['default'])
            # print('layer norms: min, max, mean, std', torch.min(layer_norms).item(), torch.max(layer_norms).item(), torch.mean(layer_norms).item(), torch.std(layer_norms).item())
        
        self._train_epochs(epoch_callback=epoch_callback, 
                           batch_callback=batch_callback,
                           **tmp_kwargs)
    
    def _train_epochs(
                self,
                total_steps,
                eval_freq = None,
                checkpoint_freq= None,
                checkpoint_callback=None,
                no_ema_eval = True,
                grad_clip = None,
                batch_callback = None,
                epoch_callback = None,
                max_batch_per_epoch = None,
                progress = False,
                refresh_data = False,
                dataset_with_labels = False,
                stop_lower_loss_threshold = None,
                **kwargs):
        
        for name, model in self.models.items():
            model.train()
        
        print('training model to step {} from step {}'.format(total_steps, self.total_steps), '...')
        
        
        num_iter_until_loss_threshold_break = 200
        val_losses = [1000]*num_iter_until_loss_threshold_break
        # count_iter_until_loss_threshold_break = 0
        
        # create tqdm iterator going from self.total_steps to total_steps
        tqdm_iterator = tqdm(range(self.total_steps, total_steps), desc='Training', unit='step')
        
        
        # --- before your loop: decide if we use FP16 ---
        use_fp16 = self.method.device.type in {'cuda'}  #, 'mps'}
        use_fp16 = use_fp16 and self.fp16
        # autocast context for cuda or mps, else a no‐op
        fp16_ctx = (autocast(device_type=self.method.device.type, dtype=torch.float16)
                    if use_fp16 else nullcontext())
        
        print('using fp16: {}'.format(use_fp16))
        # setup GradScaler
        scaler = GradScaler('cuda') if use_fp16 else None
        
        while self.total_steps < total_steps:
            epoch_loss     = torch.zeros((), device=self.method.device)
            epoch_score_loss    = torch.zeros_like(epoch_loss)
            steps = torch.zeros_like(epoch_loss)
            # epoch_loss = steps = 0
            # epoch_score_loss = 0
            # set up a list in the device to hold batch information
            batch_losses = torch.zeros((len(self.data),)).to(self.method.device)
            batch_score_losses = torch.zeros((len(self.data),)).to(self.method.device)
            for i, Xbatch in enumerate(tqdm(self.data) if progress else self.data):
                if max_batch_per_epoch is not None:
                    if i >= max_batch_per_epoch:
                        break
                
                if dataset_with_labels:
                    Xbatch, y = Xbatch
                    kwargs['y'] = y
                
                for name in self.models:
                    self.optimizers[name].zero_grad()
                
                # 4. forward + backward in mixed precision
                with fp16_ctx:
                    training_results = self.method.training_losses(self.models, Xbatch, **kwargs)
                    loss = training_results['loss'].mean()
                    score_loss = training_results['score_loss'].mean()
                
                # check if nan in loss
                if torch.isnan(loss):
                    print('nan in loss detected. Saving training_results and breaking loop...')
                    torch.save(training_results, 'nan_training_results.pth')
                    return
                    # print('nan in loss detected, resetting models...')
                    # self.models, self.optimizers, self.learning_schedules = self.reset_models(self.p)
                    # print('models reset.')    
                    # print('nan in loss detected, skipping batch...')
                    # for name, model in self.models.items():
                    #     model.zero_grad()
                    continue
                
                
                # and finally gradient descent
                    
                if scaler is not None:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                for name in self.models:
                    if grad_clip is not None:
                        if scaler is not None:
                            scaler.unscale_(self.optimizers[name]) # has been recorded; won't be called again in scaler.step
                        nn.utils.clip_grad_norm_(self.models[name].parameters(), grad_clip)
                    if scaler is not None:
                        scaler.step(self.optimizers[name])
                    else:
                        self.optimizers[name].step()
                    # self.optimizers[name].step()
                    if self.exists_ls(name):
                        self.learning_schedules[name].step()
                
                if scaler is not None:
                    scaler.update()
                
                # update ema models
                if self.ema_objects is not None:
                    for e in self.ema_objects:
                        for name in self.models:
                            e[name].update(self.models[name])
                
                epoch_loss += loss
                epoch_score_loss += score_loss
                steps += 1
                self.total_steps += 1
                # update tqdm iterator every 10 steps
                if self.total_steps % 10 == 0:
                    tqdm_iterator.update(10)
                
                if self.total_steps % 1000 == 0:
                    print('step {}/{}'.format(self.total_steps, total_steps), end =' ; ')
                    print('batch loss', loss.item(), 'score loss', score_loss.item())
                
                batch_losses[i] = loss.detach().clone()
                batch_score_losses[i] = score_loss.detach().clone()
                
                
                # now potentially checkpoint
                if (checkpoint_freq is not None) and (self.total_steps % checkpoint_freq) == 0:
                    checkpoint_callback(self.total_steps)
                    #print(self.save(curr_epoch=self.manager.training_epochs()))
                
                # now potentially eval
                if (eval_freq is not None) and  (self.total_steps % eval_freq) == 0:
                    self.evaluate(evaluate_emas= not no_ema_eval)
                
                if self.total_steps >= total_steps:
                    # rather break than return, to compute some final epoch statistics
                    break
                
                # if lower loss than threshold, stop training
                if stop_lower_loss_threshold is not None:
                    val_losses[self.total_steps % num_iter_until_loss_threshold_break] = loss.item()
                    avg_val_losses = np.mean(val_losses)
                    if avg_val_losses < stop_lower_loss_threshold:
                        print('loss lower than threshold for {} iterations, stopping training'.format(num_iter_until_loss_threshold_break))
                        return
                        
            
            # register batch values
            if batch_callback is not None:
                for i in range(len(self.data)):
                    batch_callback(batch_losses[i].item(), 
                                   batch_score_losses[i].item(), 
                                   models=self.models)
            
            epoch_loss = epoch_loss / steps
            epoch_score_loss = epoch_score_loss / steps
            
            tqdm_iterator.set_postfix({'loss': epoch_loss.item(), 
                                       'score_loss': epoch_score_loss.item(), 
                                       'grad_norm': self.eval.evals['grad_norm'][-1].item() \
                                           if len(self.eval.evals['grad_norm']) > 0 \
                                               else None})
            
            
            self.epochs += 1
            if epoch_callback is not None:
                epoch_callback(epoch_loss.item(), epoch_score_loss.item(), models=self.models)

            
            if refresh_data:
                self.data.dataset.refresh_data()
            


    def evaluate(self, evaluate_emas = False, **kwargs):
        def ema_callback_on_logging(logger, key, value):
            if not (key in ['losses', 'losses_batch']):
                logger.log('_'.join(('ema', str(ema_obj['default'].mu), str(key))), value)
        
        if not evaluate_emas:
            print('evaluating non-ema model')
            for name, model in self.models.items():
                model.eval()
            with torch.inference_mode():
                self.eval.evaluate_model(self.models, **kwargs)
        elif self.ema_objects is not None:
            for ema_obj in self.ema_objects:
                models = {name: ema_obj[name].get_ema_model() for name in self.models}
                for name in models:
                    models[name].eval()
                with torch.inference_mode():
                    print('evaluating ema model with mu={}'.format(ema_obj['default'].mu))
                    ema_obj['eval'].evaluate_model(models, callback_on_logging = ema_callback_on_logging, **kwargs)
    

    def load(self, filepath):
        # provide key rather than src[key] in case we load an old run that did not contain any vae key
        def safe_load_state_dict(dest, src):
            if dest is not None:
                dest.load_state_dict(src)
        
        checkpoint = torch.load(filepath, map_location=torch.device(self.method.device))
        self.total_steps = checkpoint['steps']
        self.epochs = checkpoint['epoch']
        for name in self.models:
            chckpt_model_name = 'model_{}_parameters'.format(name) if name != 'default' else 'model_parameters' # retro-compatibility: 'default' becomes ''
            chckpt_optim_name = 'optimizer_{}'.format(name) if name != 'default' else 'optimizer'
            chckpt_ls_name = 'learnin_schedule_{}'.format(name) if name != 'default' else 'learning_schedule'
            chckpt_ema_name = 'ema_models_{}'.format(name) if name != 'default' else 'ema_models'
            print('loading model {} from checkpoint'.format(name))
            safe_load_state_dict(self.models[name], checkpoint[chckpt_model_name])
            print('loading optimizer {} from checkpoint'.format(name))
            safe_load_state_dict(self.optimizers[name], checkpoint[chckpt_optim_name])
            print('loading learning schedule {} from checkpoint'.format(name))
            safe_load_state_dict(self.learning_schedules[name], checkpoint[chckpt_ls_name])
            if self.ema_objects is not None:
                assert chckpt_ema_name in checkpoint, 'no ema model in checkpoint'
                for ema_obj, ema_state in zip(self.ema_objects, checkpoint[chckpt_ema_name]):
                    print('loading ema model {} with mu={} from checkpoint'.format(name, ema_obj[name].mu))
                    safe_load_state_dict(ema_obj[name], ema_state)

            
    def save(self, filepath):
        def safe_save_state_dict(src):
            return src.state_dict() if src is not None else None
        checkpoint = {
            'epoch': self.epochs,
            'steps': self.total_steps,
        }
        for name in self.models:
            chckpt_model_name = 'model_{}_parameters'.format(name) if name != 'default' else 'model_parameters' # retro-compatibility: 'default' becomes ''
            chckpt_optim_name = 'optimizer_{}'.format(name) if name != 'default' else 'optimizer'
            chckpt_ls_name = 'learnin_schedule_{}'.format(name) if name != 'default' else 'learning_schedule'
            chckpt_ema_name = 'ema_models_{}'.format(name) if name != 'default' else 'ema_models'
            checkpoint[chckpt_model_name] = safe_save_state_dict(self.models[name])
            checkpoint[chckpt_optim_name] = safe_save_state_dict(self.optimizers[name])
            checkpoint[chckpt_ls_name] = safe_save_state_dict(self.learning_schedules[name])
            if self.ema_objects is not None:
                checkpoint[chckpt_ema_name] = [safe_save_state_dict(ema_obj[name]) for ema_obj in self.ema_objects]

        torch.save(checkpoint, filepath)
    
    def save_eval_metrics(self, eval_path, topological_losses=None):
        eval_save = {'eval': self.eval.evals}
        if self.ema_objects is not None:
            eval_save.update({'ema_evals': [(ema_obj['eval'].evals, ema_obj['default'].mu) for ema_obj in self.ema_objects]})
        if topological_losses is not None:
            eval_save.update({'topological': topological_losses})
        
        
        torch.save(eval_save, eval_path)
    
    def load_eval_metrics(self, eval_path, get_topological_losses=False):
        print('loading eval metrics from {}'.format(eval_path))
        eval_save = torch.load(eval_path, weights_only=False)
        assert 'eval' in eval_save, 'no eval subdict in eval file'
        # load eval metrics
        self.eval.evals.update(eval_save['eval'])
        self.eval.log_existing_eval_values(folder='eval')

        # load ema eval metrics
        if not 'ema_evals' in eval_save:
            return
        assert self.ema_objects is not None
        # saved ema evaluation, in order
        saved_ema_evals = [ema_eval_save for ema_eval_save, mu_save in eval_save['ema_evals']]
        # saved ema mu , in order
        saved_mus = [mu_save for ema_eval_save, mu_save in eval_save['ema_evals']]
        
        for ema_obj in self.ema_objects:
            # if mu has not been run previously, no loading
            if ema_obj['default'].mu not in saved_mus:
                continue
            # find index of our mu of interest
            idx = saved_mus.index(ema_obj['default'].mu)
            # load the saved evaluation
            ema_obj['eval'].evals.update(saved_ema_evals[idx])
            # log the saved evaluation
            ema_obj['eval'].log_existing_eval_values(folder='eval_ema_{}'.format(ema_obj['default'].mu))
        
        if get_topological_losses:
            # load topological losses
            assert 'topological' in eval_save.keys(), 'no topological losses in eval file'
            print('loading topological losses from eval file')
            return eval_save['topological']        
        
        return None
    