from typing import Optional, List, Dict, Any, Union, Tuple
import os
import time
import numpy as np
import json
import random
import torch
import wandb
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding

from PIL import Image
from utils.config import Config
from utils.gpt import GPTService
from experts import create_experts, ExpertInfo
from collections import Counter

class ImageEnv(gym.Env):
    def __init__(
            self,
            config: Config,
    ):
        self.config = config
        self.logger = self.config.logger

        self.gpt_service = GPTService()
        self.experts = create_experts()

        self.model_usage_counts = self.config.model_usage_counts
        self.model_usage_type_counts = self.config.model_usage_type_counts
        self.model_usage_type_score_sums = self.config.model_usage_type_score_sums
        self.model_score_sums = self.config.model_score_sums
        self.step_times = self.config.step_times

        is_all_empty = not self.model_usage_counts and not self.model_usage_type_counts and not self.model_usage_type_score_sums and not self.model_score_sums and not self.step_times

        self.expert_descriptions = {}
        self.name_to_model = {}  # Maps expert names to model_N
        self.model_to_name = {}  # Maps model_N to expert names
        self.t2i_model_idx = []
        self.i2i_model_idx = []
        
        for idx, (expert_name, expert) in enumerate(self.experts.items()):
            model_id = f"model_{idx}"

            # Store mappings
            self.name_to_model[expert_name] = model_id
            self.model_to_name[model_id] = expert_name

            # Store descriptions using the expert name
            self.expert_descriptions[expert_name] = ExpertInfo(**expert.info())

            if self.expert_descriptions[expert_name].is_t2i:
                self.t2i_model_idx.append(idx)
            else:
                self.i2i_model_idx.append(idx)

            # Initialize model usage counts and rewards if not already passed
            if is_all_empty: self.model_usage_counts[expert_name] = 0
            if is_all_empty: self.model_score_sums[expert_name] = 0

            if is_all_empty and not self.expert_descriptions[expert_name].is_t2i: 
                self.model_usage_type_counts[expert_name] = Counter()
            if is_all_empty and not self.expert_descriptions[expert_name].is_t2i: 
                self.model_usage_type_score_sums[expert_name] = Counter()

        self.config.logger.info(f"Experts / Model Numbers: {self.model_to_name}")

        self.dataset = []
        with open(self.config.dataset_path, 'r') as file:
            for line in file:
                row = json.loads(line)
                self.dataset.append(row)
    
        self.num_experts = len(self.experts)
        self.action_space = spaces.Discrete(self.num_experts)
        
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(self.config.text_embedding_dim,),
            dtype=np.float32
        )
        
        self.max_steps = self.config.max_steps
        self.current_data_idx = self.config.idx_to_start
        self.global_step_counter = self.config.global_step_counter
        self.current_task = None
        self.task_type = None
        self.original_image_path = None
        self.original_prompt = None
        self.image_to_use = None 
        self.current_state = None
        self.steps_taken = 0
        self.episode_count = 0

        if self.config.eval_mode:
            self.images_dict = {}

        self.seed(self.config.random_seed)
    
    def seed(self, seed: Optional[int] = None):
        if not hasattr(self, "np_random"):
            self.np_random = np.random.default_rng()
        self.np_random, seed = seeding.np_random(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        return [seed]

    def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for text input."""
        embedding = self.gpt_service.get_embedding(text)
        return np.array(embedding, dtype=np.float32)

    def _load_image(self, image_path: str) -> Optional['Image.Image']:
        """Load an image from path and return PIL Image object."""
        if not image_path:
            return None
        try:
            return Image.open(image_path)
        except Exception as e:
            print(f"Error loading image from {image_path}: {e}")
            return None

    def _save_generated_image(self, image_input: 'PIL.Image.Image', expert_name: str) -> str:
        # episode dir
        episode_dir = os.path.join(self.config.experiment_dir, f"episode_{self.episode_count}")
        os.makedirs(episode_dir, exist_ok=True)

        # create a step dir
        step_dir = os.path.join(episode_dir, f"step_{self.steps_taken}")
        os.makedirs(step_dir, exist_ok=True)

        image_path = os.path.join(step_dir, f"{self.steps_taken}_{expert_name}.png")
        image_input.save(image_path)

        return image_path, step_dir

    def _save_info(self, expert_name: str, reward: float, remaining_tasks: str, dir_path: str):
        info = {
            "current_task": self.current_task,
            "remaining_tasks": remaining_tasks,
            "original_prompt": self.original_prompt,
            "expert_name": expert_name,
            "reward": reward,
            "task_type": self.task_type,
            "data_idx": self.current_data_idx,
            "steps_taken": self.steps_taken,
            "global_step_counter": self.global_step_counter,
        }

        with open(os.path.join(dir_path, "info.json"), "w") as f:
            json.dump(info, f, indent=4)

    def _save_model_stats(self):
        stats = {
            "global_step_counter": self.global_step_counter,
            "data_idx": self.current_data_idx,
            "model_usage_counts": self.model_usage_counts,
            "model_score_sums": self.model_score_sums,
            "model_usage_type_counts": {k: dict(v) for k, v in self.model_usage_type_counts.items()},
            "model_usage_type_score_sums": {k: dict(v) for k, v in self.model_usage_type_score_sums.items()},
            "step_times": self.step_times,
        }
        os.makedirs(self.config.stats_dir, exist_ok=True)
        stats_path = os.path.join(self.config.stats_dir, f"{self.global_step_counter}_model_stats.json")
        with open(stats_path, "w") as f:
            json.dump(stats, f, indent=4)
    
    def _log_task_distribution_data(self):
        for expert_name, task_counts in self.model_usage_type_counts.items():
            average_scores = [[task_type, self.model_usage_type_score_sums[expert_name][task_type] / self.model_usage_type_counts[expert_name][task_type] if self.model_usage_type_counts[expert_name][task_type] > 0 else 0] for task_type in task_counts]
            table = wandb.Table(data=average_scores, columns=["Task Type", "Average Score"])
            wandb.log({
                f"task_distributions/{expert_name}_task_type_average_scores": wandb.plot.bar(
                    table,
                    "Task Type",
                    "Average Score",
                    title=f"Task Type Average Scores for {expert_name}"
                )
            })

            data_for_table = [[task_type, count] for task_type, count in task_counts.items()]
            table = wandb.Table(data=data_for_table, columns=["Task Type", "Count"])
            wandb.log({
                f"task_distributions/{expert_name}_task_distribution": wandb.plot.bar(
                    table,
                    "Task Type",
                    "Count",
                    title=f"Task Distribution for {expert_name}"
                )
            })

    def _log_model_usage_and_average_scores(self):
        data_for_table = [[expert_name, self.model_usage_counts[expert_name]] for expert_name in self.model_usage_counts]
        table = wandb.Table(data=data_for_table, columns=["Expert", "Count"])
        wandb.log({
            "model_metrics/model_usage_counts": wandb.plot.bar(
                table,
                "Expert",
                "Count",
                title="Model Usage Counts"
            )
        })

        average_scores = [[expert_name, self.model_score_sums[expert_name] / self.model_usage_counts[expert_name] if self.model_usage_counts[expert_name] > 0 else 0] for expert_name in self.model_usage_counts]
        table = wandb.Table(data=average_scores, columns=["Expert", "Average Score"])
        wandb.log({
            "model_metrics/model_average_scores": wandb.plot.bar(
                table,
                "Expert",
                "Average Score",
                title="Model Average Scores"
            )
        })    
    
    def reset(self, seed: Optional[int] = None):
        if seed is not None:
            self.seed(seed)

        self.current_data = self.dataset[self.current_data_idx]
        self.current_task = self.current_data["prompt"]
        self.original_prompt = self.current_data["prompt"]
        self.original_image_path = self.current_data["img_path"] if self.current_data["img_path"] != '' else None

        self.image_to_use = self.original_image_path
        self.steps_taken = 0
        
        self.remaining_tasks = ""
        self.task_type = None
        self.current_state = self._get_embedding(self.current_task)

        self.logger.info(f"#")
        self.logger.info(f"#")
        self.logger.info(f"*"*100)
        self.logger.info(f"STARTING EPISODE {self.episode_count}")
        self.logger.info(f"Current data idx: {self.current_data_idx}")
        self.logger.info(f"Prompt: {self.original_prompt}")
        self.logger.info(f"*"*100)
        self.logger.info(f"#")
        self.logger.info(f"#")

        return self.current_state, {}
    
    def step(self, action):
        assert self.action_space.contains(action)

        self.logger.info(f"*"*50 + " STEP " + str(self.steps_taken) + " " + "*"*50)

        time_start = time.time()
        model_id = f"model_{int(action)}"

        expert_name = self.model_to_name.get(model_id)
        if not expert_name:
            raise ValueError(f"Unknown model ID: {model_id}")

        expert = self.experts.get(expert_name)
        if not expert:
            raise ValueError(f"No expert found for name: {expert_name}")

        self.logger.info(f"Selected Expert: {expert_name}")
        self.logger.info(f"Current task: {self.current_task}")
        self.logger.info(f"Image to use: {self.image_to_use}")
        self.logger.info(f"*"*50)

        payload = {"text_prompt": self.current_task}

        expert_info = self.expert_descriptions[expert_name]
        if "image" in expert_info.required_inputs:
            if not self.image_to_use:
                raise ValueError(f"{expert_name} requires an image but none is available")
            image = self._load_image(self.image_to_use)
            if not image:
                raise ValueError(f"Failed to load image from {self.image_to_use}")
            payload["image"] = image

        # Validate all required inputs are present
        missing_inputs = set(expert_info.required_inputs) - set(payload.keys())
        if missing_inputs:
            raise ValueError(f"Missing required inputs for {expert_name}: {missing_inputs}")

        max_retries = 5
        for attempt in range(max_retries):
            try:
                generated_image = expert.run(**payload)
                break
            except Exception as e:
                self.logger.warning(f"Error on expert.run with {expert_name} (attempt {attempt+1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(1.5 * (attempt + 1))
                    continue
                raise
        
        generated_image_path, dir_path = self._save_generated_image(generated_image, expert_name)

        score, feedback = self.gpt_service.score_image_similarity(
            original_prompt=self.original_prompt,
            current_task=self.current_task, 
            previous_feedback=self.remaining_tasks,
            generated_image_path=generated_image_path,
            original_image_path=self.image_to_use,
        )
        reward = score / 10.0 - 0.05 * self.steps_taken

        if self.task_type: 
            self.model_usage_type_counts[expert_name][self.task_type] += 1 # task type is None for the first step
            self.model_usage_type_score_sums[expert_name][self.task_type] += score

        next_task, remaining_tasks, next_task_type = self.gpt_service.extract_command(
            original_prompt=self.original_prompt,
            remaining_tasks=feedback,
        )
        self._save_info(expert_name, reward, remaining_tasks, dir_path)
        self.current_task, self.remaining_tasks, self.task_type = next_task, remaining_tasks, next_task_type

        self.image_to_use = generated_image_path
        self.current_state = self._get_embedding(self.current_task)

        is_complete = self.current_task == "" and self.remaining_tasks == ""
        done = is_complete or (self.steps_taken >= self.max_steps)

        if done:
            self.current_data_idx = (self.current_data_idx + 1) % len(self.dataset)
            self.episode_count += 1
            if self.config.eval_mode:
                self.images_dict[self.original_prompt] = self.image_to_use

        self.logger.info(f"Reward: {reward}")
        self.logger.info(f"Remaining tasks: {self.remaining_tasks}")
        self.logger.info(f"Next task: {self.current_task}")
        self.logger.info(f"Task type: {self.task_type}")
        self.logger.info(f"*"*50)

        time_end = time.time()
        time_taken = time_end - time_start
        self.step_times.append(time_taken)
        self.logger.info(f"Step {self.steps_taken} of episode {self.episode_count} completed in {time_taken:.2f} seconds")
        self.logger.info(f"*"*100)

        self.model_usage_counts[expert_name] += 1
        self.model_score_sums[expert_name] += score

        self.steps_taken += 1
        self.global_step_counter += 1

        if self.global_step_counter % self.config.log_interval == 0:
            self.logger.info(f"########################## SAVING MODEL STATS FOR GLOBAL STEP {self.global_step_counter} ##########################")
            self._save_model_stats()
            self._log_task_distribution_data()
            self._log_model_usage_and_average_scores()
            wandb.log({
                "model_metrics/step_times_average": np.mean(self.step_times)
            })  

        return self.current_state, reward, done, False, {}
