import torch
import torch.nn as nn
import torch.optim as optim
from openrlhf.models import Actor
from torch.utils.data import DataLoader

class MockStrategy:
    def __init__(self, args):
        self.args = args
        self.accumulated_gradient = args.train_batch_size
        self.world_size = 1
    def is_rank_0(self):
        return True
    
    def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
        if isinstance(model, Actor):
            model = model.model
        optimizer.zero_grad()
        loss.backward()
        
    def optimizer_step(
        self,
        optimizer: optim.Optimizer,
        model: nn.Module,
        scheduler,
        name="model",
        **kwargs,
    ) -> None:
        if isinstance(model, Actor):
            model = model.model
        optimizer.step()
        scheduler.step()
        
    def all_reduce(self, data, op="mean"):
        return data
    
    def print(self, *msg):
        if self.is_rank_0():
            print(*msg)

    def create_optimizer(self, model, **kwargs):
        optimizer = torch.optim.AdamW(model.parameters(), **kwargs)
        return optimizer
    
    def prepare(self, *args, is_rlhf=False):
        ret = []
        # args[0] = args[0].to("cuda")
        for arg in args:
            if isinstance(arg, tuple):
                arg = (arg[0].to("cuda"),arg[1],arg[2])
                ret.append(arg)
            else:
                ret.append(arg.to("cuda"))
        return ret[0] if len(ret) == 1 else ret
    
    
    def setup_dataloader(
        self,
        replay_buffer,
        batch_size: int,
        pin_memory: bool = False,
        shuffle=True,
        collate_fn=None,
        drop_last=True,
        sampler=None,
        consumed_samples=0,
    ):
        return DataLoader(
            replay_buffer,
            batch_size=batch_size,
            sampler=sampler,
            drop_last=drop_last,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
        )

    def _unwrap_model(self, model) -> nn.Module:
        if isinstance(model, Actor):
            return self._unwrap_model(model.model)
        elif hasattr(model, "module"):
            return model.module
        else:
            return model
        
    def save_model(self, model: nn.Module, tokenizer=None, output_dir=None, **kwargs):
        model = self._unwrap_model(model)
        model.save_pretrained(output_dir)
        
    def get_ds_train_config(self, is_actor):
        return None
    
    def get_ds_eval_config(self, offload=False):
        return None