from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
import torch_geometric.data as gd
from torch_scatter import scatter_add, scatter_logsumexp

from gflownet.models import PolicyNet



class BaseAlgorithm:
    def __init__(self, optimizer, grad_norm_clip=1.0):
        self.optimizer = optimizer
        self.grad_norm_clip = grad_norm_clip

    def update(self, model: PolicyNet, batch: gd.Batch) -> None:
        raise NotImplementedError()


class TrajectoryBalance(BaseAlgorithm):
    def __init__(self, optimizer, grad_norm_clip=1.0, learn_backward=False):
        self.optimizer = optimizer
        self.grad_norm_clip = grad_norm_clip
        self.learn_backward = learn_backward

    def update(self, model, batch):
        fwd_dist, bck_dist, _ = model.forward(batch)
        fwd_log_probs = fwd_dist.log_prob(batch.fwd_action_idxs)
        tr_fwd_log_probs = scatter_add(fwd_log_probs, batch.tbatch)
        fwd_log_flow = model.logZ + tr_fwd_log_probs

        if self.learn_backward:
            bck_log_probs = bck_dist.log_prob(batch.bck_action_idxs)
            tr_bck_log_probs = scatter_add(bck_log_probs, batch.rbatch)
            bck_log_flow = batch.log_rewards + tr_bck_log_probs
        else:
            bck_log_flow = batch.log_rewards + batch.tr_bck_log_probs

        loss = (fwd_log_flow - bck_log_flow).square().mean()

        model.zero_grad(set_to_none=True)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.grad_norm_clip)
        self.optimizer.step()

        return {
            'loss': loss.item(), 
            'grad_norm': grad_norm.item(), 
            'logZ': model.logZ.item()
        }

class DetailedBalance(BaseAlgorithm):
    def __init__(self, optimizer, grad_norm_clip=1.0, sampling_tau=None, sampling_model=None):
        self.optimizer = optimizer
        self.grad_norm_clip = grad_norm_clip
        self.sampling_tau = sampling_tau
        self.sampling_model = sampling_model

    def update_sampling_model(self, model):
        for a, b in zip(model.parameters(), self.sampling_model.parameters()):
            b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau))

    def update(self, model, batch):
        fwd_dist, _, log_flows = model(batch)
        log_flows = log_flows.flatten()
        log_PF = fwd_dist.log_prob(batch.fwd_action_idxs)

        log_edge_flow = log_flows + log_PF
            
        with torch.no_grad():
            _, _, next_log_flows = self.sampling_model(batch.next_batch)
            next_log_flows = next_log_flows.flatten()
        
        targets = batch.targets.scatter_add_(0, batch.non_terminals, next_log_flows)

        loss = (log_edge_flow - targets).square().mean()

        model.zero_grad(set_to_none=True)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.grad_norm_clip)
        self.optimizer.step()
        
        if self.sampling_model is not None:
            self.update_sampling_model(model)

        return {
            'loss': loss.item(), 
            'grad_norm': grad_norm.item(), 
        }


class PPOAlgorithm(BaseAlgorithm):
    def __init__(
            self,
            optimizer,
            grad_norm_clip=1.0,
            gamma=0.99, 
            gae_lamb=0.95, 
            n_epochs=5, 
            value_coef=0.5, 
            entropy_coef=0.1, 
            clip_range=0.2, 
            normalize_advantages=True
        ):
        self.optimizer = optimizer
        self.grad_norm_clip = grad_norm_clip
        self.gamma = gamma
        self.gae_lamb = gae_lamb
        self.n_epochs = n_epochs
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.clip_range = clip_range
        self.normalize_advantages = normalize_advantages


    def compute_returns_and_advantage(self, state_values: Tensor, batch: gd.Batch):
        rewards = torch.exp(batch.log_rewards)
        batch_sizes = torch.bincount(batch.tbatch)
        ptr = torch.cumsum(batch_sizes, 0) - 1
        next_values = state_values.roll(-1)
        next_values[ptr] = rewards
        advantages = self.gamma * next_values - state_values
        exponents = torch.cat([torch.arange(i).flip(0) for i in batch_sizes], 0)
        returns = (self.gamma ** exponents).to(batch.tbatch.device) * rewards[batch.tbatch]

        t0 = -1
        for t1 in ptr:
            advantage_t = 0
            bias = (1 - self.gae_lamb) / (1 - torch.pow(self.gae_lamb, t1 - t0))
            for t in range(t1, t0, -1):
                advantage_t = advantages[t] + self.gae_lamb * advantage_t
                advantages[t] = bias * advantage_t
            t0 = t1

        return returns, advantages

    def update(self, model, batch):
        with torch.no_grad():
            old_dist, _, old_values = model(batch)
            old_log_probs = old_dist.log_prob(batch.fwd_action_idxs)
            returns, advantages = self.compute_returns_and_advantage(old_values.flatten(), batch)

        if self.normalize_advantages and batch.num_graphs > 1:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        infos = []
        for it in range(self.n_epochs):
            dist, _, value_pred = model(batch)
            log_probs = dist.log_prob(batch.fwd_action_idxs)

            ratios = torch.exp(log_probs - old_log_probs)

            surr1 = ratios * advantages
            surr2 = torch.clip(ratios, 1.0 - self.clip_range, 1.0 + self.clip_range) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(value_pred.flatten(), returns)
            entropy_loss = - dist.entropy().mean()
            loss =  policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss

            model.zero_grad(set_to_none=True)
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.grad_norm_clip)
            self.optimizer.step()        

            with torch.no_grad():
                # log_ratios = log_probs - old_log_probs
                # approx_kl_div = torch.mean((torch.exp(log_ratios) - 1) - log_ratios)
                approx_kl_div = torch.mean(ratios - 1 - ratios.log())
                

            info = {
                "loss": loss.item(),
                "policy_loss": policy_loss.item(),
                "value_loss": value_loss.item(),
                "entropy_loss": entropy_loss.item(),
                "approx_kl_div": approx_kl_div.item(),
                "grad_norm": grad_norm.item()
            }
            infos.append(info)

        return {k: np.mean([d[k] for d in infos]) for k in infos[0].keys()}