import hydra
import torch
import logging
import random
import json
import jsonlines
import einops
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from omegaconf import DictConfig, OmegaConf
from transformers import LlamaForCausalLM
from peft import PeftModel
from vqlm.vqvae_muse import get_tokenizer_muse
from layout_parser import ActionParser
from collections import defaultdict, Counter
from accelerate.utils import is_peft_model
log = logging.getLogger(__name__)

def seed_everything(seed = 42):
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class TokenizedDataset(Dataset):
    def __init__(self, filepath):
        self.input_ids_list = []
        self.target_ids_list = []
        self.input_state_list = []
        self.meta_list = []
        with jsonlines.open(filepath) as reader:
            for obj in reader:
                self.input_ids_list.append(torch.tensor(obj['input_tokens'], dtype=torch.long))
                self.target_ids_list.append(torch.tensor(obj['output_tokens'], dtype=torch.long))
                self.input_state_list.append(obj['input_state'])
                self.meta_list.append(obj['meta'])
    
    def __len__(self):
        return len(self.input_ids_list)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids_list[idx],
            "target_ids": self.target_ids_list[idx],
            "input_state": self.input_state_list[idx],
            "meta": self.meta_list[idx],
        }
    
def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    target_ids = torch.stack([item['target_ids'] for item in batch])
    
    input_state = [item['input_state'] for item in batch]
    meta = [item['meta'] for item in batch]
    
    return {
        "input_ids": input_ids,
        "target_ids": target_ids,
        "input_state": input_state,
        "meta": meta,
    }


def show_stats(base_dir, initial_level=3, test_size=1):
    # Data structure for statistics
    level_stats = defaultdict(lambda: {
        "true": 0,
        "false": 0,
        "false_dirs": [],
        "true_action_lengths": [],
        'all_actions_valid_false': [],
        "any_actions_invalid_false": [],
        "false_action_lengths": [],
        "valid" : 0,
        "invalid" : 0
    })

    # Iterate through all subdirectories
    for subdir in sorted(os.listdir(base_dir), key=lambda x: int(x) if x.isdigit() else float('inf')):  # Ensure numerical order
        subdir_path = os.path.join(base_dir, subdir)

        # Determine the level
        if subdir.isdigit():
            subdir_index = int(subdir)
            level = (subdir_index // test_size) + initial_level

            # Check if it is a directory
            if os.path.isdir(subdir_path):
                json_file_path = os.path.join(subdir_path, "parsed_actions.json")

                # Check if the JSON file exists
                if os.path.exists(json_file_path):
                    with open(json_file_path, "r", encoding="utf-8") as f:
                        data = json.load(f)

                        # Check the value of the "complete" key
                        if isinstance(data, dict) and "complete" in data:
                            action_list_length = len(data.get("action_list", []))
                            
                            if data["complete"] is True:
                                level_stats[level]["true"] += 1
                                level_stats[level]["true_action_lengths"].append(action_list_length)
 
                            elif data["complete"] is False:
                                level_stats[level]["false"] += 1
                                level_stats[level]["false_dirs"].append(subdir)
                                level_stats[level]["false_action_lengths"].append(action_list_length)
                                if all(action[1] != "invalid" for action in data['action_list']):
                                    level_stats[level]["all_actions_valid_false"].append(subdir)
                                else:
                                    level_stats[level]["any_actions_invalid_false"].append(subdir)

                            for action in data['action_list']:
                                if action[1] != "invalid":
                                    level_stats[level]['valid'] += 1
                                else:
                                    level_stats[level]['invalid'] += 1

    false_list = []
    acc = []

    # Print statistics per level
    for level in sorted(level_stats.keys()):
        true_count = level_stats[level]["true"]
        false_count = level_stats[level]["false"]
        total_count = true_count + false_count
        true_percentage = (true_count / total_count * 100) if total_count > 0 else 0
        acc.append(true_percentage)
        print(f"\nLevel {level}:")
        print(f"  - Number of True: {true_count}")
        print(f"  - Number of False: {false_count}")
        print(f"        - All actions valid but not optimal: {len(level_stats[level]['all_actions_valid_false'])}")
        print(f"        - Any actions invalid: {len(level_stats[level]['any_actions_invalid_false'])}")
        print(f"  - Accuracy (True Percentage): {true_percentage:.2f}%")
        print(f"  - Number of Valid: {level_stats[level]['valid']}")
        print(f"  - Number of Invalid: {level_stats[level]['invalid']}")

        # Print action list length distribution
        true_action_lengths = level_stats[level]["true_action_lengths"]
        false_action_lengths = level_stats[level]["false_action_lengths"]

        if true_action_lengths:
            true_length_counts = Counter(true_action_lengths)
            print(f"  - Action List Length Distribution (Complete=True):")
            print(f"    Min: {min(true_action_lengths)}, Max: {max(true_action_lengths)}, Avg: {sum(true_action_lengths)/len(true_action_lengths):.2f}")
            for length, count in sorted(true_length_counts.items()):
                percentage = (count / len(true_action_lengths)) * 100
                print(f"    Length {length}: {count} ({percentage:.2f}%)")

        if false_action_lengths:
            false_length_counts = Counter(false_action_lengths)
            print(f"  - Action List Length Distribution (Complete=False):")
            print(f"    Min: {min(false_action_lengths)}, Max: {max(false_action_lengths)}, Avg: {sum(false_action_lengths)/len(false_action_lengths):.2f}")
            for length, count in sorted(false_length_counts.items()):
                percentage = (count / len(false_action_lengths)) * 100
                print(f"    Length {length}: {count} ({percentage:.2f}%)")

    print(f"Average accuracy: {sum(acc)/len(acc):.2f}%")
    return false_list

@hydra.main(config_path="configs", config_name="eval", version_base=None)
def main(cfg: DictConfig):
    if cfg.show_stats:
        show_stats(cfg.evaluation_result_folder_pth)
        return

    seed_everything(cfg.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = get_tokenizer_muse().to(device)

    action_parser = ActionParser(tokenizer)

    base_model = LlamaForCausalLM.from_pretrained(
        cfg.model_path,
        torch_dtype=torch.bfloat16,
        use_safetensors=True
    )

    model = PeftModel.from_pretrained(base_model, cfg.lora_path, torch_dtype=torch.bfloat16) if cfg.lora_path else base_model
    if is_peft_model(model):
        print(f"PEFT model detected. Loading checkpoint from {cfg.lora_path}")
        model = model.merge_and_unload()
    model.to(device)

    dataset = TokenizedDataset(cfg.test_dataset_pth)
    dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    assert cfg.batch_size == 1, "for now, we only support batch size 1 for evaluation."

    correct_count = 0
    total_count = len(dataset)

    # check if evaluation_result_folder_pth exist, if not, create it
    os.makedirs(cfg.evaluation_result_folder_pth, exist_ok=True)

    for idx, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
        result_folder = os.path.join(cfg.evaluation_result_folder_pth, str(idx))
        save_path = os.path.join(result_folder, f"evaluation_{idx}.png")
        json_path = os.path.join(result_folder, "parsed_actions.json")

        if os.path.exists(result_folder) and os.path.exists(save_path) and os.path.exists(json_path):
            with open(json_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                if data["complete"]:
                    correct_count += 1
                log.info(f"Skipping env {idx} as it has already been evaluated.")
                continue
        os.makedirs(result_folder, exist_ok=True)


        input_ids = batch["input_ids"].to(device)
        input_state = batch["input_state"][0][0]
        meta = batch["meta"][0]
        expected_move = meta["distance_map"][str(input_state)]

        log.info(f"\nEvaluating env {idx}\nStart pos: {input_state}, Expected move: {expected_move}")
        
        new_tokens = [input_ids]
        num_generation = expected_move

        with torch.no_grad():
            for i in range(num_generation):
                output_ids = model.generate(
                    input_ids=input_ids[:, -256: ],
                    attention_mask=torch.ones_like(input_ids[:, -256: ]),
                    pad_token_id=8192,
                    max_new_tokens=256,
                    do_sample=False,
                    # temperature=temperature,
                    suppress_tokens=list(range(8192, model.vocab_size)),
                )
                input_ids = output_ids[:, -256: ]
                new_tokens.append(input_ids)
        new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)

        # split into two parts
        half = len(new_tokens) // 2
        tokens1, tokens2 = new_tokens[:half], new_tokens[half:]

        # decode the first half
        images1 = tokenizer.decode_code(tokens1)
        images1 = torch.clamp(images1, 0.0, 1.0).detach().cpu()  # 提前搬到 CPU
        del tokens1
        torch.cuda.empty_cache()

        # decode the second half
        images2 = tokenizer.decode_code(tokens2)
        images2 = torch.clamp(images2, 0.0, 1.0).detach().cpu()
        del tokens2
        torch.cuda.empty_cache()

        # concatenate 
        all_images = torch.cat([images1, images2], dim=0)
        new_images = einops.rearrange(all_images, 'b c h w -> b h w c').numpy()

        data = {
            "start_coords": [ActionParser.get_coordinate_from_state(input_state, meta['level'])],
            "action_list": [],
            "complete" : False,
            "level" : meta['level'],
            "start_pos" : input_state,
            "target_pos" : meta['target_pos'],
            "layout" : meta['layout'],
            "distance_map" : meta['distance_map']
        }

        for i in range(len(new_images)-1):
            input_img = new_images[i]
            pred_img = new_images[i+1]

            dict = action_parser.parse_action_in_imgs(input_img, pred_img, meta['level'], data['start_coords'][-1], meta['target_pos'])

            action = dict['action']
            next_coord = dict['pred_coord']
            data['start_coords'].append(next_coord)
            data['action_list'].append(action)

        if ActionParser.get_coordinate_from_state(meta['target_pos'], meta['level']) == data['start_coords'][expected_move]:
            if all(action[1] != "invalid" for action in data['action_list'][:expected_move]):
                data['complete'] = True
                correct_count += 1
                log.info(f"Correctly found the optimal path!")

        fig, axes = plt.subplots(1, len(new_images), figsize=(len(new_images) * 3, 3))
        
        for ax, image in zip(axes, new_images):
            ax.imshow(image)
            ax.axis("off")

        plt.tight_layout()
        plt.savefig(save_path, bbox_inches="tight")

        plt.close(fig)

        log.info(f"Saved evaluation results to {save_path}")

        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=4)

        del input_ids, new_tokens
        torch.cuda.empty_cache()
        
    log.info(f"Evaluation Accuracy: {correct_count/total_count:.2%}")

if __name__ == "__main__":
    main()