import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from typing import Tuple

import scipy
import numpy as np

from torch_geometric.utils import dense_to_sparse


class Diffusion(nn.Module):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        assert config.diffusion in ['ve', 'vp'], 'diffusion noise scheduler not implemented'

        if config.loss.startswith('implicit_energy'):
            self.loss = 'implicit_energy'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss.startswith('explicit_energy'):
            self.loss = 'explicit_energy'
            self.loss_weight = list(map(float, config.loss.split('_')[2:]))
        elif config.loss.startswith('momentum'):
            self.loss = 'momentum'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss.startswith('jensen'):
            self.loss = 'jensen'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss == 'naive':
            self.loss = 'naive'
        else:
            raise NotImplementedError
            

        self.steps = 1000
        self.config = config
        self.device = config.device

        if config.diffusion == 've':
            self.alpha_t = torch.ones(size=(self.steps, ), device=self.device)
            self.sigma_t = torch.sigmoid(torch.linspace(-5, 5, self.steps, device=self.device))
            self.sigma_t = 5 * (self.sigma_t - self.sigma_t.min() + 1e-5)
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = self.sigma_t.square()
        
        else:
            self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
            self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
            self.sigma_t = (1 - self.alpha_t.square()).sqrt()
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
            self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        backbone = config.model.lower()

        if backbone == 'egnn_gru':
            from models.EGNN_GRU.egnn_gru import EGNN_GRU
            self.network = EGNN_GRU
        elif backbone == 'gru':
            from models.GRU.gru import GRU
            self.network = GRU

        else:
            raise NotImplementedError('backbone not found')

        self.network = self.network(
            n_diff_time=self.steps,
            repara=True if self.loss.__contains__('energy') else False,
            **config.network
        ).to(self.device)

        self.nballs = config.network.n_system

        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def network_predict(self, xt: Tensor, t: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        eps_model = self.network(xt, t, edge, energy)
        if isinstance(eps_model, Tensor):
            return eps_model
        else:
            return eps_model[0]
        
    def compute_xyr2(self, pos: Tensor, edge_index: Tensor) -> Tensor:
        return (pos[edge_index[0]] - pos[edge_index[1]]).square()

    def get_loss(self, x0: Tensor, edge: Tensor, energy: Tensor) -> Tuple[dict, int]:
        batch_size, n_time = x0.shape[:2]
        n_spring = edge.size(-1)
        spatial_dim = 2
        loss_dict = {}

        rots = torch.from_numpy(np.stack([
            scipy.stats.special_ortho_group.rvs(spatial_dim) for _ in range(batch_size)
        ])).to(torch.float32).to(x0.device).reshape(batch_size, 1, 1, spatial_dim, spatial_dim)
        permute = torch.randperm(n_spring)

        # feature.shape == (batch_size, n_time, n_system, 2)
        loc_feature, vel_feature = x0.reshape(batch_size, n_time, spatial_dim, spatial_dim, self.nballs).permute(3, 0, 1, 4, 2)

        loc_feature = loc_feature[:, :, permute, :]
        vel_feature = vel_feature[:, :, permute, :]
        edge = edge[:, permute][:, :, permute]

        feature_dim = spatial_dim * self.nballs
        x0[:, :, :feature_dim] = torch.matmul(
            rots,
            loc_feature.unsqueeze(dim=-1)
        ).reshape(batch_size, n_time, feature_dim)

        x0[:, :, feature_dim:] = torch.matmul(
            rots,
            vel_feature.unsqueeze(dim=-1)
        ).reshape(batch_size, n_time, feature_dim)

        xt, t, eps_noise = self.diff_forward(x0)
        loss_w = self.diffusion_weight[t].reshape(-1, *[1] * (xt.ndim - 1))

        if self.config.diffusion == 've':
            t = torch.ones_like(t)
        eps_model = self.network(xt, t, edge, energy)
        
        if isinstance(eps_model, Tensor):
            score_loss = (eps_model - eps_noise).square()
            loss_dict['score'] = (loss_w * score_loss).mean()

            alpha_t_condition = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model

            if self.loss == 'jensen':
                model_pred = alpha_t_condition / self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1))
                pred_loc, pred_vel = model_pred.reshape(batch_size, n_time, 2, 2, n_spring).permute(3, 0, 1, 4, 2)
                loc0, vel0 = x0.reshape(batch_size, n_time, 2, 2, n_spring).permute(3, 0, 1, 4, 2)

                v2_loss = loss_w * (pred_vel.square() - vel0.square()).square().reshape(batch_size, n_time, -1)

                time_edge = edge.expand(n_time, *edge.shape).permute(1, 0, 2, 3).reshape(-1, n_spring, n_spring)
                edge_index = dense_to_sparse(adj=time_edge)[0]
                gt_xyr2 = self.compute_xyr2(loc0.reshape(-1, loc0.size(-1)), edge_index)
                pred_xyr2 = self.compute_xyr2(pred_loc.reshape(-1, pred_loc.size(-1)), edge_index)
                xyr2_loss = (gt_xyr2 - pred_xyr2).square()

                loss_dict['distance_gt'] = xyr2_loss.mean() / 1e4
                loss_dict['velocity_gt'] = v2_loss.mean() / 1e4

                # print(loss_dict['score'].mean(), xyr2_loss.mean(), v2_loss.mean())
                loss = loss_dict['score'] + loss_dict['distance_gt'] + loss_dict['velocity_gt']


            else:
                condition_momentum = alpha_t_condition.reshape(batch_size, n_time, 2, 2, n_spring)[:, :, :, 1, :].sum(dim=-1)
                momentum_loss = (condition_momentum - condition_momentum.mean(dim=1, keepdim=True).detach()).square()
                loss_dict['momentum'] = (loss_w * momentum_loss).mean()

                if self.loss == 'naive':
                    loss = score_loss
                else:
                    loss = torch.concatenate([
                        score_loss,
                        self.loss_weight * score_loss.size(-1) / momentum_loss.size(-1) * momentum_loss
                    ], dim=-1)
            
        else:
            eps_model, model_pred_vel, model_pred_edge = eps_model
            score_loss = (loss_w * (eps_model - eps_noise).square()).mean()
            loss_dict['score'] = score_loss

            loc0, vel0 = x0.reshape(batch_size, n_time, 2, 2, n_spring).permute(3, 0, 1, 4, 2)

            v2_loss = loss_w * (model_pred_vel - vel0.square()).square().reshape(batch_size, n_time, -1)

            time_edge = edge.expand(n_time, *edge.shape).permute(1, 0, 2, 3).reshape(-1, n_spring, n_spring)
            gt_xyr2 = self.compute_xyr2(loc0.reshape(-1, loc0.size(-1)), dense_to_sparse(adj=time_edge)[0])
            xyr2_loss = (gt_xyr2 - model_pred_edge).square()

            loss_dict['distance_gt'] = xyr2_loss.mean() / 10.0
            loss_dict['velocity_gt'] = v2_loss.mean()

            energy_gt_loss = loss_dict['distance_gt'] + loss_dict['velocity_gt']
            loss_dict['energy_gt'] = energy_gt_loss

            if self.loss == 'implicit_energy': 
                loss = score_loss \
                    + self.loss_weight * energy_gt_loss

            elif self.loss == 'momentum_energy':
                score_loss = (loss_w * (eps_model - eps_noise).square()).mean()
                loss_dict['score'] = score_loss

                alpha_t_condition = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model
                condition_momentum = alpha_t_condition.reshape(batch_size, n_time, 2, 2, n_spring)[:, :, :, 1, :].sum(dim=-1)
                momentum_loss = (loss_w * (condition_momentum - condition_momentum.mean(dim=1, keepdim=True).detach()).square()).mean()
                loss_dict['momentum'] = momentum_loss

                loss = score_loss \
                    + self.loss_weight[0] * momentum_loss \
                    + self.loss_weight[1] * energy_gt_loss
                
            else:
                raise NotImplementedError


        if self.loss.__contains__('energy') or self.loss == 'jensen':
            loss_dict['total'] = loss
        else:
            loss_dict['total'] = (loss_w * loss).mean()
        return loss_dict, x0.size(0)

    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), edge, energy)

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    @torch.no_grad()
    def ddpm_reverse(self, s: Tensor, t: Tensor, xt: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), edge, energy)
        eps_model = (1 - self.alpha_org[s]) / (1 - self.alpha_t[s]).sqrt() * eps_model

        mu = (xt - eps_model) / self.alpha_org[s].sqrt()
        sigma2 = (1 - self.alpha_t[t]) / (1 - self.alpha_t[s]) * (1 - self.alpha_org[s])
        # sigma2 = 1 - self.alpha_org[s]
        return mu + sigma2.sqrt() * torch.randn_like(mu)

    @torch.no_grad()
    def dpm2_reverse(self, s: Tensor, t: Tensor, xt: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        mid_t = torch.argmin(((self.lambda_t[s] + self.lambda_t[t]) / 2 - self.lambda_t).abs())
        h = self.lambda_t[t] - self.lambda_t[s]

        u = self.alpha_t[mid_t] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_t] * torch.expm1(h/2) * self.network_predict(xt, s.repeat(xt.size(0)), edge, energy)
        
        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * self.network_predict(u, mid_t.repeat(xt.size(0)), edge, energy)

    @torch.no_grad()
    def dpm3_reverse(self, s: Tensor, t: Tensor, xt: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        r1, r2, h = 1/3, 2/3, self.lambda_t[t] - self.lambda_t[s]
        mid_1 = torch.argmin((self.lambda_t[s] + r1 * h - self.lambda_t).abs())
        mid_2 = torch.argmin((self.lambda_t[s] + r2 * h - self.lambda_t).abs())

        eps_s = self.network_predict(xt, s.repeat(xt.size(0)), edge, energy)
        u1 = self.alpha_t[mid_1] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_1] * torch.expm1(r1 * h) * eps_s
        
        d1 = self.network_predict(u1, mid_1.repeat(xt.size(0)), edge, energy) - eps_s

        u2 = self.alpha_t[mid_2] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_2] * torch.expm1(r2 * h) * eps_s \
            - self.sigma_t[mid_2] * r2 / r1 * (torch.expm1(r2 * h) / (r2 * h) - 1) * d1
        
        d2 = self.network_predict(u2, mid_2.repeat(xt.size(0)), edge, energy) - eps_s

        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * eps_s \
            - self.sigma_t[t] / r2 * (torch.expm1(h) / h - 1) * d2
    
    @torch.no_grad()
    def ld_reverse(self, s: Tensor, t: Tensor, xt: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        sigma_t = self.sigma_t[s]
        eps_model = self.network_predict(
            xt,
            torch.ones(size=(xt.size(0), ), device=xt.device).to(torch.long),
            edge, energy
        )
        step_size = 0.001 * sigma_t.square()
        xt = xt + step_size * eps_model / sigma_t + torch.sqrt(step_size*2) * torch.rand_like(xt)
        return xt
    
    def sampling(self, data_shape: Tuple, edge: Tensor, energy: Tensor, method: str | None = None) -> Tensor:
        if method is None and self.config.diffusion == 'vp':
            method = self.sampling_cfg.method

        if self.config.diffusion == 've':
            method = 'ld'
            itertor = self.ld_reverse
            timestamps = torch.arange(0, self.steps)
        else:
            if method == 'ddpm':
                itertor = self.ddpm_reverse
                timestamps = torch.arange(-1, self.steps)
            elif method == 'ode':
                itertor = self.ode_reverse
                timestamps = torch.arange(0, self.steps)
            elif method.startswith('dpm3'):
                steps = int(method.split('_')[-1])
                assert steps > 2
                itertor = self.dpm3_reverse
                timestamps = torch.linspace(0, self.steps - 1, steps + 1)
            else:
                raise NotImplementedError('sampling method not found')
        
        
        timestamps = timestamps.__reversed__().int().to(self.device)

        if self.config.diffusion == 'vp':
            xt = torch.randn(size=data_shape, device=self.device)
        else:
            xt = self.sigma_t[-1] * torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = itertor(s, t, xt, edge, energy)
            
        return xt





        