import os
import gc
from importlib import import_module
import logging
import random
from typing import List

from omegaconf import DictConfig
import hydra
import torch 
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
from torch.nn.parallel import DistributedDataParallel

from inference_rlhf.code.helpers.io import *
from inference_rlhf.code.helpers.utils import set_seeds, rget_json_files_from_dir, construct_sparse_matrix
from inference_rlhf.code.helpers.constructors import construct_collator, construct_reward_model, construct_policy_model
from inference_rlhf.code.helpers.constructors import dataloader_factory

log = logging.getLogger(__name__)

def get_prompts(cfg):
    dl = dataloader_factory(cfg.task.name, cfg)
    return dl.questions

def check_and_add_prompts(cfg, _outputs): 
    if isinstance(_outputs, dict):
        outputs = [v for value in _outputs.values() for v in value]
    elif isinstance(_outputs, list):
        outputs = _outputs
    else:
        raise TypeError("Parsed outputs must be dict or list")

    prompts = get_prompts(cfg)
    
    for id, output in enumerate(outputs): 
        output['prompt'] = prompts[int(output['prompt_idx'])]
        output['id'] = id

    return outputs

def load_generation_files(cfg: DictConfig, load_path: str) -> List[str]:
    """
    Load generation files from load path. Filter appropriately.

    Args:
        cfg: Hydra config
        load_path: Path to load generation files from

    Returns:
        List of generation files
    """
    # Get all json files from load path
    generation_files = rget_json_files_from_dir(load_path)

    # Filter out LLM direct coreset files
    generation_files = [gf for gf in generation_files if not 'coreset' in gf]

    # Filter out eval results MBPP files
    generation_files = [gf for gf in generation_files if not 'eval_results' in gf]

    # Filter based on temperature
    generation_files = [
        gf for gf in generation_files 
        if f'--temp-{cfg.sampling.temperature}-' in gf or f'temp_{cfg.sampling.temperature}-' in gf
    ]

    # Filter based on top-p
    generation_files = [gf for gf in generation_files if f'--top-p-{cfg.sampling.top_p}' in gf]

    # Filter based on min-p
    if cfg.sampling.min_p > 0.0:
        generation_files = [gf for gf in generation_files if f'--min-p-{cfg.sampling.min_p}' in gf]
    else:
        generation_files = [gf for gf in generation_files if not '--min-p' in gf]

    # Filter out checked files
    if cfg.task.name == 'code_contests' or cfg.task.name == 'mbpp':
        generation_files = [gf for gf in generation_files if not '--CHECKED' in gf]

    # Filter out files that already have a features file
    generation_files = [gf for gf in generation_files if not os.path.exists(gf.replace('.json', f'--{cfg.policy.name}-features.tar'))]

    # Subsample if needed
    if cfg.plot.subsample_size:
        generation_files.sort()
        generation_files = random.sample(generation_files, cfg.plot.subsample_size)

    # Filter all files except specified generation idx if provided
    if cfg.task.generation.generation_idx is not None:
        generation_files = [gf for gf in generation_files if f'prompt-idx-{cfg.task.generation.generation_idx}-' in gf]

    # When debugging, only use the first few files
    if cfg.debug:
        generation_files = generation_files[:8]

    # Sort for reproducibility
    generation_files.sort()

    return generation_files

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    set_seeds(cfg.seed)
    task = cfg.task.name
    reward = cfg.reward.name

    if cfg.amlt:
        cfg.io.save_root = os.environ.get('AMLT_OUTPUT_DIR')
        cfg.io.load_root = cfg.blob_root
        log.info(f"AMLT_OUTPUT_DIR: {cfg.io.save_root}")

    # Get load path
    load_path = cfg.io.load_root
    policy = cfg.policy.name
    load_path = os.path.join(load_path, 'data', task, policy)
    log.info(f"Loading generations from root @ {load_path}")
    
    generation_files = load_generation_files(cfg, load_path)

    log.info(f'Found the following {len(generation_files)} generation files:')
    for gf in generation_files:
        log.info(f"\t{gf}")

    # Setup collator
    if cfg.evaluation.model_type == "policy":
        collator = construct_collator(cfg.policy.name, cfg)
    else:
        collator = construct_collator(cfg.reward.name, cfg)

    # Load reward model
    if cfg.evaluation.model_type == "reward":
        log.info(f'Loading reward model from {cfg.reward.model}')
        model = construct_reward_model(cfg.reward.name, cfg)
    
    elif cfg.evaluation.model_type == "policy":
        log.info(f'Loading policy model from {cfg.policy.model}')
        model = construct_policy_model(cfg.policy.name, cfg)

    # Setup accelerator
    accelerator = Accelerator() 
    log.info(f"Accelerate main process: {accelerator.is_main_process}")

    # Prepare reward model
    network = accelerator.prepare(model.network)
    model.network = network
    del network

    if cfg.evaluation.collect_gradients:
        lm_head = model.network.module.lm_head if isinstance(model.network, DistributedDataParallel) else model.network.lm_head
        sparse_matrix_512 = construct_sparse_matrix(np.zeros((1, np.prod(lm_head.weight.shape))), 512).to(accelerator.device)
        sparse_matrix_1024 = construct_sparse_matrix(np.zeros((1, np.prod(lm_head.weight.shape))), 1024).to(accelerator.device)

    # Iterate through generation files
    for generation_file in tqdm(generation_files, desc="Labeling rewards ..."): 
        log.info(f'Obtaining reward labels for {generation_file}')
        reward_file = generation_file.replace('.json', f'--{cfg.reward.name}.json')

        # Load generation file
        parsed_outputs = json_load(generation_file)

        # Add problem prompts back in
        parsed_outputs = check_and_add_prompts(cfg, parsed_outputs)

        if cfg.debug:
            parsed_outputs = parsed_outputs[:10]

        # Setup dataloader
        dataloader = DataLoader(
            parsed_outputs, 
            batch_size=cfg.evaluation.batch_size,
            collate_fn=collator,
            shuffle=False, 
        )
        dataloader = accelerator.prepare(dataloader)

        # Setup data to collect across gpus
        all_scores = []
        all_ids = []

        all_last_hidden_states = []
        all_mean_hidden_states = []
        all_second_to_last_hidden_states = []

        all_gradients_512 = []
        all_gradients_1024 = []
        all_gradient_norms = []
        all_gradient_sums = []

        # Iterate through dataloader
        for idx, batch in tqdm(enumerate(dataloader), desc="Forwarding through RM ...", total=len(dataloader)):
            inputs, ids = batch

            if idx % 5 == 0:
                log.info(f"batch #: {idx} / {len(dataloader)}")

            if cfg.evaluation.model_type == "reward":
                scores, last_hidden_states = model.get_reward(inputs, return_hidden_states=True)
                all_scores.append(torch.tensor(scores, device=accelerator.device))
                all_last_hidden_states.append(last_hidden_states)

            elif cfg.evaluation.model_type == "policy":
                if cfg.evaluation.collect_gradients:
                    output = model.forward(inputs)
                    grad = torch.autograd.grad(
                        outputs=output.loss,
                        inputs=model.network.module.lm_head.weight if isinstance(model.network, DistributedDataParallel) else model.network.lm_head.weight,
                        retain_graph=False,
                    )[0]
                    all_gradient_sums.append(torch.sum(grad.float(), dim=0, keepdim=True))
                    grad = grad.view(1, -1).float()
                    grad_512 = grad @ sparse_matrix_512
                    grad_1024 = grad @ sparse_matrix_1024
                    all_gradients_512.append(grad_512)
                    all_gradients_1024.append(grad_1024)
                    all_gradient_norms.append(torch.norm(grad, dim=-1, keepdim=True))
                else:
                    last_hidden_states, second_to_last_hidden_states, mean_hidden_states = model.inference(inputs, return_hidden_states=True)
                    all_second_to_last_hidden_states.append(second_to_last_hidden_states)
                    all_mean_hidden_states.append(mean_hidden_states)
                    all_last_hidden_states.append(last_hidden_states)
                
            all_ids.append(torch.tensor(ids, device=accelerator.device))

        # Concatenate across gpus
        if cfg.evaluation.model_type == "reward":
            all_scores = torch.cat(all_scores, dim=0) 
            all_last_hidden_states = torch.cat(all_last_hidden_states, dim=0)

        elif cfg.evaluation.model_type == "policy":
            if cfg.evaluation.collect_gradients:
                all_gradients_512 = torch.cat(all_gradients_512, dim=0)
                all_gradients_1024 = torch.cat(all_gradients_1024, dim=0)
                all_gradient_norms = torch.cat(all_gradient_norms, dim=0)
                all_gradient_sums = torch.cat(all_gradient_sums, dim=0)
            else:
                all_second_to_last_hidden_states = torch.cat(all_second_to_last_hidden_states, dim=0)
                all_mean_hidden_states = torch.cat(all_mean_hidden_states, dim=0)
                all_last_hidden_states = torch.cat(all_last_hidden_states, dim=0)

        all_ids = torch.cat(all_ids, dim=0)

        if cfg.evaluation.collect_gradients:
            local_id_to_gradients_512 = {id.item(): gradient for id, gradient in zip(all_ids, all_gradients_512)}
        else:
            local_id_to_last_hidden_states = {id.item(): last_hidden_state for id, last_hidden_state in zip(all_ids, all_last_hidden_states)}
        
        # Gather across gpus
        # NOTE: we rely on accelerator.gather() always collecting things in the same order
        if cfg.evaluation.model_type == "reward":
            all_scores = accelerator.gather(all_scores)
            all_last_hidden_states = accelerator.gather(all_last_hidden_states)
        elif cfg.evaluation.model_type == "policy":
            if cfg.evaluation.collect_gradients:
                all_gradients_512 = accelerator.gather(all_gradients_512)
                all_gradients_1024 = accelerator.gather(all_gradients_1024)
                all_gradient_norms = accelerator.gather(all_gradient_norms)
                all_gradient_sums = accelerator.gather(all_gradient_sums)
            else:
                all_second_to_last_hidden_states = accelerator.gather(all_second_to_last_hidden_states)
                all_mean_hidden_states = accelerator.gather(all_mean_hidden_states)
                all_last_hidden_states = accelerator.gather(all_last_hidden_states)

        all_ids = accelerator.gather(all_ids)

        log.info(f"Unique ids gathered: {len(set(all_ids.tolist()))}")

        # Check that gathering is in the same order
        if cfg.evaluation.collect_gradients:
            global_id_to_gradients_512 = {id.item(): gradient for id, gradient in zip(all_ids, all_gradients_512)}
            for id, gradient in local_id_to_gradients_512.items():
                assert torch.allclose(global_id_to_gradients_512[id], gradient)
        else:
            global_id_to_last_hidden_states = {id.item(): last_hidden_state for id, last_hidden_state in zip(all_ids, all_last_hidden_states)}
            for id, last_hidden_state in local_id_to_last_hidden_states.items():
                assert torch.allclose(global_id_to_last_hidden_states[id], last_hidden_state)

        # Check that lengths are equal
        if cfg.evaluation.collect_gradients:
            assert len(all_ids) == len(all_gradients_512), f"Lengths are not equal: {len(all_ids)} != {len(all_gradients_512)}"
        else:
            assert len(all_ids) == len(all_last_hidden_states), f"Lengths are not equal: {len(all_ids)} != {len(all_last_hidden_states)}"
        log.info("All asserts passed!")

        # Create id mappings
        if cfg.evaluation.model_type == "reward":
            id_to_score = {id.item(): score.item() for id, score in zip(all_ids, all_scores)}
            id_to_last_hidden_states = {id.item(): last_hidden_state for id, last_hidden_state in zip(all_ids, all_last_hidden_states)}

        elif cfg.evaluation.model_type == "policy":
            if cfg.evaluation.collect_gradients:
                id_to_gradients_512 = {id.item(): gradient for id, gradient in zip(all_ids, all_gradients_512)}
                id_to_gradients_1024 = {id.item(): gradient for id, gradient in zip(all_ids, all_gradients_1024)}
                id_to_gradient_norms = {id.item(): gradient_norm for id, gradient_norm in zip(all_ids, all_gradient_norms)}
                id_to_gradient_sums = {id.item(): gradient_sum for id, gradient_sum in zip(all_ids, all_gradient_sums)}
            else:
                id_to_second_to_last_hidden_states = {id.item(): second_to_last_hidden_state for id, second_to_last_hidden_state in zip(all_ids, all_second_to_last_hidden_states)}
                id_to_mean_hidden_states = {id.item(): mean_hidden_state for id, mean_hidden_state in zip(all_ids, all_mean_hidden_states)}
                id_to_last_hidden_states = {id.item(): last_hidden_state for id, last_hidden_state in zip(all_ids, all_last_hidden_states)}

        # Update outputs with their corresponding scores based on id
        if cfg.evaluation.model_type == "reward":
            updated_outputs = [{
                **output, 
                reward: id_to_score[output['id']],
            } for output in parsed_outputs]
    
            # Filter prompt and id from outputs
            updated_outputs = [
                {k: v for k, v in output.items() if k != 'prompt' and k != 'id'} 
                for output in updated_outputs
            ]

        # Create dict to save features
        if cfg.evaluation.model_type == "reward":
            features = {
                "last_hidden_state": torch.stack([id_to_last_hidden_states[output['id']].cpu() for output in parsed_outputs], dim=0), 
            }
            
        elif cfg.evaluation.model_type == "policy":
            if cfg.evaluation.collect_gradients:
                features = {
                    "gradients_512": torch.stack([id_to_gradients_512[output['id']].cpu() for output in parsed_outputs], dim=0),
                    "gradients_1024": torch.stack([id_to_gradients_1024[output['id']].cpu() for output in parsed_outputs], dim=0),
                    "gradient_norms": torch.stack([id_to_gradient_norms[output['id']].cpu() for output in parsed_outputs], dim=0),
                    "gradient_sums": torch.stack([id_to_gradient_sums[output['id']].cpu() for output in parsed_outputs], dim=0),
                }
            else:
                features = {
                    **(
                        {
                            "last_hidden_state": torch.stack([id_to_last_hidden_states[output['id']].cpu() for output in parsed_outputs], dim=0),
                        }
                        if cfg.evaluation.collect_last_hidden_states else {}
                    ),
                    **(
                        {
                            "second_to_last_hidden_state": torch.stack([id_to_second_to_last_hidden_states[output['id']].cpu() for output in parsed_outputs], dim=0),
                        }
                        if cfg.evaluation.collect_second_to_last_hidden_states else {}
                    ),
                    **(
                        {
                            "mean_hidden_state": torch.stack([id_to_mean_hidden_states[output['id']].cpu() for output in parsed_outputs], dim=0),
                        }
                        if cfg.evaluation.collect_mean_hidden_states else {}
                    )
                }

        # Save relevant files only in main process
        if accelerator.is_main_process:
            if cfg.debug:
                # save in debug json file
                debug_file = generation_file.replace('.json', f'--{cfg.reward.name}_debug.json')
                json_dump(updated_outputs, debug_file)
                log.info(f'Successfully saved updated outputs to {debug_file}.')

                # Save features
                features_file = debug_file.replace('.json', f'--features.tar')
                torch.save(features, features_file)
                log.info(f'Successfully saved features to {features_file}.')
            else:
                if cfg.evaluation.model_type == "reward":
                    # Save updated outputs
                    log.info(f'Saving updated outputs to {reward_file} ...')
                    if cfg.amlt:
                        reward_file = os.path.join(cfg.io.save_root, reward_file.replace(cfg.blob_root, "")[1:])
                        os.makedirs(os.path.dirname(reward_file), exist_ok=True)

                    json_dump(updated_outputs, reward_file)
                    log.info(f'Successfully saved updated outputs to {reward_file}.')

                    features_file = reward_file.replace('.json', f'--{cfg.reward.name}-features.tar')

                    if cfg.amlt:
                        features_file = os.path.join(cfg.io.save_root, features_file.replace(cfg.blob_root, "")[1:])
                        os.makedirs(os.path.dirname(features_file), exist_ok=True)

                elif cfg.evaluation.model_type == "policy":
                    if cfg.evaluation.collect_gradients:
                        features_file = generation_file.replace('.json', f'--{cfg.policy.name}-gradients.tar')
                    else:
                        features_file = generation_file.replace('.json', f'--{cfg.policy.name if not cfg.checkpoint_dir else cfg.policy.name + "-" + cfg.checkpoint_dir.split("checkpoints/")[1].replace("/", "-")}-features.tar')

                    if cfg.amlt:
                        features_file = os.path.join(cfg.io.save_root, features_file.replace(cfg.blob_root, "")[1:])
                        os.makedirs(os.path.dirname(features_file), exist_ok=True)

                torch.save(features, features_file)
                log.info(f'Successfully saved features to {features_file}.')

        log.info(f"Waiting for everyone to finish ...")
        accelerator.wait_for_everyone()

        log.info(f"Cleaning up ...")
        del dataloader, parsed_outputs, all_scores, all_ids, all_last_hidden_states, features
        if cfg.evaluation.model_type == "policy":
            del all_second_to_last_hidden_states, all_mean_hidden_states
        gc.collect()
        torch.cuda.empty_cache()
        accelerator.free_memory()

        log.info(f"Done cleaning!")
        accelerator.wait_for_everyone()
        
    
if __name__ == '__main__':
    main()
