import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
#os.environ['MUJOCO_GL'] = 'osmesa'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import Episode, ReplayBuffer
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'

import torchvision
from PIL import Image, ImageFilter
from sd_utils import StableDiffusion
from animate_utils import VideoAnimateDiffusion
import open_clip
from transformers import CLIPTextModel, CLIPTokenizer
from goal_diffusion import GoalGaussianDiffusion_SD, Trainer
from unet import UnetMW_SD as Unet
from avdc_utils.reward_calculation import disentangled_avdc_alignment


task2prompt = {
    'drawer_open': 'a robot arm is opening the drawer',
    'drawer_close': 'a robot arm is closing the drawer',
    'door_open': 'a robot arm is opening the door of a black safe',
    'door_close': 'a robot arm is closing the door of a black safe',
    'window_open': 'a robot arm is opening the window',
    'window_close': 'a robot arm is closing the window',
    'coffee_push': 'a robot arm is pushing the white mug towards the coffee machine',
    'button_press': 'a robot arm is pressing the button',
    'soccer': 'a robot arm is pushing a soccer ball into the net',
    'lever_pull': 'a robot arm is pulling the lever',
    'peg_insert_side': 'a robot arm is inserting the peg into the slot',
    'shelf_place': 'a robot arm is picking up an object and placing it on the shelf'
}


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards, episode_success = [], []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, info = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        episode_success.append(int(info.get('success', 0)))
        if video: video.save(env_step)
    return np.nanmean(episode_rewards), np.nanmean(episode_success)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)


    '''
    L = logger.Logger(work_dir, cfg)
    evaluated = evaluate(env, agent, 10, 0, 0, L.video)
    print(evaluated)
    import pdb
    pdb.set_trace()
    '''

    # sds pixel configs
    domain, task = cfg.task.replace('-', '_').split('_', 1)
    camera_id = dict(quadruped=2).get(domain, 0) #TODO: make sure this is set back to 0 if this works
    #dim = 64 #dict(dog=64).get(domain, 480)
    dim = 480#dict(dog=512).get(domain, 480)
    render_kwargs = dict(height=dim, width=dim, camera_id=camera_id)
    
    reward_method = cfg.reward_method
    # setup SDS model
    device = torch.device('cuda')

    if reward_method.endswith('tadpole'):
        noise_level = cfg.noise_level
        align_scale = cfg.alignment_scale
        recon_scale = cfg.recon_scale
        if reward_method.startswith('sd'):
            if domain == 'metaworld':
                cfg.text_prompt = task2prompt[task]
            guidance = StableDiffusion(device, False, False, '2.1', None, [0.02, 0.98], cfg.text_prompt)
        elif reward_method.startswith('avdc'):   # load in-domain model
            if domain == 'metaworld':
                cfg.text_prompt = task.replace('_', ' ')
            guidance = VideoAnimateDiffusion(device)
            avdc_trained_video_length = 9
            avdc_ckpt_number = 24
            avdc_pretrained_model = "openai/clip-vit-base-patch32"
            avdc_tokenizer = CLIPTokenizer.from_pretrained(avdc_pretrained_model)
            avdc_text_encoder = CLIPTextModel.from_pretrained(avdc_pretrained_model)
            avdc_unet = Unet()
            avdc = GoalGaussianDiffusion_SD(
                channels=4*(avdc_trained_video_length-1),
                model=avdc_unet,
                image_size=(64, 64),
                timesteps=100,
                sampling_timesteps=100,
                loss_type='l2',
                objective='pred_noise',
                beta_schedule = 'cosine',
                min_snr_loss_weight = True,
            )

            train_set = valid_set = [None] # dummy

            avdc_trainer = Trainer(
                diffusion_model=avdc,
                tokenizer=avdc_tokenizer, 
                text_encoder=avdc_text_encoder,
                train_set=train_set,
                valid_set=valid_set,
                train_lr=1e-4,
                train_num_steps =60000,
                save_and_sample_every =2500,
                ema_update_every = 10,
                ema_decay = 0.999,
                train_batch_size = 16,
                valid_batch_size = 32,
                gradient_accumulate_every = 1,
                num_samples = 1, 
                results_folder =f'../avdc_ckpts/mw_cond_{avdc_trained_video_length}',
                fp16 =True,
                amp=True,
            )

            # load checkpoint for avdc
            avdc_trainer.load(avdc_ckpt_number)
            avdc_text_embeddings = avdc_trainer.encode_batch_text([cfg.text_prompt])
        else:
            raise NotImplementedError
    elif reward_method == 'clip':
        if domain == 'metaworld':
            cfg.text_prompt = task2prompt[task]
        model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')
        tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')

        model.eval().to(device)
        prompt_tokens = tokenizer([cfg.text_prompt]).to(device)
        with torch.no_grad():
            prompt_features = model.encode_text(prompt_tokens).float()
        prompt_features /= prompt_features.norm(dim=-1, keepdim=True)
    else:
        pass # sparse-only doesn't need any model loading

        

    # Run training
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    #source_noise = torch.randn((1, 4, 64, 64)).to(device) # define global source noise
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):

        # Collect trajectory
        obs = env.reset()
        episode = Episode(cfg, obs)
        #TODO: MAKE SURE TO DENOTE WHICH EXP THIS IS: STABLE OR UNSTABLE
        source_noise = torch.randn((1, 4, 64, 64)).to(device) # refresh consistent source noise each episode
        #alignment_scores = torch.zeros(cfg.episode_length).to(device)
        while not episode.done:
            prev_rendered = torch.Tensor(env.render(**render_kwargs).copy()[np.newaxis, ...]).permute(0,3,1,2).to(device)
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, gt_reward, done, info = env.step(action.cpu().numpy())
            #'''
            #sd_reward, alignment_scores[len(episode)] = guidance.disentangled_sds_alignment(rendered, alignment_scale=align_scale, recon_scale=recon_scale, noise_level=noise_level, noise=source_noise)# + gt_reward * 1.5
            if reward_method.endswith('tadpole'):
                rendered = torch.Tensor(env.render(**render_kwargs).copy()[np.newaxis, ...]).permute(0,3,1,2).to(device)
                if reward_method.startswith('sd'):
                    sd_reward = guidance.disentangled_sds_alignment(rendered, alignment_scale=align_scale, recon_scale=recon_scale, noise_level=noise_level, noise=source_noise)
                    reward = sd_reward[0] + cfg.sparse_scale * info['success']
                elif reward_method.startswith('avdc'):  # use in-domain model
                    avdc_reward = disentangled_avdc_alignment(guidance, avdc_trainer, prev_rendered, rendered, avdc_text_embeddings, alignment_scale=align_scale, recon_scale=recon_scale, noise_level=noise_level, noise=source_noise)
                    reward = avdc_reward[0] + cfg.sparse_scale * info['success']
                else:
                    raise NotImplementedError
            elif reward_method == 'clip':
                pil_frame = Image.fromarray(env.render(**render_kwargs))
                clip_frame = preprocess_val(pil_frame).to(device)
                with torch.no_grad():
                    image_features = model.encode_image(clip_frame.unsqueeze(0)).float()
                    #image_features = clip_model.encode_image(clip_frame.unsqueeze(0)).float()
                image_features /= image_features.norm(dim=-1, keepdim=True)
                clip_reward = (prompt_features @ image_features.T)[0][0]# + 0.1*gt_reward
                reward = clip_reward + cfg.sparse_scale * info['success']
            else: # by default uses sparse-only reward
                reward = info['success']
            #reward = info['success']
            #'''
            episode += (obs, action, reward, gt_reward, done)
        assert len(episode) == cfg.episode_length
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))

        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward,
            'episode_gt_reward': episode.cumulative_gt_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train', agent=agent)

        # Evaluate agent periodically
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'], common_metrics['episode_success'] = \
				evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            L.log(common_metrics, category='eval', agent=agent)

    L.finish(agent)
    print('Training completed successfully')


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))
