import os
import sys
import torch
import numpy as np

from pathlib import Path
from tqdm.auto import tqdm
from ema_pytorch import EMA
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from Engine.lr_sch import ReduceLROnPlateauWithWarmup

sys.path.append(os.path.join(os.path.dirname(__file__), '../'))


def cycle(dl):
    while True:
        for data in dl:
            yield data


class Trainer(object):
    def __init__(
        self,
        model, 
        data_loader, 
        results_folder='./Checkpoints', 
        train_lr=1e-5, 
        warmup_lr=1e-4, 
        train_num_steps=100000, 
        adam_betas=(0.9, 0.96), 
        gradient_accumulate_every=2, 
        ema_update_every=10,
        ema_decay=0.995,
        patience=1000, 
        min_lr=1e-5, 
        threshold=1e-1, 
        warmup=500,
        factor=0.5,
    ):
        super().__init__()
        self.model = model
        self.device = self.model.betas.device
        self.train_num_steps = train_num_steps
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_cycle = int(train_num_steps // 10)
        self.dataset = data_loader.dataset
        self.dl = cycle(data_loader)
        self.step = 0
        self.milestone = 0

        self.results_folder = Path(results_folder)
        os.makedirs(self.results_folder, exist_ok=True)

        self.opt = Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=train_lr, betas=adam_betas)
        self.sch = ReduceLROnPlateauWithWarmup(optimizer=self.opt, factor=factor, patience=patience, min_lr=min_lr, threshold=threshold,
                                               threshold_mode='rel', warmup_lr=warmup_lr, warmup=warmup, verbose=False)

        self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to(self.device)

    def save(self, milestone):
        data = {
            'step': self.step,
            'model': self.model.state_dict(),
            'ema': self.ema.state_dict(),
            'opt': self.opt.state_dict(),
        }
        torch.save(data, str(self.results_folder / f'checkpoint-{milestone}.pt'))

    def load(self, milestone):
        device = self.device
        data = torch.load(str(self.results_folder / f'checkpoint-{milestone}.pt'), map_location=device)
        self.model.load_state_dict(data['model'])
        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        self.ema.load_state_dict(data['ema'])
        self.milestone = milestone

    def train(self):
        device = self.device
        step = 0

        with tqdm(initial=step, total=self.train_num_steps) as pbar:
            while step < self.train_num_steps:
                total_loss = 0.
                for _ in range(self.gradient_accumulate_every):
                    data = next(self.dl).to(device)
                    loss = self.model(data, target=data)
                    loss = loss / self.gradient_accumulate_every
                    loss.backward()
                    total_loss += loss.item()

                pbar.set_description(f'loss: {total_loss:.6f}')

                clip_grad_norm_(self.model.parameters(), 1.0)
                self.opt.step()
                self.sch.step(total_loss)
                self.opt.zero_grad()
                self.step += 1
                step += 1
                self.ema.update()

                with torch.no_grad():
                    if self.step != 0 and self.step % self.save_cycle == 0:
                        self.milestone += 1
                        self.save(self.milestone)

                pbar.update(1)

        print('training complete')

    def sample(self, num, size_every, shape=None):
        samples = np.empty([0, shape[0], shape[1]])
        num_cycle = int(num // size_every) + 1

        for _ in range(num_cycle):
            sample = self.ema.ema_model.generate_mts(batch_size=size_every)
            samples = np.row_stack([samples, sample.detach().cpu().numpy()])
            torch.cuda.empty_cache()

        return samples

    def restore(self, raw_dataloader, shape=None, coef=1e-1, stepsize=1e-1, sampling_steps=50):
        model_kwargs = {}
        model_kwargs['coef'] = coef
        model_kwargs['learning_rate'] = stepsize
        samples = np.empty([0, shape[0], shape[1]])

        for idx, (x, t_m) in enumerate(raw_dataloader):
            x, t_m = x.to(self.device), t_m.to(self.device)
            if sampling_steps == self.model.num_timesteps:
                sample = self.ema.ema_model.sample_infill(shape=x.shape, target=x, partial_mask=t_m,
                                                          model_kwargs=model_kwargs)
            else:
                sample = self.ema.ema_model.fast_sample_infill(shape=x.shape, target=x, partial_mask=t_m, model_kwargs=model_kwargs,
                                                               sampling_timesteps=sampling_steps)

            samples = np.row_stack([samples, sample.detach().cpu().numpy()])

        return samples
    
    def forecasting(self, raw_dataloader, window, shape=None, coef=1e-1, stepsize=1e-1, sampling_steps=50):
        model_kwargs = {}
        model_kwargs['coef'] = coef
        model_kwargs['learning_rate'] = stepsize
        samples = np.empty([0, shape[0], shape[1]])

        for idx, x in enumerate(raw_dataloader):
            x = x.to(self.device)
            t_m = np.ones(x.shape)
            t_m[:, -window:, :] = 0
            t_m = t_m.astype(bool)
            t_m = torch.from_numpy(t_m).to(x.device)
            if sampling_steps == self.model.num_timesteps:
                sample = self.ema.ema_model.sample_infill(shape=x.shape, target=x, partial_mask=t_m,
                                                          model_kwargs=model_kwargs)
            else:
                sample = self.ema.ema_model.fast_sample_infill(shape=x.shape, target=x, partial_mask=t_m, model_kwargs=model_kwargs,
                                                               sampling_timesteps=sampling_steps)

            samples = np.row_stack([samples, sample.detach().cpu().numpy()])

        return samples
