import os.path
import warnings

import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
from traits.trait_types import self

from model import ModelWrapper
from sde_lib import *


class Flow(object):
    def __init__(self, model, trajectory, config, repara=None):
        self.trajectory = trajectory
        self.name = self.__class__.__name__
        self.name += f'_{self.trajectory}'
        self.config = config
        self.model = model

    def get_random_samples(self, x_0=None, batch_size=None):
        length_x_0 = len(x_0)

        if batch_size == length_x_0:
            return x_0

        indices_x_0 = torch.randint(0, length_x_0, (batch_size,), device=x_0.device)

        return x_0[indices_x_0]

    def get_train_tuple(self, x_0=None, t=None):
        raise NotImplementedError

    @torch.no_grad()
    def sample_ode(self, model, z0, N):
        raise NotImplementedError


    def model_prediction(self, x, c, t):
        return self.model.model(x, c, t)


    def train(self, optimizer, x_0, batchsize, inner_iters):
        raise NotImplementedError

class RectifiedFlow(Flow):
    def __init__(self, model, trajectory, config, reparam=None, conditional=False):
        super(RectifiedFlow, self).__init__(model, trajectory, config)
        if self.trajectory == 'vp':
            self.sde = VPSDE()
        elif self.trajectory == 'subvp':
            self.sde = subVPSDE()
        elif self.trajectory == 'linear':
            self.sde = Linear()
        else:
            raise NotImplementedError
        assert isinstance(self.sde, SDE)

        self.reparam = reparam
        if self.reparam is not None:
            self.name += f'_{self.reparam}'


    def pred(self, x, t, c=None):
        if not self.config.conditional_model:
            output = self.model(x, t)
        else:
            assert c is not None
            output = self.model(x, t, c)

        if self.config.repara is None:
            return output
        elif self.config.repara == 'x_pred':
            t = t.view(-1, *([1] * (x.dim() - 1)))  # (N, 1, 1, ..., 1)
            alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
            alpha_t, beta_t = schedule(t, self.trajectory)
            # output = self.model(x, t)
            noise = (x - alpha_t * output) / beta_t
            output = alpha_t_dot * output + beta_t_dot * noise
            return output
        elif self.config.repara == 'epsilon_pred':
            t = t.view(-1, *([1] * (x.dim() - 1)))  # (N, 1, 1, ..., 1)
            alpha_t_dot, beta_t_dot = abdot(t, self.trajectory)
            alpha_t, beta_t = schedule(t, self.trajectory)
            x_0 = (x - beta_t * output) / alpha_t
            output = alpha_t_dot * x_0 + beta_t_dot * output
            return output
        else:
            raise NotImplementedError

    def cfg_pred(self, x, t, c):
        cfg_weight = self.config.cfg_weight
        if cfg_weight == 1:
            return self.pred(x, t, c)
        else:
            return (cfg_weight * self.pred(x, t, c)
                    + (1 - cfg_weight) * self.pred(x, t, torch.zeros_like(c)))

    def get_train_tuple(self, z1=None, t=None):
        if t is None:
            t = torch.rand(z1.shape[0], 1).to(z1.device) * (1 - 1e-5)
        _t = t.view(-1, *([1] * (z1.dim() - 1)))  # (N, 1, 1, ..., 1)

        z0 = torch.randn_like(z1)
        alpha_t, beta_t = schedule(_t, self.trajectory)
        z_t = alpha_t * z1 + beta_t * z0
        alpha_t_dot, beta_t_dot = abdot(_t, self.trajectory)
        if self.config.repara == 'epsilon_pred':
            target = z0
        elif self.config.repara == 'x_pred':
            target = z1
        elif self.config.repara is None:
            target = alpha_t_dot * z1 + beta_t_dot * z0
        else:
            raise NotImplementedError
        return z_t, t, target

    def one_step_x0_pred(self, zt, t, c=None, predicted_target=None):
        _t = t.view(-1, *([1] * (zt.dim() - 1)))  # (N, 1, 1, ..., 1)
        alpha_t, beta_t = schedule(_t, self.trajectory)
        alpha_t_dot, beta_t_dot = abdot(_t, self.trajectory)

        if predicted_target is None:
            if self.config.conditional_model:
                pred = self.model(zt, t, c)
            else:
                pred = self.model(zt, t)
        else:
            pred = predicted_target

        if self.config.repara is None:
            # velocity to x_0
            return  (beta_t_dot * zt -beta_t * pred) / (alpha_t * beta_t_dot - alpha_t_dot * beta_t)
        elif self.config.repara == 'x_pred':
            return pred
        elif self.config.repara == 'epsilon_pred':
            raise NotImplementedError


    def sampling_step(self, z, t, dt, c=None, vis_x0_prediction=False,
                      start_noise=None, param=None):
        if self.config.sampling == 'vanilla':
            if not self.config.conditional_model:
                pred = self.pred(z, t, c=None)
            else:
                pred = self.cfg_pred(z, t, c)
            z = z.detach().clone() + pred * dt
        elif self.config.sampling == 'eci':
            _t = t.view(-1, *([1] * (z.dim() - 1)))  # (N, 1, 1, ..., 1)
            alpha_t, beta_t = schedule(_t, self.trajectory)

            _t = t.view(-1, *([1] * (z.dim() - 1)))  # (N, 1, 1, ..., 1)
            alpha_t_next, beta_t_next = schedule(_t+dt, self.trajectory)

            for i in range(self.config.eci_n_mix):
                if i < self.config.eci_n_mix - 1:
                    z0 = self.one_step_x0_pred(z, t, c)
                    eps = torch.randn_like(z0)
                    mask = torch.ones_like(z0)
                    mask = self.config.train_set.get_condition(mask, type=self.config.condition_type)
                    mask = mask > 0.5
                    z0[mask] = c[mask]
                    z = alpha_t * z0 + eps * beta_t
                else:
                    z0 = self.one_step_x0_pred(z, t, c)
                    eps = torch.randn_like(z0)
                    mask = torch.ones_like(z0)
                    mask = self.config.train_set.get_condition(mask, type=self.config.condition_type)
                    mask = mask > 0.5
                    z0[mask] = c[mask]
                    z = alpha_t_next * z0 + eps * beta_t_next
        elif self.config.sampling == 'impainting':
            noise = torch.randn_like(z)
            _t = t.view(-1, *([1] * (z.dim() - 1)))  # (N, 1, 1, ..., 1)
            alpha_t, beta_t = schedule(_t, self.trajectory)
            mask = torch.ones_like(z)
            mask = self.config.train_set.get_condition(mask, type=self.config.condition_type)
            mask = mask > 0.5
            z[mask] = alpha_t[0][0][0] * c[mask] + noise[mask] * beta_t[0][0][0]
            if not self.config.conditional_model:
                pred = self.pred(z, t, c=None)
            else:
                pred = self.cfg_pred(z, t, c)
            z = z.detach().clone() + pred * dt
        if vis_x0_prediction:
            z0 = self.one_step_x0_pred(z, t, c)
            from  matplotlib import pyplot as plt
            plt.imshow(z0.detach().cpu().numpy()[0][:,:,0])
            plt.axis('off')
            plt.savefig(f'./saved/ex_0_{t[0][0].item()}.png', bbox_inches='tight', pad_inches=0)
            plt.show()
            plt.close()
            from matplotlib import pyplot as plt
            plt.imshow(z.detach().cpu().numpy()[0][:,:,0], cmap='gray')
            plt.axis('off')
            plt.savefig(f'./saved/noise_{t[0][0].item()}.png', bbox_inches='tight', pad_inches=0)
            plt.show()
            plt.close()
        return z

    def sample_dflow(self, num_samples, num_steps, c=None,
                   return_trajectory=False, batch_size=None, return_z0=False,
                   param=None):
        assert c is not None
        ### NOTE: Use Euler method to sample from the learned flow
        resolution = self.config.train_set.resolution
        z0 = torch.randn((num_samples, resolution[0], resolution[1], resolution[2])).to(self.config.device)

        if batch_size is None: batch_size = z0.shape[0]

        z_final = []

        num_batches = z0.shape[0] // batch_size
        if num_batches == 0:
            num_batches = 1

        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            z0_batch = z0[start_idx:end_idx].detach().clone()
            c_batch = c[start_idx:end_idx]

            _param = {key: param[key][start_idx:end_idx] for key in param.keys()}

            mask = torch.ones_like(z0_batch)
            mask = self.config.train_set.get_condition(mask, type=self.config.condition_type)
            mask = mask > 0.5

            z = z0_batch.clone().detach().requires_grad_()

            def loss_fn(x0):
                x = x0
                # x = x0.clone().detach().requires_grad_(True)

                ts = torch.linspace(0, 1, num_steps + 1, device=x.device)
                for t in ts[:-1]:
                    t = torch.ones((x.shape[0],1)).to(x.device) * (1-1e-5) * t
                    with torch.no_grad():
                        vf = self.model(x, t, c_batch)
                    x = x + vf / num_steps
                loss = ((x - c_batch) * mask).square().sum()
                return x, loss

            cnt = 0
            def closure():
                nonlocal cnt
                cnt += 1
                optimizer.zero_grad()
                _, loss = loss_fn(z)
                loss.backward(retain_graph=False)
                print(f'Iter {cnt}: {loss.item():.4f}')
                return loss

            optimizer = torch.optim.LBFGS([z], max_iter=20, lr=1e-1)
            optimizer.step(closure)

            x1, _ = loss_fn(z)

            z_final.append(x1.detach())
            del z, z0_batch, c_batch, mask
        if return_z0:
            return torch.concatenate(z_final, dim=0), z0
        else:
            return torch.concatenate(z_final, dim=0)


    @torch.no_grad()
    def sample_ode(self, num_samples, num_steps, c=None,
                   return_trajectory=False, batch_size=None, return_z0=False,
                   param=None):
        ### NOTE: Use Euler method to sample from the learned flow
        resolution = self.config.train_set.resolution
        z0 = torch.randn((num_samples, resolution[0], resolution[1], resolution[2])).to(self.config.device)

        if batch_size is None: batch_size = z0.shape[0]

        dt = 1./num_steps
        traj = [] # to store the trajectory

        z_final = []
        z_trajectories = []

        num_batches = z0.shape[0] // batch_size
        if num_batches == 0:
            num_batches = 1

        # for z0_batch in z0_batches:
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            z0_batch = z0[start_idx:end_idx]
            c_batch = c[start_idx:end_idx]

            _param = {key: param[key][start_idx:end_idx] for key in param.keys()}

            z = z0_batch.detach().clone()
            if return_trajectory:
                traj.append(z.detach().clone())

            pbar = tqdm(total=num_steps, leave=False)
            for i in range(0, num_steps):
                t = torch.ones((z0_batch.shape[0],1)).to(z.device) * i / num_steps * (1-1e-5)
                z = self.sampling_step(z, t, dt, c_batch, start_noise=z0_batch,
                                       param=_param)
                pbar.update(1)

                if return_trajectory:
                    traj.append(z.detach().clone())

            pbar.close()

            z_final.append(z)
            if return_trajectory:
                z_trajectories.append(torch.stack(traj))

        if return_trajectory:
            if return_z0:
                return torch.concatenate(z_trajectories, dim=1), z0
            else:
                return torch.concatenate(z_trajectories, dim=1)
        else:
            if return_z0:
                return torch.concatenate(z_final, dim=0), z0
            else:
                return torch.concatenate(z_final, dim=0)

    def get_loss(self, z_t, t, target, condition, param):
        mode = self.config.loss_mode
        if mode == 'vanilla':
            predicted_target = self.model(z_t, t, condition)
            loss = (predicted_target - target).pow(2).mean()

            x0_prediction = self.one_step_x0_pred(z_t, t, predicted_target=predicted_target)
            pde_loss = self.config.train_set.compute_pde_error(x0_prediction, **param)
            pde_loss = (pde_loss ** 2).mean() * self.config.pde_loss_weight
            loss_record = {'flow matching loss': loss.item(),
                           'pde loss': pde_loss.item()}
        elif mode == 'pde_loss':
            predicted_target = self.model(z_t, t, condition)
            fm_loss = (predicted_target - target).pow(2).mean()

            x0_prediction = self.one_step_x0_pred(z_t, t, predicted_target=predicted_target)
            pde_loss = self.config.train_set.compute_pde_error(x0_prediction, **param)
            pde_loss = (pde_loss ** 2).mean() * self.config.pde_loss_weight
            loss = fm_loss + pde_loss
            loss_record = {'flow matching loss': fm_loss.item(),
                           'pde loss': pde_loss.item()}
        else:
            raise NotImplementedError
        return loss, loss_record

    def train(self, optimizer, dataloader):
        loss_curve = []
        inner_iters = self.config.iterations
        current_iteration = self.config.current_iteration

        if current_iteration not in [None, 0]:
            self.model.load_state_dict(torch.load(self.config.model_path + f'_iter_{current_iteration}',
                                                  map_location=self.config.device)['model'])
            optimizer.load_state_dict(torch.load(self.config.model_path + f'_iter_{current_iteration}',
                                                 map_location=self.config.device)['optimizer'])
            loss_curve = torch.load(self.config.model_path + f'_iter_{current_iteration}',
                                                 map_location=self.config.device)['loss']

        pbar = tqdm(total=inner_iters, leave=False)
        pbar.update(current_iteration)
        for i in range(current_iteration, inner_iters + 1):
            optimizer.zero_grad()

            _x_0, param = next(dataloader)
            _x_0 = _x_0.to(self.config.device)
            if self.config.conditional_model:
                condition = self.config.train_set.get_condition(_x_0, self.config.condition_type)
                mask = torch.rand(condition.shape[0]) < self.config.uncondition_ratio
                condition[mask] = torch.zeros_like(condition[mask])
            else:
                condition = None

            z_t,  t, target, = self.get_train_tuple(z1=_x_0)

            loss, loss_record = self.get_loss(z_t, t, target, condition, param)
            loss.backward()

            optimizer.step()
            loss_curve.append(loss.item())  ## to store the loss curve
            pbar.update(1)
            if i % 500 ==0:
                print(f'Iteration {i}: Loss: {loss_record}')
            if i % self.config.save_freq == 0 and i != 0:
                torch.save({
                    'loss': loss_curve,
                    'model': self.model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, self.config.model_path + f'_iter_{i}')
        pbar.close()

        return loss_curve