import os
from copy import deepcopy

import d4rl
import gym
import hydra
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

from cleandiffuser.dataset.d4rl_kitchen_dataset import D4RLKitchenTDDataset
from cleandiffuser.dataset.dataset_utils import loop_dataloader
from cleandiffuser.diffusion import DiscreteDiffusionSDE, DiscreteVPSDE, DDPM
from cleandiffuser.nn_condition import IdentityCondition
from cleandiffuser.nn_diffusion import DQLMlp, IDQLMlp
from cleandiffuser.utils import report_parameters, DQLCritic, FreezeModules, QEnsembleCritic
from utils import set_seed
import wandb


@hydra.main(config_path="../configs/TDP/kitchen", config_name="kitchen", version_base=None)
def pipeline(args):

    set_seed(args.seed)

    save_path = f'results/{args.pipeline_name}/{args.seed}/{args.task.env_name}/'
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)
    
    # ---------------------- Create Dataset ----------------------
    env = gym.make(args.task.env_name)
    dataset = D4RLKitchenTDDataset(d4rl.qlearning_dataset(env), args.normalize_reward)
    dataloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    obs_dim, act_dim = dataset.o_dim, dataset.a_dim

    # --------------- Network Architecture -----------------
    nn_diffusion = DQLMlp(obs_dim, act_dim, emb_dim=64, timestep_emb_type="fourier").to(args.device)
   
    
    
    nn_condition = IdentityCondition(dropout=0.0).to(args.device)
    
    # Additional Hyperparameters
    rho = args.task.rho
    eta = args.task.eta
    eta_lr = args.task.eta_lr
    q_target = args.q_target
    bc_target = args.task.bc_target
    num_a_train = args.num_a_train
    max_q_backup = args.max_q_backup
    lamb = args.task.lamb
    grad_norm = args.task.gn
    
    
    print(f"======================= Parameter Report of Diffusion Model =======================")
    report_parameters(nn_diffusion)
    print(f"==============================================================================")

    # --------------- Diffusion Model Actor --------------------

    actor = DDPM(
    nn_diffusion, nn_condition, predict_noise=args.predict_noise, beta_schedule="vp", optim_params={"lr": args.actor_learning_rate},
        x_max=+1. * torch.ones((1, act_dim), device=args.device),
        x_min=-1. * torch.ones((1, act_dim), device=args.device),
        diffusion_steps=args.diffusion_steps, ema_rate=args.ema_rate, device=args.device)

    # ------------------ Critic ---------------------

    critic = QEnsembleCritic(obs_dim, act_dim, hidden_dim=args.hidden_dim, num_q=args.num_q).to(args.device)
    critic_target = deepcopy(critic).requires_grad_(False).eval()
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_learning_rate)

    # ---------------------- Training ----------------------
    if args.mode == "train":
        
        wandb.init(project="TDP-kitchen", config=dict(args), name=f"{args.task.env_name}-{args.task.bc_target}-{args.task.rho}-{args.task.lamb}-{args.seed}")
        actor_lr_scheduler = CosineAnnealingLR(actor.optimizer, T_max=args.gradient_steps, eta_min=args.actor_learning_rate / 10)
        critic_lr_scheduler = CosineAnnealingLR(critic_optim, T_max=args.gradient_steps, eta_min=args.critic_learning_rate / 10)

        actor.train()
        critic.train()
        
        noise_level = torch.sqrt(1 - actor.bar_alpha)
        bar_alpha = actor.bar_alpha
        
        n_gradient_step = 0
        
        log = {"bc_loss": 0., "critic_loss": 0., "target_q_mean": 0., "q_guidance_loss": 0., "actor_loss": 0., "eta": 0., "q_loss": 0., "cos_sim": 0., "diff_loss": 0., "eta_min": 0.}


        prior = torch.zeros((args.batch_size * num_a_train, act_dim), device=args.device)

        q_guidance_loss, q_loss, cos_sim, diff_loss = None, None, None, None
        
        for batch in loop_dataloader(dataloader):

            obs, next_obs = batch["obs"]["state"].to(args.device), batch["next_obs"]["state"].to(args.device)
            act = batch["act"].to(args.device)
            rew = batch["rew"].to(args.device)
            tml = batch["tml"].to(args.device)
            
            
            obs_next = next_obs.unsqueeze(0).repeat(num_a_train,1,1).view(-1, obs_dim)
            
            next_act, _ = actor.sample(
                prior, 
                n_samples=args.batch_size * num_a_train, sample_steps=args.diffusion_steps, use_ema=True,
                temperature=1.0, condition_cfg=obs_next, w_cfg=1.0, requires_grad=False)
                
            
            
            act_next = next_act.reshape(-1, act_dim)
            
            
            next_v = critic_target(obs_next, act_next)
            next_v = torch.stack(next_v)
            next_v = next_v.view(args.num_q, num_a_train, -1)
            if max_q_backup:
                next_v, _ = next_v.max(1)
            else:
                next_v = next_v.mean(1)
            
            if q_target == 'lcb':
                target_q = next_v.mean(0) - rho * next_v.std(0, correction=0)
            elif q_target == 'min':
                target_q, _ = next_v.min(0)
            
            target_q = target_q.unsqueeze(-1)
            
            target_q = (rew + (1 - tml) * args.discount * target_q).detach()
            
            
            
            
            critic_loss = F.mse_loss(torch.stack(critic(obs,act)), target_q.unsqueeze(0).repeat(critic.num_q,1,1))
            
            critic_optim.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), grad_norm)
            critic_optim.step()

            # -- Policy Training
            
            if (n_gradient_step + 1) % 2 == 0:
            
                xt, t, eps = actor.add_noise(act)

                condition = actor.model["condition"](obs) 
            
                eps_model = actor.model["diffusion"](xt, t, condition)
                loss_1 = (eps_model - eps) ** 2
                loss_2 = (loss_1 * actor.loss_weight * (1 - actor.fix_mask)).mean()

                bc_loss = loss_2
                
                idx_1 = np.random.choice(args.batch_size, args.batch_size // 2, replace=False)
                xt_1 = xt[idx_1]
                t_1 = t[idx_1]
                eps_1 = eps[idx_1]
                eps_model_1 = eps_model[idx_1]
                obs_1 = obs[idx_1]
                xt_1.requires_grad_(True)    
                x0_hat_1 = (xt_1 - (1 - bar_alpha[t_1]).unsqueeze(-1).repeat(1,act_dim).sqrt() * eps_1) / bar_alpha[t_1].unsqueeze(-1).repeat(1,act_dim).sqrt()
                obs_1.requires_grad_(False)  
                with FreezeModules([critic, ]):
                    q_val = critic.q_mean(obs_1, x0_hat_1).sum()
                q_val.backward()
                grad_model = xt_1.grad
                q_guidance_loss_1 = grad_model / torch.mean(torch.abs(torch.stack(critic_target(obs_1,x0_hat_1)))).detach()
                q_guidance_loss = q_guidance_loss_1 * eps_model_1 * noise_level[t_1].unsqueeze(-1).repeat(1,eps_1.shape[1])
                q_guidance_loss = q_guidance_loss.mean()
                    
                x0_hat_2 = torch.clamp((xt_1 - (1 - bar_alpha[t_1]).unsqueeze(-1).repeat(1,act_dim).sqrt() * eps_model_1) / bar_alpha[t_1].unsqueeze(-1).repeat(1,act_dim).sqrt(),max=1.,min=-1.)
                with FreezeModules([critic, ]):
                    q_loss = critic.q_mean(obs_1, x0_hat_2).mean() / torch.mean(torch.abs(torch.stack(critic_target(obs_1,x0_hat_2)))).detach()
            
           
                actor_loss = eta * bc_loss + (1 - lamb) * q_guidance_loss - lamb * q_loss
               
                actor.optimizer.zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(actor.model.parameters(), grad_norm)
                actor.optimizer.step()
  
            actor_lr_scheduler.step()
            critic_lr_scheduler.step()
            
            
            if (n_gradient_step + 1) % 2 == 0:
                loss_eta = bc_loss.detach()
                loss_eta_1 = (loss_eta - bc_target).detach()
                eta += eta_lr * loss_eta_1
                eta = eta.clamp(0.0001, 100)
            
            
            

            # -- ema
            if n_gradient_step % args.ema_update_interval == 0:
                if n_gradient_step >= 1000:
                    actor.ema_update()
                for param, target_param in zip(critic.parameters(), critic_target.parameters()):
                    target_param.data.copy_(0.995 * param.data + (1 - 0.995) * target_param.data)

            # # ----------- Logging ------------
            
            log["critic_loss"] += critic_loss.item()
            log["target_q_mean"] += target_q.mean().item()
            
            
            
            if (n_gradient_step + 1) % 2 == 0:
                log["actor_loss"] += actor_loss.item() 
                log["bc_loss"] += bc_loss.item()
                log["q_guidance_loss"] += q_guidance_loss.item() if q_guidance_loss is not None else 0
                log["q_loss"] += q_loss.item() if q_loss is not None else 0
                log["eta"] += eta.item() if type(eta) !=float else eta
                

            if (n_gradient_step + 1) % args.log_interval == 0:
                log["gradient_steps"] = n_gradient_step + 1
                log["critic_loss"] /= args.log_interval
                log["target_q_mean"] /= args.log_interval
                log["actor_loss"] /= args.log_interval / 2
                log["q_guidance_loss"] /= args.log_interval / 2
                log["bc_loss"] /= args.log_interval / 2
                log["q_loss"] /= args.log_interval / 2
                log["eta"] /= args.log_interval / 2
                
                
                print(log)
                wandb.log(log)
                log = {"bc_loss": 0., "critic_loss": 0., "target_q_mean": 0., "q_guidance_loss": 0., "actor_loss": 0., "eta": 0., "q_loss": 0.}

            # ----------- Saving ------------
            if (n_gradient_step + 1) % args.save_interval == 0:
                actor.save(save_path + f"diffusion_ckpt_{n_gradient_step + 1}.pt")
                actor.save(save_path + f"diffusion_ckpt_latest.pt")
                torch.save({
                    "critic": critic.state_dict(),
                    "critic_target": critic_target.state_dict(),
                }, save_path + f"critic_ckpt_{n_gradient_step + 1}.pt")
                torch.save({
                    "critic": critic.state_dict(),
                    "critic_target": critic_target.state_dict(),
                }, save_path + f"critic_ckpt_latest.pt")
                
            if (n_gradient_step + 1) % args.eval_interval == 0:
                
                actor.eval()
                critic.eval()
                critic_target.eval()

                env_eval = gym.vector.make(args.task.env_name, args.num_envs)
                normalizer = dataset.get_normalizer()
                episode_rewards = []
  
                prior_1 = torch.zeros((args.num_envs * args.num_candidates, act_dim), device=args.device)

                obs, ep_reward, cum_done, t = env_eval.reset(), 0., 0., 0

                while not np.all(cum_done) and t < 1000 + 1:
                # normalize obs
                    obs = torch.tensor(normalizer.normalize(obs), device=args.device, dtype=torch.float32)
                    obs = obs.unsqueeze(1).repeat(1, args.num_candidates, 1).view(-1, obs_dim)

                # sample actions
                    act, log_1 = actor.sample(
                        prior_1,
                        
                        n_samples=args.num_envs * args.num_candidates,
                        sample_steps=args.diffusion_steps,
                        condition_cfg=obs, w_cfg=1.0,
                        use_ema=args.use_ema, temperature=args.temperature)

                # resample
                    with torch.no_grad():
                        q = critic_target.q_mean(obs, act)
                        
                        q = q.view(-1, args.num_candidates, 1)
                        act = act.view(-1, args.num_candidates, act_dim)
                        if args.task.weight_temperature > 0:
                            w = torch.softmax(q * args.task.weight_temperature, 1) 
                            indices = torch.multinomial(w.squeeze(-1), 1).squeeze(-1)
                        else:
                            indices = torch.argmax(q.squeeze(-1), 1)

                        
                        sampled_act = act[torch.arange(act.shape[0]), indices].cpu().numpy()

                # step
                    obs, rew, done, info = env_eval.step(sampled_act)
    
                    t += 1
                    cum_done = done if cum_done is None else np.logical_or(cum_done, done)
                    ep_reward += rew

                    if np.all(cum_done):
                        break
  
                episode_rewards.append(ep_reward)
                episode_rewards = [list(map(lambda x: env.get_normalized_score(x), r)) for r in episode_rewards]
                episode_rewards = np.array(episode_rewards)
                episode_rewards = (episode_rewards>0).astype(float)
                print(episode_rewards)
                print(np.mean(episode_rewards), np.std(episode_rewards))
                wandb.log({"normalized_score_mean": 100*np.mean(episode_rewards), "normalized_score_std": 100*np.std(episode_rewards)})
                actor.train()
                critic.train()
                critic_target.train()
  
            n_gradient_step += 1
            if n_gradient_step >= args.gradient_steps:
                break

    # ---------------------- Inference ----------------------
    elif args.mode == "inference":

        actor.load(save_path + f"diffusion_ckpt_{args.ckpt}.pt")
        critic_ckpt = torch.load(save_path + f"critic_ckpt_{args.ckpt}.pt")
        critic.load_state_dict(critic_ckpt["critic"])
        critic_target.load_state_dict(critic_ckpt["critic_target"])
        

        actor.eval()
        critic.eval()
        critic_target.eval()

        env_eval = gym.vector.make(args.task.env_name, args.num_envs)
        normalizer = dataset.get_normalizer()
        episode_rewards = []
        
        n_samples=args.num_envs * args.num_candidates

        
        use_ema=args.use_ema
        w_cfg=1.0

        prior = torch.zeros((args.num_envs * args.num_candidates, act_dim), device=args.device)
        for i in range(1):

            obs, ep_reward, cum_done, t = env_eval.reset(), 0., 0., 0

            while not np.all(cum_done) and t < 1000 + 1:
                # normalize obs
                obs = torch.tensor(normalizer.normalize(obs), device=args.device, dtype=torch.float32)
                obs = obs.unsqueeze(1).repeat(1, args.num_candidates, 1).view(-1, obs_dim)

                # sample actions
                act, log = actor.sample(
                    prior,
                    solver=args.solver,
                    n_samples=args.num_envs * args.num_candidates,
                    sample_steps=args.diffusion_steps,
                    condition_cfg=obs, w_cfg=1.0,
                    use_ema=args.use_ema, temperature=args.temperature)
                # resample
                with torch.no_grad():
                    q = critic_target.q_mean(obs, act)
                    q = q.view(-1, args.num_candidates, 1)
                    act = act.view(-1, args.num_candidates, act_dim)
                    if args.task.weight_temperature > 0:
                        w = torch.softmax(q * args.task.weight_temperature, 1) 
                        indices = torch.multinomial(w.squeeze(-1), 1).squeeze(-1)
                    else:
                        indices = torch.argmax(q.squeeze(-1), 1)
                        
                    sampled_act = act[torch.arange(act.shape[0]), indices].cpu().numpy()

                # step
                obs, rew, done, info = env_eval.step(sampled_act)

                t += 1
                cum_done = done if cum_done is None else np.logical_or(cum_done, done)
                ep_reward += rew

                if np.all(cum_done):
                    break

            episode_rewards.append(ep_reward)
        episode_rewards = [list(map(lambda x: env.get_normalized_score(x), r)) for r in episode_rewards]
        episode_rewards = np.array(episode_rewards)
        episode_rewards = (episode_rewards>0).astype(float)
        print(np.mean(episode_rewards, -1), np.std(episode_rewards, -1))


    else:
        raise ValueError(f"Invalid mode: {args.mode}")


if __name__ == "__main__":
    pipeline()
