import wandb
import torch
from accelerate import Accelerator
import torch.nn.functional as F

class AlignRunner:
    
    def __init__(self, net, accelerator, batch_count = 0, stage="train",
                 optimizer=None,tokenizer = None, lr_scheduler=None, config = None):
        
        self.net, self.stage = net, stage
        self.optimizer, self.lr_scheduler = optimizer, lr_scheduler
        self.accelerator = accelerator 
        self.batch_count = batch_count
        self.tokenizer = tokenizer
        self.config = config

        if self.stage == 'train':
            self.net.train()
        else:
            self.net.eval()
        
    def __call__(self, batch):
        
        with self.accelerator.accumulate(self.net):
   
            inputs, rewards, prompt_len = batch

            if self.config.length_norm:
                token_count = inputs["attention_mask"][:,prompt_len:].float().sum(-1,keepdims=True)
            else:
                token_count = 1

            policy_logp, entropy = self.net.get_logp(inputs, prompt_len, "policy")
            ref_logp = self.net.get_logp(inputs, prompt_len, "ref")
            rewards =  torch.tensor(rewards).to("cuda").reshape(-1, 1).float()

            target_p = torch.clamp(policy_logp.exp()/ref_logp.exp(), max=1).detach()
            logp = target_p * policy_logp + (1-target_p+1e-6) / (1-policy_logp.exp().detach()+1e-6) * entropy
            logp_sum = (logp).sum(-1,keepdims=True)/token_count

            scores = (policy_logp.sum(-1,keepdims=True)/token_count - ref_logp.sum(-1,keepdims=True)/token_count)

            if torch.rand(1) >= self.config.mask_rate:
                loss = - (logp_sum - logp_sum.transpose(1,0))
            else:
                if torch.rand(1) > 0.5:
                    loss = - (0*logp_sum - logp_sum.transpose(1,0))
                else:
                    loss = - (logp_sum - 0*logp_sum.transpose(1,0))

            reward_mask = (rewards - rewards.transpose(1,0)) > 1e-6 
            loss = (loss * reward_mask.float()).sum()/reward_mask.sum()

            if self.accelerator.is_local_main_process:
                print("\n")
                print(scores)
                print(torch.min(target_p,axis=-1)[0])
               
            if self.optimizer is not None and self.stage == "train":
                self.accelerator.backward(loss)
                if self.accelerator.sync_gradients:
                    self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
                self.optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                self.optimizer.zero_grad()
        
        _loss = self.accelerator.gather(loss).mean() 
        _acc = self.accelerator.gather((scores[0,0] > scores[1,0]).float()).mean()
        _win = self.accelerator.gather(scores[0]).mean() 
        _loose = self.accelerator.gather(scores[1]).mean()
        _gap = _win-_loose

        step_metrics = {
                    "loss": _loss.item(),
                    "_win": _win.item(),
                    "_loose": _loose.item(),
                    "_gap": _gap.item(),
                    "_acc": _acc.item()
                }
        
        step_losses = {self.stage + "_loss": _loss.item()}
        if self.accelerator.is_local_main_process:
            wandb.log(step_metrics)

        if (self.batch_count+1) % self.config.save_per_step == 0:
            self.save_ckpt(self.config.ckpt_path+"_%s"%(self.batch_count+1),accelerator = self.accelerator)

        self.batch_count += 1
        return step_losses, step_metrics
        
    def save_ckpt(self, path, accelerator = None):

        if accelerator is None:
            accelerator = self.accelerator

        accelerator.wait_for_everyone()
        state_dict = accelerator.get_state_dict(self.net)
        
        if accelerator.is_main_process:
            policy_dict = {key[len("policy."):]: value for key, value in state_dict.items() if "policy" in key}
            self.net.policy.save_pretrained(path,state_dict=policy_dict,safe_serialization=False)
