import teneva
import torch
import torch.optim as optim
from time import perf_counter as tpc
import numpy as np
import csv
import os

class Protes:
    def __init__(self, f, d, n, is_max=True, r=5, device=None, results_dir='.', results_prefix='protes'):
        """
        PyTorch implementation of the PROTES algorithm with incremental logging.

        Args:
            results_dir (str): Directory to save the incremental log files.
        """
        self.f = f; self.d = d; self.n = n; self.is_max = is_max; self.r = r
        if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else: self.device = device
        print(f"Protes (PyTorch) is using device: {self.device}")
        
        # ★★★★★★★★★★★★★★★★ 変更点 1: ファイルパスの設定 ★★★★★★★★★★★★★★★★
        self.results_dir = results_dir
        os.makedirs(self.results_dir, exist_ok=True)
        self.eval_history_filepath = os.path.join(self.results_dir, f'{results_prefix}_eval_history.csv')
        self.best_history_filepath = os.path.join(self.results_dir, f'{results_prefix}_best_history.csv')
        # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
        
        self.P = None; self.optimizer = None; self.info = {}

    # ... ( _generate_initial, _interface_matrices, _sample_batch, _log_likelihood_batch は変更なし) ...
    def _generate_initial(self):
        cores = [torch.rand(1, self.n, self.r, device=self.device, requires_grad=True)]
        for _ in range(self.d - 2):
            cores.append(torch.rand(self.r, self.n, self.r, device=self.device, requires_grad=True))
        cores.append(torch.rand(self.r, self.n, 1, device=self.device, requires_grad=True))
        return cores
    def _interface_matrices(self, P_cores):
        Z = torch.eye(P_cores[-1].shape[-1], device=self.device)
        Zr = [Z]
        for i in range(self.d - 1, 0, -1):
            core = P_cores[i]; core_sum = torch.sum(core, dim=1); Z = torch.matmul(core_sum, Z)
            Z_norm = torch.linalg.norm(Z); Z = Z / (Z_norm + 1e-9)
            Zr.append(Z)
        return Zr[::-1]
    def _sample_batch(self, P_cores, k):
        with torch.no_grad():
            Z_interfaces = self._interface_matrices(P_cores); samples = []
            for _ in range(k):
                sample_indices = []; Q = torch.eye(1, device=self.device)
                for i in range(self.d):
                    core = P_cores[i]; Z = Z_interfaces[i]; Q_core = torch.einsum('pr,rnq->pnq', Q, core)
                    probs = torch.einsum('pnq,qr->pnr', Q_core, Z).abs().squeeze()
                    probs = probs / (probs.sum() + 1e-9)
                    if torch.isnan(probs).any(): probs = torch.ones_like(probs) / self.n
                    idx = torch.multinomial(probs, 1).squeeze()
                    sample_indices.append(idx.item())
                    if i < self.d - 1:
                        selected_core_part = core[:, idx, :]; Q = torch.matmul(Q, selected_core_part)
                        Q_norm = torch.linalg.norm(Q); Q = Q / (Q_norm + 1e-9)
                samples.append(sample_indices)
        return torch.tensor(samples, device=self.device, dtype=torch.long)
    def _log_likelihood_batch(self, P_cores, I_batch):
        batch_size = I_batch.shape[0]; log_probs = torch.zeros(batch_size, device=self.device)
        Z_right = self._interface_matrices(P_cores)
        for j in range(batch_size):
            i_sample = I_batch[j]; log_prob_sample = 0.; Q = torch.eye(1, device=self.device)
            for i in range(self.d):
                core = P_cores[i]; Z = Z_right[i]; Q_core = torch.einsum('pr,rnq->pnq', Q, core)
                probs = torch.einsum('pnq,qr->pnr', Q_core, Z).abs().squeeze()
                probs = probs / (probs.sum() + 1e-9)
                if torch.isnan(probs).any(): probs = torch.ones_like(probs) / self.n
                prob_i = probs[i_sample[i]]; log_prob_sample += torch.log(prob_i + 1e-9)
                if i < self.d - 1:
                    selected_core_part = core[:, i_sample[i], :]; Q = torch.matmul(Q, selected_core_part)
                    Q_norm = torch.linalg.norm(Q); Q = Q / (Q_norm + 1e-9)
            log_probs[j] = log_prob_sample
        return log_probs
        
    def optimize(self, m=10000, k=100, k_top=10, k_gd=1, lr=0.01, P_init=None, log=False):
        time = tpc()
        self.info = {'m': 0, 'y_opt': None, 'i_opt': None}
        
        # ★★★★★★★★★★★★★★★★ 変更点 1: ヘッダーに 'iteration' を追加 ★★★★★★★★★★★★★★★★
        # 全評価履歴ファイルのヘッダーを書き込む
        with open(self.eval_history_filepath, 'w', newline='') as f:
            writer = csv.writer(f)
            # ヘッダーとして次元のインデックスと評価値のカラム名を作成
            header = ['iteration'] + [f'x_{i}' for i in range(self.d)] + ['score']
            writer.writerow(header)
            
        # 最良値履歴ファイルのヘッダーを書き込む
        with open(self.best_history_filepath, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['iteration', 'eval_count', 'best_score'])
        # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★

        if P_init is not None:
            if teneva is None: raise ImportError("`teneva` is required to process P_init.")
            self.r = teneva.ranks(P_init)[len(P_init)//2]
            self.P = [torch.from_numpy(core).float().to(self.device).requires_grad_(True) for core in P_init]
        else:
            self.P = self._generate_initial()
            
        self.optimizer = optim.Adam(self.P, lr=lr)
        num_iters = m // k
        
        for i in range(num_iters):
            I = self._sample_batch(self.P, k)
            y_np = self.f(I.cpu().numpy())
            y = torch.from_numpy(y_np).float().to(self.device)
            self.info['m'] += k

            # ★★★★★★★★★★★★★★★ 変更点 2: 評価履歴の行にイテレーション番号を追加 ★★★★★★★★★★★★★★★
            with open(self.eval_history_filepath, 'a', newline='') as f:
                writer = csv.writer(f)
                I_np = I.cpu().numpy()
                iteration_num = i + 1 # イテレーションは1から始まるように
                for j in range(k):
                    # 各行の先頭に現在のイテレーション番号を追加
                    row = [iteration_num] + I_np[j].tolist() + [y_np[j].item()]
                    writer.writerow(row)
            # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★

            sort_indices = torch.argsort(y, descending=self.is_max)
            y_opt_curr, i_opt_curr_idx = (y[sort_indices[0]], I[sort_indices[0]])
            
            is_new = False
            if self.info['y_opt'] is None or \
               (self.is_max and y_opt_curr > self.info['y_opt']) or \
               (not self.is_max and y_opt_curr < self.info['y_opt']):
                is_new = True
                self.info['y_opt'] = y_opt_curr.item()
                self.info['i_opt'] = i_opt_curr_idx.cpu().numpy()

            # ★★★★★★★★★★★★★★★ 変更点 3: 最良値履歴の行にイテレーション番号を追加 ★★★★★★★★★★★★★★★
            with open(self.best_history_filepath, 'a', newline='') as f:
                writer = csv.writer(f)
                iteration_num = i + 1
                writer.writerow([iteration_num, self.info['m'], self.info['y_opt']])
            # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★

            if log and is_new:
                print(f"Iter {i+1:3d} | m {self.info['m']:6d} | y_opt {self.info['y_opt']:.4f}")
            
            top_indices = sort_indices[:k_top]
            I_top = I[top_indices]
            for _ in range(k_gd):
                self.optimizer.zero_grad()
                log_likelihoods = self._log_likelihood_batch(self.P, I_top)
                loss = -log_likelihoods.mean()
                loss.backward()
                self.optimizer.step()
                
        self.info['t'] = tpc() - time
        print(f"\nOptimization finished in {self.info['t']:.2f} seconds.")
        
        return self.info['i_opt'], self.info['y_opt']