import time
import tqdm
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.collectors import SyncDataCollector
from collectors import JumpStartSyncDataCollector
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.data.utils import DEVICE_TYPING
from tensordict.nn import TensorDictModule
from utils import eval_model, get_env_mask
import wandb
import json
from torch import Tensor
from collections import deque

class EpisodeHist():
    def __init__(self, buffer_size: int = 100):
        self.buffer_size = buffer_size
        self.buffer = {'reward': [], 'length': []}

    def update(self, episodes_this_iter: dict[Tensor]):
        stats = {}
        for k, v in episodes_this_iter.items():
            assert k in self.buffer, f'unknown key {k}'
            self.buffer[k].extend(v.tolist())
            if len(self.buffer[k]) > self.buffer_size:
                self.buffer[k] = self.buffer[k][-self.buffer_size:]

            stats[k] = np.mean(self.buffer[k])

        return stats


def train_ppo(
        config,
        actor: TensorDictModule,
        critic: TensorDictModule,
        collector: SyncDataCollector,
        data_buffer,
        adv_module,
        loss_module,
        total_frames: int,
        num_sgd_iter: int ,
        num_mini_batches: int,
        optim,
        test_env,
        test_interval: int,
        device: DEVICE_TYPING,
        archive_buffer=None,
    ):
    logs = []
    episodes_hist = EpisodeHist()

    if config.get('lr_scheduler') != 'constant':
        if config.get('lr_scheduler') == 'linear':
            lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optim,
                start_factor=1,
                end_factor=0.1,
                total_iters=total_frames // config['train_batch_size']
            )
            print("using linear lr scheduler")
        else:
            lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optim,
                max_lr=config['lr'],
                pct_start=0.3,
                div_factor=10, final_div_factor=1,
                total_steps=total_frames // config['train_batch_size']
            )
            print("using one cycle lr scheduler")
    else:
        lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optim, factor=1)
        print("using constant lr scheduler")

    losses = TensorDict({}, batch_size=[num_sgd_iter, num_mini_batches])
    pbar = tqdm.tqdm(total=total_frames)
    max_reward = float('-inf')
    n_steps = 0

    eval_horizon = 5
    recent_eval_rewards = deque(maxlen=eval_horizon)
    best_moving_avg_reward = 1e-3


    for i, data in enumerate(collector):
        frames_in_batch = data.numel()
        pbar.update(frames_in_batch)
        n_steps += frames_in_batch
        loss_module.set_entropy(n_steps, total_frames)  # update current_entropy based on schedule

        if config.get("env") == "single_cook":
            try:
                collector.env.env_method("anneal_reward_shaping_factor", n_steps)
            except RuntimeError as e:
                # If the error is because the env is closed, simply pass
                if "closed/non started" in str(e):
                    pass
                else:
                    raise e

        log_info = {'train/steps': n_steps}
        log_info['train/cur_entropy_coef'] = loss_module.entropy_coef.item()

        if 'issue_advice' in data.keys():
            log_info['train/issue_advice_rate'] = data['issue_advice'].float().mean().item()
        if 'take_advice' in data.keys():
            log_info['train/take_advice_rate'] = data['take_advice'].float().mean().item()

        if "lambda_k" in data.keys():
            log_info["train/lambda_k"] = data["lambda_k"].float().mean().item()
            # print(f"current lambda: {data['lambda_k'].float().mean().item()}")

        jumping_envs = ["multi_grid", "single_cook"]
        if config['env'] in jumping_envs:
            images = data["image"]
            B, T = images.shape[:2]
            images = images.reshape(B * T, *images.shape[2:])

            if 'recipe' in data.keys():
                recipe = data['recipe']
                recipe = recipe.reshape(B * T, *recipe.shape[2:])  # [D, N]
                env_mask = get_env_mask(images, name=config.get("env"), recipe=recipe)
            else:
                env_mask = get_env_mask(images, name=config.get("env"), specific_name=config["name"])

            torch.set_printoptions(threshold=float('inf'))

            # Log the ratio of each unique mask value
            for val in torch.unique(env_mask):
                mask = env_mask == val
                log_info[f"train/env_mask_{val.item()}"] = mask.float().mean().item()
                # print(f"env_mask_{val.item()} is {mask.float().mean().item()}")

            if 'issue_advice' in data.keys():
                issue_advice = data['issue_advice'].reshape(-1)
                assert issue_advice.shape[0] == env_mask.shape[0], "Mismatch between advice and env_mask size"

                for val in torch.unique(env_mask):
                    mask = env_mask == val
                    log_info[f"train/issue_advice_{val.item()}"] = (
                        issue_advice[mask].float().mean().item() if mask.any() else None
                    )
                    # print(f"issue_advice_{val.item()} is {issue_advice[mask].float().mean().item() if mask.any() else None}")

        episode_rewards = data["next", "reward"][data["next", "done"]]
        if len(episode_rewards) > 0:
            if "info" in data.keys():
                episode_info = data["next", "info"][data["next", "done"]]
                episode_rewards = episode_info['sparse_r_by_agent']
                log_info["train/sparse"] = episode_info['sparse_r_by_agent'].mean().item()
                log_info["train/shape"] = episode_info['shaped_r_by_agent'].mean().item()


                # print(f"sparse reward is {episode_info['sparse_r_by_agent']} \n"
                #       f"shape is {episode_info['shaped_r_by_agent']}")

            if config['env'] in jumping_envs:
                done_mask = data["next", "done"].reshape(-1)  # [B*T]
                env_mask = env_mask.to(done_mask.device)

                # Only select masks for episodes where done=True
                selected_env_mask = env_mask[done_mask]
                selected_rewards = episode_rewards.to(done_mask.device)

                for val in torch.unique(selected_env_mask):
                    mask = (selected_env_mask == val)
                    reward_vals = selected_rewards[mask]

                    log_info[f'train/reward_{val.item()}'] = (
                        reward_vals.mean().item() if reward_vals.numel() > 0 else None
                    )

                    # print(f"reward_{val.item()} is {log_info[f'train/reward_{val.item()}']}")

            episode_length = data["next", "step_count"][data["next", "done"]]
            episode_stats = episodes_hist.update({'reward': episode_rewards, 'length': episode_length})
            log_info.update({
                "train/reward": episode_stats['reward'],
                "train/episode_length": episode_stats['length'],
            })

            if 'success' in data.keys():
                success_rate = data["next", "success"][..., None][data["next", "done"]].float().mean()
                log_info.update({'train/success_rate': success_rate.item()})



        data_reshape = data.reshape(-1)
        # Update id_batch dynamically with images from archive_buffer
        if config.get("archive_buffer", False) or config.get("save_images", False):
            archive_buffer.extend(data_reshape)
            # print(f"Number of samples in archive_buffer: {len(archive_buffer)}")
            if len(archive_buffer) > 0:
                num_samples = config.get("archive_buffer_size", 3000)
                archive_sample = archive_buffer.sample(min(num_samples, len(archive_buffer)))
                # print(len(archive_buffer))
                if "image" in archive_sample.keys():
                    images = archive_sample["image"]
                    if config.get("archive_buffer", False):
                        if "recipe" in archive_sample.keys():
                            recipe = archive_sample["recipe"]
                            # print(recipe.shape)
                            # print(images.shape)
                            loss_module.set_id_batch({"image": images, "recipe": recipe})
                        else:
                            loss_module.set_id_batch(images)


        for j in range(num_sgd_iter):
            # Compute GAE
            with torch.no_grad():
                adv_module.value_network.eval()
                data = adv_module(data.to(device, non_blocking=True))
                adv_module.value_network.train()

            data_reshape = data.reshape(-1)

            # Update the data buffer
            data_buffer.extend(data_reshape)

            for k, batch in enumerate(data_buffer):
                batch = batch.to(device, non_blocking=True)
                loss = loss_module(batch)

                losses[j, k] = loss.select(
                    "loss_critic", "loss_entropy", "loss_objective", "loss_kl", "loss_energy", "iso",
                    "loss_unclipped", "loss_clipped", "loss_distill"

                ).detach()

                loss_sum = (
                        loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_kl"] +
                        loss["loss_energy"] + loss["loss_distill"]
                )
                # print(f"loss_ob is {loss['loss_objective'].item()}")
                # print(f"loss_distill is {loss['loss_distill'].item()}")

                loss_sum.backward()
                if "max_grad_norm" in config.keys():
                    # print("Using max norm clipping")
                    torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=config["max_grad_norm"])

                # Update the networks
                optim.step()
                optim.zero_grad()

        lr_scheduler.step()
        log_info.update({'train/lr': lr_scheduler.get_last_lr()[-1]})

        losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
        for key, value in losses_mean.items():
            log_info.update({f'train/{key}': value.item()})
        log_info.update({'train/cur_kl_coeff': loss_module.kl_coeff})

        loss_module.update_kl(losses_mean['loss_kl'])

        with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
            if ((i - 1) * frames_in_batch) // test_interval < (i * frames_in_batch) // test_interval:
                # print("Evaluation Phase")
                actor.eval()
                eval_start = time.time()
                test_logs = eval_model(
                    actor, test_env, max_steps=config['env_config']['max_steps'], num_episodes=100
                )
                if config.get("env") == "single_cook":
                    test_logs['eval/reward'] = test_logs['eval/sparse']

                cur_reward = test_logs['eval/reward']
                recent_eval_rewards.append(cur_reward)
                cur_avg_reward = np.mean(recent_eval_rewards)
                best_moving_avg_reward = max(best_moving_avg_reward, cur_avg_reward)

                max_reward = max(cur_reward, max_reward)
                eval_time = time.time() - eval_start
                log_info.update(test_logs)
                log_info.update({"eval/time": eval_time})
                log_info["eval/moving_avg_reward"] = cur_avg_reward
                log_info["eval/best_moving_avg_reward"] = best_moving_avg_reward

                if not config['debug']:
                    torch.save({'actor': actor.state_dict(), 'critic': critic.state_dict()}, f"{config['exp_dir']}/{config['exp_name']}/model-{i*frames_in_batch}.pt")
                    if max_reward == cur_reward:
                        torch.save(
                            {'actor': actor.state_dict(), 'critic': critic.state_dict()},
                            f"{config['exp_dir']}/{config['exp_name']}/model-best.pt"
                        )

                if isinstance(collector, JumpStartSyncDataCollector):
                    threshold = (1 - collector.tolerance) * best_moving_avg_reward
                    print(f"the threshold is {threshold} and the cur reward is {cur_avg_reward}")
                    if cur_avg_reward >= threshold:
                        collector.set_current_stage()

                if config.get("archive_buffer", False) or config.get("save_images", False):
                    if len(archive_buffer) > 0:
                        num_samples = config.get("archive_buffer_size", 3000)
                        archive_sample = archive_buffer.sample(min(num_samples, len(archive_buffer)))
                        if "image" in archive_sample.keys():
                            images = archive_sample["image"]
                            save_path = f"{config['exp_dir']}/{config['exp_name']}/sample_images_{i * frames_in_batch}.pt"
                            if "recipe" in archive_sample.keys():
                                recipe = archive_sample["recipe"]
                                torch.save({"image": images.cpu(), "recipe": recipe.cpu()}, save_path)
                            else:
                                torch.save(images.cpu(), save_path)
                actor.train()
        #
        # print(log_info)

        if not config['debug']:
            wandb.log(log_info)

        logs.append(log_info)

        collector.update_policy_weights_()

    collector.shutdown()

    if not config['debug']:
        with open(f"{config['exp_dir']}/{config['exp_name']}/logs.json", 'w') as f:
            json.dump(logs, f)
