from datetime import datetime, timedelta
from collections import defaultdict
import json
from multiprocessing import Process, Manager

class TradeMarket:
    def __init__(
        self,
        json_files,
        stock_names,
        step_size_sec: int = 1,
        history_points: int = 20,
        history_min_gap_sec: int = 60
    ):
        self.market_data = {}
        self.trade_in_second = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        self.sim_time = None
        self.real_time = None
        self.trade_count = defaultdict(int)
        self.step_size = step_size_sec
        self.history_points = history_points
        self.history_min_gap_sec = history_min_gap_sec
        self.current_date = None
        self.time_cursor = None

        stock_times = []

        for file_path, stock_name in zip(json_files, stock_names):
            with open(file_path, "r") as f:
                raw_data = json.load(f)
            self.market_data[stock_name] = {}
            stock_time_set = set()

            for date, time_data in raw_data.items():
                self.market_data[stock_name][date] = {}
                for timestamp, ohlc in time_data.items():
                    self.market_data[stock_name][date][timestamp] = {
                        "high": ohlc["high"],
                        "low": ohlc["low"]
                    }
                    dt = datetime.strptime(f"{date} {timestamp}", "%Y-%m-%d %H:%M:%S")
                    stock_time_set.add(dt)

            stock_times.append(stock_time_set)

        # ✅ 构建多个股票之间的时间交集
        common_times = set.intersection(*stock_times)
        self.common_timestamps = sorted(list(common_times))
        self.current_index = 0

        if self.common_timestamps:
            self.current_date = self.common_timestamps[0].strftime("%Y-%m-%d")
            self.time_cursor = self.common_timestamps[0]
        else:
            raise ValueError("No overlapping timestamps across stocks.")

    def _apply_linear_decay(self, high, low, rank, total_agents):
        avg = (high + low) / 2
        if total_agents <= 1:
            alpha = 0.0
        else:
            alpha = min(rank / (total_agents - 1), 1.0)
        new_high = high - alpha * (high - avg)
        new_low = low + alpha * (avg - low)
        return new_high, new_low

    def get_current_price(self, dt: datetime, rank: int = 0, total_agents: int = 1):
        date_str = dt.strftime("%Y-%m-%d")
        sec_str = dt.strftime("%H:%M:%S")
        price_dict = {}
        for stock, data in self.market_data.items():
            if sec_str not in data[date_str]:
                price_dict[stock] = {"buy": None, "sell": None}
                continue
            entry = data[date_str][sec_str]
            high, low = entry["high"], entry["low"]
            if total_agents > 1:
                adj_high, adj_low = self._apply_linear_decay(high, low, rank, total_agents)
                price_dict[stock] = {"buy": adj_low, "sell": adj_high}
            else:
                price_dict[stock] = {"buy": low, "sell": high}
        return price_dict

    def get_next_time(self):
        while self.current_index < len(self.common_timestamps):
            next_time = self.common_timestamps[self.current_index]
            if self.sim_time is None or (next_time - self.sim_time).total_seconds() >= self.step_size:
                self.time_cursor = next_time
                self.sim_time = next_time
                self.current_index += 1
                return next_time
            self.current_index += 1
        return None  # simulation end

    def run_simulation(self, agents: list):
        print("[⏳] Waiting for all agents to load...")
        for a in agents:
            pass
            _ = a.model
        print("[✅] All agents loaded. Starting simulation.")

        while True:
            dt = self.get_next_time()
            if dt is None:
                print("[🔚] Reached end of timestamp list.")
                break

            self.trade_count[dt] = 0
            prices_snapshot = self.get_current_price(dt, total_agents=0)
            triggering_agents = [a for a in agents if a.should_call_llm(prices_snapshot)]
            self.trade_in_second[dt] = triggering_agents

            if not triggering_agents:
                continue

            self.real_time = datetime.now()

            with Manager() as manager:
                shared_dict = manager.dict()
                timing_dict = manager.dict()
                processes = []

                def run_decision(agent, i):
                    import time
                    start = time.time()
                    shared_dict[i] = agent.decide_trades(prices_snapshot)
                    timing_dict[i] = time.time() - start

                for i, agent in enumerate(triggering_agents):
                    p = Process(target=run_decision, args=(agent, i))
                    p.start()
                    processes.append((p, agent, i))

                for p, _, _ in processes:
                    p.join()

                sorted_indices = sorted(timing_dict.items(), key=lambda x: x[1])
                for rank, (i, _) in enumerate(sorted_indices):
                    agent = triggering_agents[i]
                    agent.apply_trade(shared_dict[i], self, dt)

            if self.time_cursor.strftime("%H:%M:%S") > "15:30:00":
                print("[🕒] Simulation cutoff reached at 15:30.")
                break
