from typing import Optional, List, Tuple, Dict, Any, Union
from gym import spaces
from gym.spaces import Box, Discrete
from xuance.environment import RawEnvironment
from xuance.environment import RawMultiAgentEnv

import os
import time
import math
import torch
import random
import copy
from itertools import product
from dataclasses import dataclass
from distance import hamming
from collections import Counter
from RNA import fold
import numpy as np
import pandas as pd
import torch.nn.functional as F
from xuance.torch.utils.operations import set_seed


@dataclass
class RnaDesignEnvironmentConfig:
    """
    Dataclass for the configuration of the environment.
    """

    mutation_threshold: int = 5
    reward_exponent: float = 1.0
    state_radius: int = 5
    use_conv: bool = True
    use_embedding: bool = False
    diversity_loss: bool = False


def _string_difference_indices(s1, s2):
    """
    Returns all indices where s1 and s2 differ.
    """
    return [index for index in range(len(s1)) if s1[index] != s2[index]]


def _encode_dot_bracket(secondary, env_config):
    """
    Encode the dot_bracket notated target structure. The encoding can either be binary
    or by the embedding layer.
    """

    padding = "=" * env_config.state_radius
    padded_secondary = padding + secondary + padding

    site_encoding = {".": 0, "(": 1, ")": 1, "=": 0}

    if env_config.use_conv:
        return np.array([[site_encoding[site]] for site in padded_secondary])

    return np.array([site_encoding[site] for site in padded_secondary])


def _encode_nucleotide(design, padding_length, env_config):
    site_encoding = {"A": 1, "U": 2, "G": 3, "C": 4}

    encoded_sequence = np.array([site_encoding[nucleotide] for nucleotide in design])

    padded_sequence = np.pad(encoded_sequence, (padding_length, padding_length), 'constant', constant_values=0)

    if env_config.use_conv:
        return padded_sequence.reshape(-1, 1)
    return padded_sequence


def _encode_pairing(secondary):
    pairing_encoding = [None] * len(secondary)
    stack = []

    for index, symbol in enumerate(secondary, 0):
        if symbol == "(":
            stack.append(index)
        elif symbol == ")":
            paired_site = stack.pop()
            pairing_encoding[paired_site] = index
            pairing_encoding[index] = paired_site

    return pairing_encoding


class _Target(object):
    _id_counter = 0

    def __init__(self, dot_bracket, env_config, target_id=None):
        """
        Initialize a target structure.
        Args:
             dot_bracket: dot_bracket encoded target structure.
             env_config: The environment configuration.
        """
        _Target._id_counter += 1
        if target_id is None:
            self.id = _Target._id_counter
        else:
            self.id = target_id
        self.dot_bracket = dot_bracket
        self._pairing_encoding = _encode_pairing(self.dot_bracket)
        self.padded_encoding = _encode_dot_bracket(self.dot_bracket, env_config)

    def __len__(self):
        return len(self.dot_bracket)

    def get_paired_site(self, site):
        return self._pairing_encoding[site]


class _Design(object):
    action_to_base = {0: "G", 1: "A", 2: "U", 3: "C"}
    action_to_pair = {0: "GC", 1: "CG", 2: "AU", 3: "UA"}
    site_encoding = {"A": 1, "U": 2, "G": 3, "C": 4}

    def __init__(self, length=None, primary=None, env_config=None):
        if primary:
            self._primary_list = primary
        else:

            self._primary_list = self.random_initialize(length)
        self._dot_bracket = None
        self.env_config = env_config

        self.padding_length = env_config.state_radius_action

        self._design_padded_encoding = _encode_nucleotide(self._primary_list, self.padding_length, self.env_config)

    def random_initialize(self, length):
        nucleotide = np.random.choice(['A', 'U', 'G', 'C'])
        return [nucleotide] * length

    def random_mutate(self, mutation_rate=0.3):
        nucleotides = ['A', 'U', 'G', 'C']
        length = len(self._primary_list)
        num_mutations = int(length * mutation_rate)

        mutation_indices = random.sample(range(length), num_mutations)

        for idx in mutation_indices:
            original_nucleotide = self._primary_list[idx]
            new_nucleotide = original_nucleotide
            while new_nucleotide == original_nucleotide:
                new_nucleotide = random.choice(nucleotides)
            self._primary_list[idx] = new_nucleotide

    def get_mutated(self, mutations, sites):
        mutatedprimary = self._primary_list.copy()
        for site, mutation in zip(sites, mutations):
            mutatedprimary[site] = mutation
        return _Design(primary=mutatedprimary, env_config=self.env_config)

    def assign_sites(self, action, site, paired_site=None):
        if paired_site:
            base_current, base_paired = self.action_to_pair[action]
            self._primary_list[site] = base_current
            self._primary_list[paired_site] = base_paired

            self._design_padded_encoding[site + self.padding_length] = self.site_encoding[base_current]
            self._design_padded_encoding[paired_site + self.padding_length] = self.site_encoding[base_paired]

        else:
            base_current = self.action_to_base[action]
            self._primary_list[site] = base_current

            self._design_padded_encoding[site + self.padding_length] = self.site_encoding[base_current]

    def first_unassigned_site(self, agent_id, target, last_site):
        if agent_id == 0 or agent_id == 1:
            code = agent_id
            indices = np.where(target == code)[0]
            if indices.size > 0:

                if last_site is None:
                    last_site = -1
                greater_indices = indices[indices > last_site]
                if greater_indices.size > 0:
                    return greater_indices[0]
                else:
                    return indices[0]
            else:
                return None
        elif agent_id == 2 or agent_id == 3:
            code = agent_id % 2
            indices = np.where(target == code)[0]
            if indices.size > 0:
                if last_site is None:
                    last_site = len(target)
                less_indices = indices[indices < last_site]
                if less_indices.size > 0:
                    return less_indices[-1]
                else:
                    return indices[-1]
            else:
                return None

    def change_current_design(self, primary_list):

        self._primary_list = primary_list

    @property
    def primary(self):

        return "".join(self._primary_list)

    @property
    def primary_list(self):
        return self._primary_list

    @property
    def primary_encoding(self):
        return self._design_padded_encoding


def _random_epoch_gen(data):
    while True:
        for i in np.random.permutation(len(data)):
            yield data[i]


def _sequential_epoch_gen(data):
    while True:
        for i in range(len(data)):
            yield data[i]


@dataclass
class EpisodeInfo:
    __slots__ = ["target_id",
                 "time",
                 "normalized_hamming_distance",
                 "hamming_distance",
                 "structure",
                 "sequence",
                 ]
    target_id: int
    time: float
    normalized_hamming_distance: float
    hamming_distance: int
    structure: str
    sequence: str


def read_rna_files(directory):
    rna_data = []
    max_length = 0

    for filename in os.listdir(directory):
        if filename.endswith(".rna"):
            file_id = filename.split('.')[0]
            file_path = os.path.join(directory, filename)

            with open(file_path, 'r') as file:
                structure = file.read().strip()

            rna_data.append((file_id, structure))

            structure_length = len(structure)
            if structure_length > max_length:
                max_length = structure_length

    data = pd.DataFrame(rna_data, columns=['Id', 'str'])

    dot_brackets = list((i, db) for i, db in zip(data['Id'], data['str']))
    return dot_brackets


def read_task_description(path):
    tasks = []

    with open(path, 'r') as file:
        lines = file.readlines()

        for line in lines[1:]:

            if not line.strip():
                continue

            parts = line.strip().split(',')
            if len(parts) == 2:
                task_id, structure = parts
                tasks.append({'Id': task_id, 'str': structure})

    data = pd.DataFrame(tasks)
    dot_brackets = list((i, db) for i, db in zip(data['Id'], data['str']))

    return dot_brackets


def split_sequence(sequence_length, n_agents):
    base_length = sequence_length // n_agents
    remainder = sequence_length % n_agents
    lengths = [base_length] * n_agents
    for i in range(remainder):
        lengths[i] += 1
    return lengths


class MultiAgentRNAEnv(RawMultiAgentEnv):

    def __init__(self, env_config):
        super(RawMultiAgentEnv, self).__init__()
        self._env_config = env_config
        set_seed(self._env_config.seed)
        self._state_dim = (self._env_config.state_radius * 2 + 1) + (self._env_config.state_radius_action * 2 + 1)
        self.agents = ['agent_0', 'agent_1', 'agent_2', 'agent_3']
        agent_observation = Box(low=-np.inf, high=np.inf, shape=(self._state_dim, 1,), dtype=np.float32)
        self.observation_space = {k: agent_observation for k in self.agents}
        self.state_space = Box(low=-np.inf, high=np.inf, shape=(self._env_config.state_dim, 1,), dtype=np.float32)
        agent_action_space = Discrete(4)
        self.action_space = {k: agent_action_space for k in self.agents}
        self.num_agents: Optional[int] = 4
        self.teams_info = {
            "names": ['agent'],
            "num_teams": 1,
            "agents_in_team": [['agent_0', 'agent_1', 'agent_2', 'agent_3']]
        }
        if self._env_config.input_file:
            dot_brackets = read_task_description(self._env_config.dot_brackets_dir)
        else:
            dot_brackets = self._env_config.dot_brackets

        if isinstance(dot_brackets[0], str):
            targets = [_Target(dot_bracket, self._env_config) for dot_bracket in dot_brackets]
        else:
            targets = [_Target(dot_bracket, self._env_config, target_id=i) for i, dot_bracket in dot_brackets]

        if self._env_config.test:
            self._target_gen = _sequential_epoch_gen(targets)
        else:
            self._target_gen = _random_epoch_gen(targets)

        self.target = None
        self.design = None
        self.episodes_info = []
        self.predictions = []
        self.split_lengths = []
        self.env_id = env_config.env_id
        self.max_episode_steps = env_config.max_episode_steps
        self._current_step = 0
        self.total_step = 0
        self.best_design = None
        self.mismatch = None
        self.last_normalized_distance = None
        self.best_normalized_distance = None
        self.last_site = [None] * self.num_agents
        self.last_target = None
        self.last_primary_list = None
        self.last_completed = False
        self.redesign_num = 0

    def reset(self, **kwargs):

        if self._env_config.test and not self.last_completed and self.last_target is not None:
            self.target = copy.deepcopy(self.last_target)
            self.design = _Design(len(self.target), primary=self.last_primary_list, env_config=self._env_config)
            self.redesign_num += 1
        else:
            self.target = next(self._target_gen)
            self.design = _Design(len(self.target), env_config=self._env_config)
            self.redesign_num = 0
        self.target_nonpadded = copy.deepcopy(
            self.target.padded_encoding[self._env_config.state_radius:-self._env_config.state_radius])
        self.split_lengths = split_sequence(len(self.target), self.num_agents)
        self._current_step = 0
        self.best_design = copy.deepcopy(self.design.primary_list)
        hamming_distance, _, _, mismatch = self.my_hamming(self.design.primary, self.target.dot_bracket)
        self.mismatch = mismatch
        self.last_normalized_distance = hamming_distance / len(self.target)
        self.best_normalized_distance = copy.deepcopy(self.last_normalized_distance)
        self.last_site = [None] * self.num_agents
        observation = [None] * self.num_agents
        for agent_id in range(self.num_agents):
            observation[agent_id] = self._get_state(agent_id)
        observation_dict = {key: observation[index] for index, key in enumerate(self.agents)}

        return observation_dict, {}

    def step(self, action_dict):
        agent_terminated = [False] * self.num_agents
        observation = [None] * self.num_agents

        actions_list = [action_dict[key] for key in self.agents]
        for agent_id in range(self.num_agents):
            current_site = self.last_site[agent_id]
            if current_site is not None:
                action = actions_list[agent_id]
                self._apply_action(action, current_site)
                agent_terminated[agent_id] = False
            else:
                agent_terminated[agent_id] = True

        self._current_step += 1
        self.total_step += 1
        truncated = False if self._current_step < self.max_episode_steps else True
        reward, rna_info, normalized_distance, diff, self.mismatch = self._get_reward(truncated, actions_list)

        terminal = (normalized_distance <= self._env_config.tolerance)
        if terminal:
            reward += 3
            if self._env_config.test:
                self.last_target = self.target
                self.last_completed = True
                self.redesign_num = 0
        reward_dict = {k: reward for k in self.agents}

        for agent_id in range(self.num_agents):
            agent_terminated[agent_id] = agent_terminated[agent_id] or terminal
            observation[agent_id] = self._get_state(agent_id)
        observation_dict = {key: observation[index] for index, key in enumerate(self.agents)}
        terminated_dict = {key: agent_terminated[index] for index, key in enumerate(self.agents)}
        return observation_dict, reward_dict, terminated_dict, truncated, rna_info

    def _get_reward(self, truncated, actions):
        primary = self.design.primary
        hamming_distance, diff, folded_design, mismatch = self.my_hamming(primary, self.target.dot_bracket)

        normalized_distance = hamming_distance / len(self.target)
        new_actions = None

        if 0 < normalized_distance <= 0.2 and self.total_step < 200000:
            candidates, new_actions = self._local_improvement(actions, hamming_distance, diff)
            if candidates is not None:
                self.design.change_current_design(candidates[1].primary_list)
                primary = candidates[1].primary
                hamming_distance, diff, folded_design, mismatch = self.my_hamming(primary, self.target.dot_bracket)
                normalized_distance = hamming_distance / len(self.target)
        reward = self.last_normalized_distance - normalized_distance
        self.last_normalized_distance = copy.deepcopy(normalized_distance)

        if self._env_config.test:
            if normalized_distance < self.best_normalized_distance:
                self.best_normalized_distance = normalized_distance
                self.best_design = copy.deepcopy(self.design.primary_list)

            if truncated:
                self.design.change_current_design(self.best_design)
                normalized_distance = self.best_normalized_distance
                primary = self.design.primary
                hamming_distance, _, folded_design, _ = self.my_hamming(primary, self.target.dot_bracket)

                if (self.redesign_num < 1):
                    self.last_target = self.target
                    self.last_primary_list = copy.deepcopy(self.design.primary_list)
                    self.last_completed = False
                else:
                    self.last_target = self.target
                    self.last_primary_list = None
                    self.last_completed = True

        if hamming_distance == 0:
            self.last_completed = True

        rna_info = {"target": self.target.dot_bracket,
                    "rna_sequence": primary,
                    "folded_design": folded_design,
                    "hamming": hamming_distance,
                    "rate": 1 - normalized_distance,
                    "new_actions": new_actions,
                    "last_completed": self.last_completed,
                    "num_redesigned": self.redesign_num}

        return reward, rna_info, normalized_distance, diff, mismatch

    def _local_improvement(self, actions, current_hamming, diff):
        candidates = None
        new_actions = None
        mut_agent_ids = []
        mut_sites = []
        old_actions = []
        for agent_id in range(self.num_agents):
            site = self.last_site[agent_id]
            if diff[site] == 1:
                mut_agent_ids.append(agent_id)
                mut_sites.append(site)
                old_actions.append(actions[agent_id])

        search_num = len(mut_agent_ids)
        for mutated_acts in product([0, 1, 2, 3], repeat=search_num):
            if search_num == 1 and mutated_acts[0] == old_actions[0]:
                continue
            if search_num > 1 and (mutated_acts == np.array(old_actions)).all():
                continue
            mutated_design = _Design(primary=copy.deepcopy(self.design.primary_list), env_config=self._env_config)

            for i, act in enumerate(mutated_acts):
                current_site = mut_sites[i]
                paired_site = self.target.get_paired_site(current_site)
                mutated_design.assign_sites(act, current_site, paired_site)

            folded_mutated, _ = fold(mutated_design.primary)
            hamming_distance = hamming(folded_mutated,
                                       self.target.dot_bracket)
            if hamming_distance < current_hamming:
                current_hamming = copy.deepcopy(hamming_distance)
                for i, agent_id in enumerate(mut_agent_ids):
                    actions[agent_id] = mutated_acts[i]
                new_actions = copy.deepcopy(actions)
                candidates = copy.deepcopy(
                    (hamming_distance, mutated_design, folded_mutated))

            if hamming_distance == 0:
                break

        return candidates, new_actions

    def _apply_action(self, action, current_site):
        paired_site = self.target.get_paired_site(current_site)
        self.design.assign_sites(action, current_site, paired_site)

    def _get_state(self, agent_id):
        start = self.design.first_unassigned_site(agent_id, self.target_nonpadded, self.last_site[agent_id])
        self.last_site[agent_id] = start
        if start is None:
            if self._env_config.use_conv:
                return np.zeros(self._state_dim).reshape(-1, 1)
            return np.zeros(self._state_dim)
        padded_target = self.target.padded_encoding[start: start + 2 * self._env_config.state_radius + 1]
        padded_action = self.design.primary_encoding[start: start + 2 * self._env_config.state_radius_action + 1]
        return np.concatenate((padded_target, padded_action))

    def state(self):
        nonpadded_target = self.target.padded_encoding[self._env_config.state_radius:-self._env_config.state_radius]
        nonpadded_action = self.design.primary_encoding[
                           self._env_config.state_radius_action:-self._env_config.state_radius_action]

        if self._env_config.use_conv:
            padding = np.zeros(int(self._env_config.state_dim / 2) - len(nonpadded_target)).reshape(-1, 1)
        else:
            padding = np.zeros(int(self._env_config.state_dim / 2) - len(nonpadded_target))
        state_target = np.concatenate((nonpadded_target, padding))
        state_action = np.concatenate((nonpadded_action, padding))
        state = np.concatenate((state_target, state_action))

        return state

    def close(self):
        pass

    @property
    def states(self):
        type = "int" if self._env_config.use_embedding else "float"
        if self._env_config.use_conv:
            return dict(type=type, shape=(self._state_dim, 1))
        return dict(type=type, shape=(self._state_dim,))

    def my_hamming(self, design, target):
        folded_design, _ = fold(design)
        diff, hamming_distance, mismatch = self.generate_diff(folded_design, target)
        return hamming_distance, diff, folded_design, mismatch

    def generate_diff(self, target_encoding, design_encoding):
        diff = np.zeros(len(target_encoding), dtype=np.float32)
        mismatch = []
        for i, (t, d) in enumerate(zip(target_encoding, design_encoding)):

            if t != d:
                diff[i] = 1
                mismatch.append(i)
        hamming_distance = np.sum(diff)
        mismatch = np.array(mismatch)
        return diff, hamming_distance, mismatch

    @property
    def actions(self):
        return dict(type="int", num_actions=4)

    def render(self, *args, **kwargs):
        return np.ones([64, 64, 64])

    def __str__(self):
        return "RnaDesignEnvironment"

    def seed(self, seed):
        return None
