import os
import copy

import gym
import numpy as np
import torch
import torch.nn.functional as F
import random
from tqdm import tqdm
import time
import logging

def data_iter(batch_size, x):
    def data_i():
        num_examples = x.length
        indices = list(range(num_examples))
        random.shuffle(indices)
        for i in range(0, num_examples, batch_size):
            batch_indices = torch.tensor(
                indices[i: min(i + batch_size, num_examples)])
            yield x[batch_indices]
    return data_i

def cycle(dl):
    while True:
        for data in dl():
            yield data

class EMA():
    '''
        empirical moving average
    '''
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new
class Trainer(object):
    def __init__(
        self,
        diffusion_model,
        rw_model,
        dataset,
        # renderer,
        ema_decay=0.995,
        train_batch_size=256,
        train_lr=3e-4,
        gradient_accumulate_every=2,
        step_start_ema=2000,
        update_ema_every=10,
        log_freq=10000,
        sample_freq=1000,
        save_freq=100000,
        label_freq=100000,
        save_parallel=False,
        n_reference=8,
        bucket=None,
        train_device='cuda',
        save_checkpoints=False,
        envname='dummy_env',
        savename=None
    ):
        super().__init__()
        self.model = diffusion_model
        self.rw_model = rw_model
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.model)
        self.update_ema_every = update_ema_every
        self.save_checkpoints = save_checkpoints
        self.start_time = time.time()
        self.envname = envname

        self.path = os.path.join(os.getcwd(), envname)

        # Check if the directory exists
        if not os.path.exists(self.path):
            # Create the directory since it does not exist
            os.makedirs(self.path)
            print(f"Directory '{envname}' created at {self.path}")
        else:
            print(f"Directory '{envname}' already exists at {self.path}")


        self.step_start_ema = step_start_ema
        self.log_freq = log_freq
        self.sample_freq = sample_freq
        self.save_freq = save_freq
        self.label_freq = label_freq
        self.save_parallel = save_parallel

        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.dataset = dataset
        # self.renderer = renderer
        self.df_optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr, weight_decay=int(1e-4))
        self.rw_optimizer = torch.optim.Adam(rw_model.parameters(), lr=train_lr, weight_decay=int(1e-4))

        self.bucket = bucket
        self.n_reference = n_reference

        self.dataloader = cycle(data_iter(self.batch_size, self.dataset))

        self.reset_parameters()
        self.step = 0

        self.device = train_device
        self.savename = savename

        self.save_suffix = ""
        if self.savename:
            self.save_suffix = "_" + self.savename

    def reset_parameters(self):
        self.ema_model.load_state_dict(self.model.state_dict())
    
    def step_ema(self):
        if self.step < self.step_start_ema:
            self.reset_parameters()
            return
        self.ema.update_model_average(self.ema_model, self.model)

    #-----------------------------------------------------------------------------#
    #------------------------------------ api ------------------------------------#
    #-----------------------------------------------------------------------------#

    def save(self, save_path):
        model_data = self.model.state_dict()
        torch.save(model_data, save_path)


    def rw_loss(self, rw, conditions):
        predicts = self.rw_model(conditions)
        reward_predicts = predicts[:, 0]
        terminal_probs = predicts[:, 1]
        reward_loss = F.mse_loss(reward_predicts, rw[:, 0])
        terminal_loss = F.binary_cross_entropy(terminal_probs, rw[:, 1])
        loss = reward_loss
        info = {'reward_loss': reward_loss, 'terminal_loss': terminal_loss}

        return loss, info
    
    def train_reward(self, n_train_steps):
        logging.basicConfig(filename=self.path+f'/rw_losses' + self.save_suffix + '.log', filemode='w', level=logging.INFO, force=True)
        self.dataset.to_tensor()
        infos = {}
        for step in tqdm(range(n_train_steps)):
            n_batch = 0
            infos["reward_loss"] = 0
            infos["terminal_loss"] = 0
            batch = next(self.dataloader)
            n_batch += 1
            x = batch.outputs
            conditions = batch.conditions
            returns = batch.rewards
            terminals = batch.terminals
            rw = torch.cat((returns, terminals), dim=1)
            step += 1

            # diffusers
            loss, new_infos = self.rw_loss(rw, torch.cat((conditions, x), dim=1))

            loss.backward()
            tmp_loss = loss.detach().item()
            for key in new_infos.keys():
                infos[key] += new_infos[key]
            self.rw_optimizer.step()
            self.rw_optimizer.zero_grad()
            # print(tmp_loss)

            if self.step % self.update_ema_every == 0:
                self.step_ema()

            if self.step % self.save_freq == 0:
                self.rw_model.save(self.path + r'/{}_rwmodel_state_{}{}.pt'.format(self.envname, self.step, self.save_suffix))

            if self.step % self.log_freq == 0:
                print("loss:", tmp_loss)
                infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
                current_time = time.time() - self.start_time
                logging.info(f'{self.step}: {tmp_loss:8.4f} | {infos_str} | t: {current_time:8.4f}')
                metrics = {k: v.detach().item() for k, v in infos.items()}
                metrics['steps'] = self.step
                metrics['loss'] = tmp_loss

            self.step += 1
        self.rw_model.save(self.path + '/{}_rwmodel_state_{}{}.pt'.format(self.envname, self.step, self.save_suffix))
        return tmp_loss


    def train(self, n_train_steps):

        logging.basicConfig(filename=self.path+'/losses{}.log'.format(self.save_suffix), filemode='w', level=logging.INFO, force=True)
        self.dataset.to_tensor()
        #timer = Timer()
        # tmp_loss = 0
        for step in tqdm(range(n_train_steps)):
            batch = next(self.dataloader)
            x = batch.outputs
            conditions = batch.conditions
            loss, infos = self.model.loss(x, conditions)

            loss.backward()
            tmp_loss = loss.detach().item()

            self.df_optimizer.step()
            self.df_optimizer.zero_grad()

            if self.step % self.update_ema_every == 0:
                self.step_ema()

            if self.step % self.save_freq == 0:
                if self.savename:
                    self.save(self.path + r'/{}_model_state_{}{}.pt'.format(self.envname, self.step, self.save_suffix))
                else:
                    self.save(self.path + r'/{}_model_state_{}.pt'.format(self.envname, self.step))

            if self.step % self.log_freq == 0:
                print("loss:", tmp_loss)
                infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
                current_time = time.time() - self.start_time
                logging.info(f'{self.step}: {tmp_loss:8.4f} | {infos_str} | t: {current_time:8.4f}')
                metrics = {k:v.detach().item() for k, v in infos.items()}
                metrics['steps'] = self.step
                metrics['loss'] = tmp_loss

            self.step += 1
        if self.savename:
            self.save( self.path + '/{}_model_state_{}_{}.pt'.format(self.envname, self.step, self.savename))
            self.save(self.path + '/{}_ema_model_state_{}_{}.pt'.format(self.envname, self.step, self.savename))
        else:
            self.save(self.path + '/{}_model_state_{}.pt'.format(self.envname, self.step))
            self.save(self.path + '/{}_ema_model_state_{}.pt'.format(self.envname, self.step))
        return tmp_loss
