import torch
from torch.optim import Adam
from src.dataset import SingleImageDataset
from tqdm import tqdm

from src.model import FlowMatching

class Trainer(object):
    def __init__(
        self,
        flow_matching_model: FlowMatching,
        train_dataset: SingleImageDataset,
        train_lr = 1e-4,
        train_num_steps = 100000,
        noise_sigma=0,
        device=torch.device('cpu')
    ):
        self.model = flow_matching_model.to(device).to(torch.float)
        self.train_lr = train_lr
        self.train_num_steps = train_num_steps
        self.device = device
        
        self.train_dataset = train_dataset
        

        self.optimizer = Adam(self.model.parameters(), lr = self.train_lr)
        self.step = 0
        
        self.training_losses = []
        self.metric = {}

        model_size = 0
        for param in self.model.parameters():
            model_size += param.data.nelement()
        print("Model params: %.2f M" % (model_size / 1024 / 1024))
        
        self.noise_sigma = noise_sigma

    def train(self):
        self.model.train()
        pbar_postfix = {}
        x1 = self.train_dataset.training_image.to(self.device).to(torch.float).reshape(-1, 3)
        pos = self.train_dataset.training_position.to(self.device).to(torch.float).reshape(-1, 2)

        with tqdm(total=self.train_num_steps, initial=self.step, desc='Training...', ncols=0) as pbar:
            while self.step < self.train_num_steps:
                self.model.train()
                loss, _ = self.model(x1=x1+torch.randn_like(x1)*self.noise_sigma, position=pos)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                pbar_postfix['loss']=loss.item()
                self.training_losses.append(loss.item())

                self.step+=1
                pbar.set_postfix(pbar_postfix)
                pbar.update(1)
   
    def save(self, milestone):
        model_state_cpu = {k: v.detach().cpu() for k, v in self.model.state_dict().items()}
        opt_state = self.optimizer.state_dict()
        for group in opt_state.get('state', {}).values():
            for k, v in group.items():
                if isinstance(v, torch.Tensor):
                    group[k] = v.detach().cpu()

        data = {
            'step': self.step,
            'model': model_state_cpu,
            'optimizer': opt_state
        }
        torch.save(data, milestone)

    def load(self, milestone):
        data = torch.load(milestone, map_location='cpu')

        self.step = data['step']
        self.model.load_state_dict(data['model'])
        self.model.to(self.device)

        self.optimizer.load_state_dict(data['optimizer'])
        dev = torch.device(self.device)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(dev)
