import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from typing import Tuple

import scipy
import numpy as np

from itertools import combinations

class Diffusion(nn.Module):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        assert config.diffusion in ['ve', 'vp'], 'diffusion noise scheduler not implemented'

        self.m = 1.0
        G = 6.67408e-11

        m_nd = 1.989e+30
        r_nd = 5.326e+12
        v_nd = 30000

        norm_constant = 30000000.0
        self.G1 = G * m_nd / r_nd / norm_constant
        self.G2 = (v_nd)**2 / norm_constant

        if config.loss.startswith('implicit_energy'):
            self.loss = 'implicit_energy'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss.startswith('explicit_energy'):
            self.loss = 'explicit_energy'
            self.loss_weight = list(map(float, config.loss.split('_')[2:]))
        elif config.loss.startswith('momentum_energy'):
            self.loss = 'momentum_energy'
            self.loss_weight = list(map(float, config.loss.split('_')[2:]))
        elif config.loss.startswith('momentum'):
            self.loss = 'momentum'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss.startswith('ablation'):
            self.loss = 'ablation'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss.startswith('jensen'):
            self.loss = 'jensen'
            self.loss_weight = float(config.loss.split('_')[-1])
        elif config.loss == 'naive':
            self.loss = config.loss

        else:
            raise NotImplementedError
            
        self.steps = 1000
        self.config = config
        self.device = config.device

        if config.diffusion == 've':
            self.alpha_t = torch.ones(size=(self.steps, ), device=self.device)
            self.sigma_t = torch.sigmoid(torch.linspace(-5, 5, self.steps, device=self.device))
            self.sigma_t = 5 * (self.sigma_t - self.sigma_t.min() + 1e-5)
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = self.sigma_t.square()
        
        else:
            self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
            self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
            self.sigma_t = (1 - self.alpha_t.square()).sqrt()
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
            self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        backbone = config.model.lower()

        if backbone == 'paragru':
            from models.ParaGRU.paragru import ParaGRU
            self.network = ParaGRU

        elif backbone == 'paraphygru':
            from models.ParaPhyGRU.paraphygru import ParaPhyGRU
            self.network = ParaPhyGRU

        else:
            raise NotImplementedError('backbone not found')

        self.network = self.network(
            n_diff_time=self.steps,
            **config.network
        ).to(self.device)

        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def network_predict(self, xt: Tensor, t: Tensor) -> Tensor:
        eps_model = self.network(xt, t)
        if isinstance(eps_model, Tensor):
            return eps_model
        else:
            return eps_model[0]
        
    def get_loss(self, x0: Tensor) -> Tuple[dict, int]:
        loss_dict = {}


        batch_size, n_time = x0.shape[:2]
        rots = torch.from_numpy(np.stack([
            scipy.stats.special_ortho_group.rvs(3) for _ in range(batch_size)
        ])).to(torch.float32).to(x0.device).reshape(batch_size, 1, 1, 3, 3)
        permute = torch.randperm(3, device=x0.device)

        # pos.shape == (batch_size, n_time, n_system, xyz)
        pos = x0[:, :, :9].reshape(batch_size, n_time, 3, 3)[:, :, permute, :]
        x0[:, :, :9] = torch.matmul(
            rots,
            pos.unsqueeze(dim=-1)
        ).reshape(batch_size, n_time, 9)

        vel = x0[:, :, 9:].reshape(batch_size, n_time, 3, 3)[:, :, permute, :]
        x0[:, :, 9:] = torch.matmul(
            rots,
            vel.unsqueeze(dim=-1)
        ).reshape(batch_size, n_time, 9)


        xt, t, eps_noise = self.diff_forward(x0)
        loss_w = self.diffusion_weight[t].reshape(-1, *[1] * (xt.ndim - 1))

        if self.config.diffusion == 've':
            t = torch.ones_like(t)
        eps_model = self.network(xt, t)

        if isinstance(eps_model, Tensor):
            if self.loss == 'naive':
                loss = (eps_model - eps_noise).square()
                loss_dict['score'] = (loss_w * loss).mean()

            elif self.loss == 'momentum':
                loss1 = (eps_model - eps_noise).square()
                loss_dict['score'] = (loss_w * loss1).mean()

                alpha_t_condition = xt[:, :, 9:] - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model[:, :, 9:]
                alpha_t_condition = alpha_t_condition.reshape(*xt.shape[:2], 3, 3).sum(dim=2)
                loss2 = (alpha_t_condition - alpha_t_condition.mean(dim=1, keepdim=True).detach()).square()
                loss = torch.concatenate([
                    loss1,
                    self.loss_weight * loss1.size(-1) / loss2.size(-1) * loss2
                ], dim=-1)
                loss_dict['momentum'] = (loss_w * loss2).mean()

            elif self.loss == 'jensen':
                loss1 = (eps_model - eps_noise).square()
                loss_dict['score'] = (loss_w * loss1).mean()

                pred_x0 = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model
                mutual_traj = list(combinations(pred_x0[:, :, :9].reshape(*pred_x0.shape[:2], 3, 3).permute(2, 1, 0, 3), 2))
                mutual_traj = torch.stack([torch.stack(c) for c in mutual_traj]).permute(1, 0, 3, 2, 4)
                pred_r_reciprocal = 1.0 / (mutual_traj[0] - mutual_traj[1]).norm(dim=-1).permute(1, 2, 0)
                pred_v2 = torch.square(pred_x0[:, :, 9:])
                pred = torch.concatenate([pred_r_reciprocal, pred_v2], dim=-1)


                mutual_traj = list(combinations(x0[:, :, :9].reshape(*x0.shape[:2], 3, 3).permute(2, 1, 0, 3), 2))
                mutual_traj = torch.stack([torch.stack(c) for c in mutual_traj]).permute(1, 0, 3, 2, 4)
                gt_r_reciprocal = 1.0 / (mutual_traj[0] - mutual_traj[1]).norm(dim=-1).permute(1, 2, 0)
                gt_v2 = torch.square(x0[:, :, 9:])
                gt = torch.concatenate([gt_r_reciprocal, gt_v2], dim=-1)

                loss2 = (pred - gt).square()

                loss_dict['jensen_gt'] = (loss_w * loss2).mean()

                loss = torch.concatenate([
                    loss1,
                    self.loss_weight * loss1.size(-1) / loss2.size(-1) * loss2
                ], dim=-1)

            elif self.loss == 'ablation':
                loss1 = (eps_model - eps_noise).square()
                loss_dict['score'] = (loss_w * loss1).mean()

                pred_x0 = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model
                mutual_traj = list(combinations(pred_x0[:, :, :9].reshape(*pred_x0.shape[:2], 3, 3).permute(2, 1, 0, 3), 2))
                mutual_traj = torch.stack([torch.stack(c) for c in mutual_traj]).permute(1, 0, 3, 2, 4)

                pred_r_reciprocal = 1.0 / (mutual_traj[0] - mutual_traj[1]).norm(dim=-1).permute(1, 2, 0)
                pred_v2 = torch.square(pred_x0[:, :, 9:])

                grav_energy = self.G1 * (- (self.m**2) * pred_r_reciprocal).sum(dim=-1)
                vel_energy  = self.G2 * 0.5 * self.m * pred_v2.sum(dim=-1)

                total_energy = grav_energy + vel_energy
                mean_energy = total_energy.mean(dim=-1, keepdim=True).detach()
                loss2 = (total_energy - mean_energy).square().unsqueeze(dim=-1)

                loss_dict['ablation_gt'] = (loss_w * loss2).mean()

                loss = torch.concatenate([
                    loss1,
                    self.loss_weight * loss1.size(-1) / loss2.size(-1) * loss2
                ], dim=-1)

            else:
                raise NotImplementedError
            
        else:
            eps_model, condition_expect = eps_model
            loss1 = (eps_model - eps_noise).square()
            loss_dict['score'] = (loss_w * loss1).mean()

            # gt_r2.shape == (batch_size, n_time, num_combinations = 3)
            mutual_traj = list(combinations(x0[:, :, :9].reshape(*x0.shape[:2], 3, 3).permute(2, 1, 0, 3), 2))
            mutual_traj = torch.stack([torch.stack(c) for c in mutual_traj]).permute(1, 0, 3, 2, 4)

            gt_r_reciprocal = 1.0 / (mutual_traj[0] - mutual_traj[1]).norm(dim=-1).permute(1, 2, 0)
            gt_v2 = torch.square(x0[:, :, 9:])
            loss2 = (condition_expect - torch.concatenate([gt_r_reciprocal, gt_v2], dim=-1)).square()
            loss_dict['energy_gt'] = (loss_w * loss2).mean()

            if self.loss == 'explicit_energy':
                grav_energy = self.G1 * (- (self.m**2) * condition_expect[:, :, :3]).sum(dim=-1)
                vel_energy  = self.G2 * 0.5 * self.m * condition_expect[:, :, 3:].sum(dim=-1)

                total_energy = grav_energy + vel_energy
                mean_energy = total_energy.mean(dim=-1, keepdim=True).detach()
                loss3 = (total_energy - mean_energy).square()
                loss_dict['energy'] = (loss_w * loss3).mean()

                loss = torch.concatenate([
                    loss1,
                    self.loss_weight[0] * loss1.size(-1) / loss2.size(-1) * loss2,
                    self.loss_weight[1] * loss1.size(-1) * loss3.unsqueeze(dim=-1)
                ], dim=-1)
                
            elif self.loss == 'implicit_energy': 
                loss = torch.concatenate([
                    loss1,
                    self.loss_weight * loss1.size(-1) / loss2.size(-1) * loss2
                ], dim=-1)

            elif self.loss == 'momentum_energy':
                alpha_t_condition = xt[:, :, 9:] - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model[:, :, 9:]
                alpha_t_condition = alpha_t_condition.reshape(*xt.shape[:2], 3, 3).sum(dim=2)
                loss3 = (alpha_t_condition - alpha_t_condition.mean(dim=1, keepdim=True).detach()).square()

                loss = torch.concatenate([
                    loss1,
                    self.loss_weight[1] * loss1.size(-1) / loss2.size(-1) * loss2,
                    self.loss_weight[0] * loss1.size(-1) / loss3.size(-1) * loss3
                ], dim=-1)
                loss_dict['momentum'] = (loss_w * loss3).mean()



        loss_dict['total'] = (loss_w * loss).mean()
        return loss_dict, loss.size(0)

    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)))

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    @torch.no_grad()
    def ddpm_reverse(self, s: Tensor, t: Tensor, xt: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)))
        eps_model = (1 - self.alpha_org[s]) / (1 - self.alpha_t[s]).sqrt() * eps_model

        mu = (xt - eps_model) / self.alpha_org[s].sqrt()
        sigma2 = (1 - self.alpha_t[t]) / (1 - self.alpha_t[s]) * (1 - self.alpha_org[s])
        # sigma2 = 1 - self.alpha_org[s]
        return mu + sigma2.sqrt() * torch.randn_like(mu)

    @torch.no_grad()
    def dpm2_reverse(self, s: Tensor, t: Tensor, xt: Tensor) -> Tensor:
        mid_t = torch.argmin(((self.lambda_t[s] + self.lambda_t[t]) / 2 - self.lambda_t).abs())
        h = self.lambda_t[t] - self.lambda_t[s]

        u = self.alpha_t[mid_t] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_t] * torch.expm1(h/2) * self.network_predict(xt, s.repeat(xt.size(0)))
        
        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * self.network_predict(u, mid_t.repeat(xt.size(0)))

    @torch.no_grad()
    def dpm3_reverse(self, s: Tensor, t: Tensor, xt: Tensor) -> Tensor:
        r1, r2, h = 1/3, 2/3, self.lambda_t[t] - self.lambda_t[s]
        mid_1 = torch.argmin((self.lambda_t[s] + r1 * h - self.lambda_t).abs())
        mid_2 = torch.argmin((self.lambda_t[s] + r2 * h - self.lambda_t).abs())

        eps_s = self.network_predict(xt, s.repeat(xt.size(0)))
        u1 = self.alpha_t[mid_1] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_1] * torch.expm1(r1 * h) * eps_s
        
        d1 = self.network_predict(u1, mid_1.repeat(xt.size(0))) - eps_s

        u2 = self.alpha_t[mid_2] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_2] * torch.expm1(r2 * h) * eps_s \
            - self.sigma_t[mid_2] * r2 / r1 * (torch.expm1(r2 * h) / (r2 * h) - 1) * d1
        
        d2 = self.network_predict(u2, mid_2.repeat(xt.size(0))) - eps_s

        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * eps_s \
            - self.sigma_t[t] / r2 * (torch.expm1(h) / h - 1) * d2
    
    @torch.no_grad()
    def ld_reverse(self, s: Tensor, t: Tensor, xt: Tensor) -> Tensor:
        sigma_t = self.sigma_t[s]
        eps_model = self.network_predict(xt, torch.ones(size=(xt.size(0), ), device=xt.device).to(torch.long))
        step_size = 0.001 * sigma_t.square()
        xt = xt + step_size * eps_model / sigma_t + torch.sqrt(step_size*2) * torch.rand_like(xt)
        return xt
    
    def sampling(self, data_shape: Tuple, method: str | None = None) -> Tensor:
        if method is None and self.config.diffusion == 'vp':
            method = self.sampling_cfg.method

        if self.config.diffusion == 've':
            method = 'ld'
            itertor = self.ld_reverse
            timestamps = torch.arange(0, self.steps)
        else:
            if method == 'ddpm':
                itertor = self.ddpm_reverse
                timestamps = torch.arange(-1, self.steps)
            elif method == 'ode':
                itertor = self.ode_reverse
                timestamps = torch.arange(0, self.steps)
            elif method.startswith('dpm3'):
                steps = int(method.split('_')[-1])
                assert steps > 2
                itertor = self.dpm3_reverse
                timestamps = torch.linspace(0, self.steps - 1, steps + 1)
            else:
                raise NotImplementedError('sampling method not found')
        
        timestamps = timestamps.__reversed__().int().to(self.device)

        if self.config.diffusion == 'vp':
            xt = torch.randn(size=data_shape, device=self.device)
        else:
            xt = self.sigma_t[-1] * torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = itertor(s, t, xt)
            
        return xt





        