import copy
import warnings
from typing import List, Optional
import torch
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf
from tgm.eval.metrics import mmd_metric, sinkhorn_dist
from tgm.train.callbacks import Callback
import math

class Trainer:
    def __init__(self, train_cfg: DictConfig, model, optimizer, train_loader, val_sub = None, val_full = None,
                 callbacks: Optional[List[Callback]] = None):
        self.cfg = train_cfg
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_sub = val_sub
        self.val_full = val_full
        self._callbacks: List[Callback] = callbacks or []
        

    def train(self):
        
        global_step = 0
        best_mmd = float("inf")
        state_dict_best_mmd = None
        
        self._cb("on_train_start", no_epochs=self.cfg.no_epochs)
        
        for i in range(self.cfg.no_epochs):
            
            epoch_loss = 0.0
            
            for j, batch in enumerate(self.train_loader):
                global_step += 1
                loss = self._step(batch)
                epoch_loss = (j / (j + 1)) * epoch_loss + (1 / (j + 1)) * loss
                
                self._cb("on_step_end", step=global_step, loss=loss)

            self._cb("on_epoch_end", step=global_step, epoch = i, loss=epoch_loss)
    
            mmd, sinkhorn, trajectories, times = self._validate()
            new_best = mmd < best_mmd   
            if new_best:
                best_mmd = mmd
                state_dict_best_mmd = copy.deepcopy(self.model.state_dict())
                
            self._cb("on_validation_end", step=global_step, mmd=mmd, sinkhorn = sinkhorn,
                     trajectories = trajectories, times = times, new_best = new_best) 
                
        self._cb("on_train_end")
        
        if state_dict_best_mmd is not None:
            self.model.load_state_dict(state_dict_best_mmd)
                
        return best_mmd

    @torch.no_grad()
    def _validate(self):
        self.model.eval()
        
        if self.cfg.regular_val_set_provided:
            x0 = torch.tensor(OmegaConf.to_container(self.cfg.x_start), dtype=torch.float32)
            val_samples = self.val_full["x"]
            samples, times, _ = self.model.sample_unif(x0, self.cfg.no_bridges, self.cfg.t_start, self.cfg.t_end, self.cfg.stepsize, no_samples = val_samples.shape[0])
            trajectory_length = samples.shape[1]
            assert trajectory_length == round((self.cfg.t_end - self.cfg.t_start) / self.cfg.stepsize + 1)
            equidistant_steps = self.cfg.val_trajectory_length - 1
            assert int(trajectory_length - 1) % equidistant_steps == 0 # not ideal check
            stepsize_subsampling = int((trajectory_length - 1) / equidistant_steps)
            
            samples_val_grid = samples[:, torch.arange(0, trajectory_length, stepsize_subsampling), :] #potentially just interpolate if grid cannot be aligned
            
            mmd = mmd_metric(samples_val_grid, val_samples)
            sinkhorn = sinkhorn_dist(samples_val_grid, val_samples)
        
        return mmd, sinkhorn, samples_val_grid, times[:, torch.arange(0, trajectory_length, stepsize_subsampling)]
        #TODO only mmd
        # complete trajectory/marginal mmd works only for uniform timegrid
        # but mse is correlated an works for arbitray timegrids
        # so choose mse for stopping?!

    def _step(self, batch):
        
        self.model.train()
        self.optimizer.zero_grad()
        loss = self.model.loss(batch)
        if not torch.isfinite(loss):
            print("Non-finite loss, skipping step.")
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
        
        return loss.item()
    
    def _cb(self, name: str, **kw):
        for cb in self._callbacks:
            fn = getattr(cb, name, None)
            if fn is None:
                continue
            try:
                fn(trainer=self, **kw)
            except Exception as e:
                warnings.warn(f"Callback {cb.__class__.__name__}.{name} failed: {e}")
