if __name__ == "__main__":
    import sys
    import os
    import pathlib

    ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
    sys.path.append(ROOT_DIR)
    os.chdir(ROOT_DIR)
import torch.nn.functional as F

import os
import h5py
import hydra
import torch
from omegaconf import OmegaConf
import pathlib
from torch.utils.data import DataLoader
import copy
import random
import dill
import wandb
import tqdm
import numpy as np
import shutil
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import DiffusionTransformerHybridImagePolicy
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.common.checkpoint_util import TopKCheckpointManager
from diffusion_policy.common.json_logger import JsonLogger
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to
from diffusion_policy.model.diffusion.ema_model import EMAModel
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from pathlib import Path
import cv2
import time

OmegaConf.register_new_resolver("eval", eval, replace=True)


def create_video_from_images(image_list, output_filename, fps=30, frame_size=None):
    if frame_size is None:
        frame_size = (image_list.shape[1], image_list.shape[2])

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_filename, fourcc, fps, frame_size)

    for img in image_list:
        if img.shape[1] != frame_size[0] or img.shape[0] != frame_size[1]:
            img = cv2.resize(img, frame_size)
        out.write(img)

    out.release()


from torch.nn import Module, Sequential, Linear, Tanh, Parameter, Embedding
import torch.nn as nn


class Discriminator(Module):
    def __init__(self, state_dim, action_dim, obs_step, horizon, discrete=False) -> None:
        super().__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.discrete = discrete
        self.horizon = horizon
        self.obs_step = obs_step
        if self.discrete:
            self.act_emb = Embedding(
                action_dim, state_dim
            )
            self.net_in_dim = 2 * state_dim
        else:
            self.net_in_dim = state_dim + action_dim

        self.expand_dim = nn.Conv1d(in_channels=self.obs_step, out_channels=50, kernel_size=1)
        self.bn0 = nn.BatchNorm1d(50)

        self.conv1 = nn.Conv1d(in_channels=50, out_channels=1, kernel_size=3, padding=1)

        self.action_expand_dim = nn.Conv1d(in_channels=self.horizon, out_channels=self.horizon // 4, kernel_size=3, padding=1)
        self.action_bn0 = nn.BatchNorm1d(self.horizon // 4)

        self.action_conv1 = nn.Conv1d(in_channels=self.horizon // 4, out_channels=1, kernel_size=3, padding=1)

        self.net = Sequential(
            Linear(self.net_in_dim, self.net_in_dim // 2),
            Tanh(),
            Linear(self.net_in_dim // 2, self.net_in_dim // 2),
            Tanh(),
            Linear(self.net_in_dim // 2, self.net_in_dim // 2),
            Tanh(),
            Linear(self.net_in_dim // 2, 1),
        )

    def forward(self, obs_states, actions):
        return torch.sigmoid(self.get_logits(obs_states, actions))

    def get_logits(self, states, actions):
        states = states.reshape(-1, self.obs_step, self.state_dim)
        states_fea = self.expand_dim(states)
        states_fea = self.bn0(states_fea)
        states_fea = torch.relu(states_fea)
        states_fea = self.conv1(states_fea)
        states = states_fea.squeeze(1)

        actions_fea = self.action_expand_dim(actions)
        actions_fea = self.action_bn0(actions_fea)
        actions_fea = torch.relu(actions_fea)
        actions_fea = self.action_conv1(actions_fea)
        actions = actions_fea.squeeze(1)

        if self.discrete:
            actions = self.act_emb(actions.long())

        batch_size = actions.shape[0]
        sa = torch.cat([states, actions], dim=-1)
        return self.net(sa)


def kl_divergence_loss(student_logits, teacher_logits):
    temperature = 3.0
    student_probs = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
    return loss


def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1.0 - logits_real))
    loss_fake = torch.mean(F.relu(1.0 + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def backup_project_to_output(output_dir):
    current_dir = Path.cwd()
    diff_dir = current_dir / 'diffusion_policy'
    backup_dir_list = [current_dir / 'diffusion_policy' / 'env_runner', current_dir / 'diffusion_policy' / 'policy',
                       current_dir / 'diffusion_policy' / 'workspace', current_dir / 'diffusion_policy' / 'dataset']
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    for backup_dir in backup_dir_list:
        for root, dirs, files in os.walk(backup_dir):
            if output_dir.name in root or 'libero' in root:
                continue

            relative_path = Path(root).relative_to(diff_dir)
            target_dir = output_dir / relative_path

            target_dir.mkdir(parents=True, exist_ok=True)

            for file in files:
                source_file = Path(root) / file
                target_file = target_dir / file
                shutil.copy(source_file, target_file)


class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
    include_keys = ['global_step', 'epoch']

    def __init__(self, cfg: OmegaConf, output_dir=None):
        super().__init__(cfg, output_dir=output_dir)
        backup_project_to_output(self.output_dir)

        seed = cfg.training.seed
        device = cfg.training.device
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        self.model: DiffusionTransformerHybridImagePolicy = hydra.utils.instantiate(cfg.policy)

        from diffusers.schedulers.scheduling_ddim import DDIMScheduler

        if isinstance(self.model.noise_scheduler, DDIMScheduler):
            self.model.noise_scheduler.set_timesteps(num_inference_steps=cfg.policy.num_inference_steps)
        self.ema_model: DiffusionTransformerHybridImagePolicy = None

        if cfg.training.use_ema:
            self.ema_model = copy.deepcopy(self.model)

        if hasattr(self.model, 'get_optimizer'):
            self.optimizer = self.model.get_optimizer(**cfg.optimizer)
        else:
            self.optimizer = hydra.utils.instantiate(
                cfg.optimizer, params=self.model.parameters())

        self.global_step = 0
        self.epoch = 0
        self.best_ckpt_score = 0
        self.best_ckpt_score_epoch = 0

        self.d = Discriminator(self.model.obs_encoder.output_shape()[0],
                               self.cfg.shape_meta['action']['shape'][0], self.model.n_obs_steps, self.model.horizon).to(device)

        self.opt_d = torch.optim.Adam(self.d.parameters())

    def run(self):
        cfg = copy.deepcopy(self.cfg)

        if cfg.training.resume:
            lastest_ckpt_path = self.get_checkpoint_path()
            if lastest_ckpt_path.is_file():
                print(f"Resuming from checkpoint {lastest_ckpt_path}")
                self.load_checkpoint(path=lastest_ckpt_path)

        dataset: BaseImageDataset

        self.dataset_path = cfg.task.dataset.dataset_path

        if 'tool_hang' in self.dataset_path:
            save_name = 'tool_hang'
        elif 'transport' in self.dataset_path:
            save_name = 'transport'
        elif 'square' in self.dataset_path:
            save_name = 'square'
        elif 'can' in self.dataset_path:
            save_name = 'can'
        elif 'kitchen' in self.dataset_path:
            save_name = 'kitchen'
        elif 'block' in self.dataset_path:
            save_name = 'block_push'
        elif 'pusht' in self.dataset_path:
            save_name = 'pusht'
        filename = f'./{save_name}_init_data.hdf5'

        if cfg.training.get('empty', False):
            OmegaConf.update(cfg.task.dataset, "extra", True, force_add=True)
            cfg.task.dataset.subdataset = 1
            cfg.task.dataset.dataset_path = filename
            cfg.task.dataset_path = filename
            self.dataset_path = cfg.task.dataset.dataset_path
            print('redefine dataset path')

        dataset = hydra.utils.instantiate(cfg.task.dataset)

        assert isinstance(dataset, BaseImageDataset)
        train_dataloader = DataLoader(dataset, **cfg.dataloader)
        normalizer = dataset.get_normalizer()

        val_dataset = dataset.get_validation_dataset()
        val_dataset_true = copy.deepcopy(val_dataset)
        val_dataloader = DataLoader(val_dataset_true, **cfg.val_dataloader)

        self.model.set_normalizer(normalizer)
        if cfg.training.use_ema:
            self.ema_model.set_normalizer(normalizer)

        lr_scheduler = get_scheduler(
            cfg.training.lr_scheduler,
            optimizer=self.optimizer,
            num_warmup_steps=cfg.training.lr_warmup_steps,
            num_training_steps=(
                                       len(train_dataloader) * cfg.training.num_epochs) \
                               // cfg.training.gradient_accumulate_every,
            last_epoch=self.global_step - 1
        )

        ema: EMAModel = None
        if cfg.training.use_ema:
            ema = hydra.utils.instantiate(
                cfg.ema,
                model=self.ema_model)

        env_runner: BaseImageRunner

        env_runner = hydra.utils.instantiate(
            cfg.task.env_runner,
            output_dir=self.output_dir)
        assert isinstance(env_runner, BaseImageRunner)

        wandb_run = wandb.init(
            dir='wandb' + str(self.output_dir),
            config=OmegaConf.to_container(cfg, resolve=True),
            **cfg.logging
        )
        wandb.config.update(
            {
                "output_dir": self.output_dir,
            }
        )

        topk_manager = TopKCheckpointManager(
            save_dir=os.path.join(self.output_dir, 'checkpoints'),
            **cfg.checkpoint.topk
        )

        device = torch.device(cfg.training.device)
        self.model.to(device)
        if self.ema_model is not None:
            self.ema_model.to(device)
        optimizer_to(self.optimizer, device)

        train_sampling_batch = None

        if cfg.training.debug:
            cfg.training.num_epochs = 2
            cfg.training.max_train_steps = 3
            cfg.training.max_val_steps = 3
            cfg.training.rollout_every = 1
            cfg.training.checkpoint_every = 1
            cfg.training.val_every = 1
            cfg.training.sample_every = 1

        log_path = os.path.join(self.output_dir, 'logs.json.txt')

        with JsonLogger(log_path) as json_logger:
            for local_epoch_idx in range(cfg.training.num_epochs):
                step_log = dict()
                train_losses = list()
                init_weight_d = 1
                loss_no_change_epochs = 0
                with tqdm.tqdm(train_dataloader,
                               desc=f"Train e:{self.epoch} Best e:{self.best_ckpt_score_epoch} Best score:{self.best_ckpt_score}",
                               leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
                    for batch_idx, batch in enumerate(tepoch):
                        batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))

                        if train_sampling_batch is None:
                            train_sampling_batch = batch

                        self.model.train()

                        self.d.train()
                        self.opt_d.zero_grad()
                        raw_loss, stu_traj, nobs_stu = self.model.compute_loss(batch)

                        nactions = self.model.normalizer['action'].normalize(batch['action'])
                        dataset_action = nactions

                        exp_scores = self.d.get_logits(nobs_stu.clone().detach(), dataset_action.clone().detach())
                        nov_scores = self.d.get_logits(nobs_stu.clone().detach(), stu_traj.clone().detach())
                        d_train_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                            exp_scores, torch.ones_like(exp_scores)
                        ) \
                                       + torch.nn.functional.binary_cross_entropy_with_logits(
                            nov_scores, torch.zeros_like(nov_scores)
                        )

                        d_train_loss.backward()
                        self.opt_d.step()

                        self.d.eval()
                        self.optimizer.zero_grad()

                        raw_loss = raw_loss / cfg.training.gradient_accumulate_every

                        nov_scores = self.d.get_logits(nobs_stu, stu_traj)

                        d_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                            nov_scores, torch.ones_like(nov_scores)
                        )

                        weight_d_loss = 0.002
                        loss = raw_loss + d_loss * weight_d_loss

                        loss.backward()

                        self.optimizer.step()
                        lr_scheduler.step()

                        if cfg.training.use_ema:
                            ema.step(self.model)

                        d_train_loss_cpu = d_train_loss.item()
                        raw_loss_cpu = raw_loss.item()
                        d_loss_cpu = d_loss.item()
                        train_loss_cpu = loss.item()
                        tepoch.set_postfix(loss=train_loss_cpu, refresh=False)
                        train_losses.append(train_loss_cpu)
                        step_log = {
                            'train_loss': train_loss_cpu,
                            'raw_loss': raw_loss_cpu,
                            'd_train_loss_cpu': d_train_loss_cpu,
                            'd_loss_cpu': d_loss_cpu,
                            'global_step': self.global_step,
                            'epoch': self.epoch,
                            'lr': lr_scheduler.get_last_lr()[0]
                        }

                        is_last_batch = (batch_idx == (len(train_dataloader) - 1))
                        if self.global_step % cfg.training.val_every == 0:
                            wandb_run.log(step_log, step=self.global_step)
                            json_logger.log(step_log)
                        if not is_last_batch:
                            self.global_step += 1

                        if (cfg.training.max_train_steps is not None) \
                                and batch_idx >= (cfg.training.max_train_steps - 1):
                            break

                train_loss = np.mean(train_losses)
                step_log['train_loss'] = train_loss

                policy = self.model
                if cfg.training.use_ema:
                    policy = self.ema_model
                policy.eval()

                if (self.epoch % cfg.training.agpd_every) == 0 and self.epoch > 0:
                    if self.epoch == cfg.training.agpd_every:
                        _, traj_list = env_runner.augwog_run(policy)
                    else:
                        if cfg.training.guide_type == 'dagger':
                            _, traj_list = env_runner.dagger_run(self.s_model, policy, self.epoch)

                            for i, traj in enumerate(traj_list):
                                obs_dict = copy.deepcopy(traj['obs'])
                                action = traj['actions']
                                traj_len = action.shape[0]
                                action_dagger = []
                                for j in range(traj_len):
                                    obs_input = dict()
                                    if j == 0:
                                        for key in obs_dict.keys():
                                            if not isinstance(value, torch.Tensor):
                                                obs_dict[key] = torch.tensor(obs_dict[key])
                                            shape_len = len(obs_dict[key].shape)
                                            obs_input[key] = obs_dict[key][j:j + 1].unsqueeze(0).repeat(1, 2, *[1 for _ in range(
                                                shape_len - 1)]).to(
                                                policy.device)
                                    else:
                                        for key in obs_dict.keys():
                                            if not isinstance(value, torch.Tensor):
                                                obs_dict[key] = torch.tensor(obs_dict[key])
                                            obs_input[key] = obs_dict[key][j - 1:j + 1].unsqueeze(0).to(
                                                policy.device)
                                    dagger_action = policy.predict_action(obs_input)['action_pred'][:, 1:2].detach().to(
                                        'cpu').squeeze(0)
                                    action_dagger.append(dagger_action)
                                action_dagger = torch.cat(action_dagger, dim=0)
                                assert action_dagger.shape[0] == traj_len
                                traj['actions'] = action_dagger
                                traj_list[i] = traj

                        elif cfg.training.guide_type == 'diffusion':
                            _, traj_list = env_runner.aug_run(self.s_model, policy, self.epoch)
                        else:
                            _, traj_list = env_runner.augwog_run(policy)

                    aug_num = 0

                    previous_episode_num = dataset.replay_buffer.n_episodes

                    for traj in traj_list:
                        aug = dataset.finetune_data(traj)
                        if aug:
                            aug_num += 1

                    print(f'Epoch:{self.epoch}: aug_{aug_num}demos')

                    current_episode_num = dataset.replay_buffer.n_episodes
                    if getattr(self, 'init_agpd_every', None) is None:
                        self.init_agpd_every = cfg.training.agpd_every

                    cfg.training.agpd_every = round(
                        self.init_agpd_every * (current_episode_num / previous_episode_num)) + self.init_agpd_every

                    normalizer = dataset.get_normalizer()
                    self.model.set_normalizer(normalizer)
                    self.model.to(device)
                    if cfg.training.use_ema:
                        self.ema_model.set_normalizer(normalizer)
                        self.ema_model.to(device)

                if (self.epoch % cfg.training.rollout_every) == 0 and self.epoch > 0:
                    runner_log = env_runner.run(policy)
                    step_log.update(runner_log)
                    metric_dict = {}
                    for key, value in step_log.items():
                        new_key = key.replace('/', '_')
                        metric_dict[new_key] = value
                    print(f'Epoch:{self.epoch}_test_mean_score', metric_dict['test_mean_score'])
                    print(f'Epoch:{self.epoch}_raw_loss', raw_loss_cpu)
                    if metric_dict['test_mean_score'] > self.best_ckpt_score:
                        self.best_ckpt_score_epoch = self.epoch
                        self.best_ckpt_score = metric_dict['test_mean_score']
                        print('**Current Best**:', self.best_ckpt_score)
                        if self.best_ckpt_score == 1:
                            assert False, 'alright full!'
                        self.save_checkpoint(tag='best')
                        epoch_text = os.path.join(self.output_dir, 'epoch_log.txt')
                        with open(epoch_text, 'a') as f:
                            f.write(f'Epoch:{self.epoch}_test_mean_score:{metric_dict["test_mean_score"]}\n')
                            f.write(f'Epoch:{self.epoch}_best_ckpt_score:{self.best_ckpt_score}\n')

                    if cfg.training.guide_type == 'diffusion' or cfg.training.guide_type == 'dagger':
                        self.s_model = copy.deepcopy(self.ema_model)
                        self.s_model.requires_grad = False

                if (self.epoch % cfg.training.val_every) == 0:
                    with torch.no_grad():
                        val_losses = list()
                        with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
                                       leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
                            for batch_idx, batch in enumerate(tepoch):
                                batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
                                loss, _, _ = self.model.compute_loss(batch)
                                val_losses.append(loss)
                                if (cfg.training.max_val_steps is not None) \
                                        and batch_idx >= (cfg.training.max_val_steps - 1):
                                    break
                        if len(val_losses) > 0:
                            val_loss = torch.mean(torch.tensor(val_losses)).item()
                            step_log['val_loss'] = val_loss

                if (self.epoch % cfg.training.sample_every) == 0 and self.epoch > 0:
                    with torch.no_grad():
                        batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
                        obs_dict = batch['obs']
                        gt_action = batch['action']

                        result = policy.predict_action(obs_dict)
                        pred_action = result['action_pred']
                        mse = torch.nn.functional.mse_loss(pred_action, gt_action)
                        step_log['train_action_mse_error'] = mse.item()
                        del batch
                        del obs_dict
                        del gt_action
                        del result
                        del pred_action
                        del mse

                if (self.epoch % cfg.training.checkpoint_every) == 0 and self.epoch > 0:
                    if cfg.checkpoint.save_last_ckpt:
                        self.save_checkpoint()
                    if cfg.checkpoint.save_last_snapshot:
                        self.save_snapshot()

                    metric_dict = dict()
                    for key, value in step_log.items():
                        new_key = key.replace('/', '_')
                        metric_dict[new_key] = value

                    topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)

                    if topk_ckpt_path is not None:
                        self.save_checkpoint(path=topk_ckpt_path)

                policy.train()

                wandb_run.log(step_log, step=self.global_step)
                json_logger.log(step_log)
                self.global_step += 1
                self.epoch += 1


@hydra.main(
    version_base=None,
    config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")),
    config_name=pathlib.Path(__file__).stem)
def main(cfg):
    workspace = TrainDiffusionTransformerHybridWorkspace(cfg)
    workspace.run()


if __name__ == "__main__":
    main()