import math
import torch
from torch.utils.data.dataloader import DataLoader
import numpy as np
import wandb
from .timer import Timer
from tqdm import tqdm
from src.envs.toy_car.toy_car import ToyCar
def to(xs, device):
    res = []
    for x in xs:
        if isinstance(x, dict):
            for k in x:
                x[k] = x[k].to(device)
            res.append(x)
        else:
            res.append(x.to(device))
    return res


class Trainer:

    def __init__(self, config):
        self.config = config
        self.device = config.device

        self.n_epochs = 0
        self.n_tokens = 0 # counter used for learning rate decay
        self.optimizer = None

    def get_optimizer(self, model):
        if self.optimizer is None:
            print(f'[ utils/training ] Making optimizer at epoch {self.n_epochs}')
            self.optimizer = model.configure_optimizers(self.config)
        return self.optimizer

    def train(self, model, dataset, subgoals_segmented, obs_segmented, n_epochs=1, log_freq=100):

        config = self.config
        optimizer = self.get_optimizer(model)
        model.train(True)


        loader = DataLoader(dataset, shuffle=True, pin_memory=True,
                            batch_size=config.batch_size,
                            num_workers=config.num_workers)
        
       
        
        for _ in range(n_epochs):

            losses = []
            timer = Timer()
            for it, batch in enumerate(loader):

                batch = to(batch, self.device)
               
                # forward the model
                with torch.set_grad_enabled(True):
                    outputs = model(batch[0])
                    loss = model.compute_loss(outputs, *batch)
                    losses.append(loss.item())

                # print('optimizers: ', optimizer) # 2
                # backprop and update the parameters
                model.zero_grad()
                # print('update loss: ', loss)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                optimizer.step()

                mask = batch[-1]
                self.n_tokens += mask.sum() # number of tokens processed this step
                # decay the learning rate based on our progress
                if config.lr_decay:
                    mask = batch[-1]
                    self.n_tokens += mask.sum() # number of tokens processed this step
                    if self.n_tokens < config.warmup_tokens:
                        # linear warmup
                        lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
                    else:
                        # cosine learning rate decay
                        progress = float(self.n_tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                        lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                    lr = config.learning_rate * lr_mult
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    if self.n_tokens < config.warmup_tokens:
                        # linear warmup
                        lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
                    else:
                        lr_mult = 1.0
                    lr = lr_mult * config.learning_rate

                # report progress
                if it % log_freq == 0:
                    print(
                        f'[ utils/training ] epoch {self.n_epochs} [ {it:4d} / {len(loader):4d} ] ',
                        f'train loss {loss.item():.5f} | lr {lr:.3e} | lr_mult: {lr_mult:.4f} | '
                        f't: {timer():.2f}')

            self.n_epochs += 1
        logs = dict()
        logs['training/train_loss_mean'] = np.mean(losses)
        logs['training/train_loss_std'] = np.std(losses)

        wandb.log(logs)

        ####################################3
        model.train(False)
        env = ToyCar(stop_sign = 30, seed = 0)
        for i in tqdm(range(100)):

            
            env.reset(testing=True)
           
            observation = obs_segmented[i, 0, :]
            
            env.ego_x = observation[0]
            env.ego_vel = observation[1]
            env.other_x = observation[2]
            env.other_vel = observation[3]
          

           

            for t in range(40):
                print('observation: ', type(observation))

                
                action, pred_subgoal = model(observation)
                print('pred: ', pred_subgoal)


                true_subgoal = subgoals_segmented[i, t, :] 

                print('true: ', true_subgoal)
                

                print(np.square(np.subtract(true_subgoal,pred_subgoal)).mean())
                print('--------------------------')

    


class Hier_Trainer:

    def __init__(self, config):
        self.config = config
        self.device = config.device

        self.n_epochs = 0
        self.n_tokens = 0 # counter used for learning rate decay
        self.optimizer = None
        self.command_optimizer = None

    def get_optimizer(self, model):
        if self.optimizer is None:
            print(f'[ utils/training ] Making optimizer at epoch {self.n_epochs}')
            self.optimizer = model.configure_optimizers(self.config)
        return self.optimizer
    
    def get_command_optimizer(self, model):
        if self.command_optimizer is None:
            print(f'[ utils/training ] Making command optimizer at epoch {self.n_epochs}')
            self.command_optimizer = model.configure_optimizers(self.config)
        return self.command_optimizer

    def train(self, model, command_model, dataset, n_epochs=1, log_freq=100):

        config = self.config

        optimizer = self.get_optimizer(model)
        model.train(True)

        command_optimizer = self.get_command_optimizer(command_model)
        command_model.train(True)

        

        loader = DataLoader(dataset, shuffle=True, pin_memory=True,
                            batch_size=config.batch_size,
                            num_workers=config.num_workers)

        
        for _ in range(n_epochs):

            losses = []
            command_losses = []
            timer = Timer()
            for it, batch in enumerate(loader):

                batch = to(batch, self.device)
               
                # forward the model
                with torch.set_grad_enabled(True):

                    command_outputs = command_model(batch[0])
                    command_loss = command_model.compute_loss(command_outputs, *batch)
                    command_losses.append(command_loss.item())

                 

                    outputs = model(batch[0])
                    loss = model.compute_loss(outputs, *batch)
                    losses.append(loss.item())

                # print('optimizers: ', optimizer) # 2
                # backprop and update the parameters
                command_model.zero_grad()
                model.zero_grad()
                # print('update loss: ', loss)
                command_loss.backward()
                loss.backward()

                torch.nn.utils.clip_grad_norm_(command_model.parameters(), config.grad_norm_clip)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                command_optimizer.step()
                optimizer.step()

                mask = batch[-1]
                self.n_tokens += mask.sum() # number of tokens processed this step
                # decay the learning rate based on our progress
                if config.lr_decay:
                    mask = batch[-1]
                    self.n_tokens += mask.sum() # number of tokens processed this step
                    if self.n_tokens < config.warmup_tokens:
                        # linear warmup
                        lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
                    else:
                        # cosine learning rate decay
                        progress = float(self.n_tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                        lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                    lr = config.learning_rate * lr_mult
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    if self.n_tokens < config.warmup_tokens:
                        # linear warmup
                        lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
                    else:
                        lr_mult = 1.0
                    lr = lr_mult * config.learning_rate

                # report progress
                if it % log_freq == 0:
                    print(
                        f'[ utils/training ] epoch {self.n_epochs} [ {it:4d} / {len(loader):4d} ] ',
                        f'train loss {loss.item():.5f} | command train loss {command_loss.item():.5f} | lr {lr:.3e} | lr_mult: {lr_mult:.4f} | '
                        f't: {timer():.2f}')

            self.n_epochs += 1
        logs = dict()
        logs['training/train_loss_mean'] = np.mean(losses)
        logs['training/train_loss_std'] = np.std(losses)
        logs['training/train_command_loss_mean'] = np.mean(command_losses)
        logs['training/train_command_loss_std'] = np.std(command_losses)

        wandb.log(logs)


       
