import os
import os.path as osp
import pandas as pd

from dataset_produce import SmilesRepeat

# smiles check website: https://cdb.ics.uci.edu/cgibin/Smi2DepictWeb.py
class CopolymerProcessor:
    def __init__(self):
        self.smiles_repeat = SmilesRepeat(1, task_name='', root='')
    
    def repeat_polymer(self, polymer_smiles, repeat_times):
        """
        Repeat polymer for a specified number of times
        
        Parameters:
            polymer_smiles: SMILES string of the polymer
            repeat_times: Number of repetitions
            
        Returns:
            SMILES string of the repeated polymer
        """
        if repeat_times == 1:
            return polymer_smiles
        else:
            return self.smiles_repeat.dfs(polymer_smiles, repeat_times)
    
    def combine_polymers(self, polymer_a, polymer_b):
        """
        Combine two polymers.
        
        Parameters:
            polymer_a: SMILES string of the first polymer
            polymer_b: SMILES string of the second polymer
            
        Returns:
            SMILES string of the combined polymer
        """
        return self.smiles_repeat.edit_mol(polymer_a, polymer_b)
    
    def alternate_polymer(self, polymer_a, polymer_b, ratio_a, ratio_b):
        """
        Alternately concatenate polymer A and B until one runs out, 
        then append the remaining units of the other.
        
        Parameters:
            polymer_a: SMILES string of polymer A
            polymer_b: SMILES string of polymer B
            ratio_a: number of repeat units of A
            ratio_b: number of repeat units of B
        
        Returns:
            SMILES string of the resulting alternating copolymer
        """
        result = []
        count_a, count_b = 0, 0
        while count_a < ratio_a and count_b < ratio_b:
            result.append(polymer_a)
            count_a += 1
            if count_b < ratio_b:
                result.append(polymer_b)
                count_b += 1
        while count_a < ratio_a:
            result.append(polymer_a)
            count_a += 1
        while count_b < ratio_b:
            result.append(polymer_b)
            count_b += 1
        copolymer = result[0]
        for monomer in result[1:]:
            copolymer = self.smiles_repeat.edit_mol(copolymer, monomer)
        return copolymer
    
    def process_copolymers(self, input_file):
        """
        Process a CSV file to generate copolymers and add the results as a new column.
        
        Parameters:
            input_file: Path to the input CSV file
        """

        df = pd.read_csv(input_file)
        
        df['SMILES'] = None
        
        for index, row in df.iterrows():
            try:
                poly_type = row['poly_type']
                smiles_a = row['smiles_a']
                smiles_b = row['smiles_b']
                ratio_a = int(row['ratio_a'])
                ratio_b = int(row['ratio_b'])
                if poly_type == 'alternating':
                    copolymer = self.alternate_polymer(smiles_a, smiles_b, ratio_a, ratio_b)
                elif poly_type == 'block':
                    repeated_a = self.repeat_polymer(smiles_a, ratio_a)
                    repeated_b = self.repeat_polymer(smiles_b, ratio_b)
                    copolymer = self.combine_polymers(repeated_a, repeated_b)
                else:
                    raise ValueError(f"Unknown poly_type: {poly_type}")
                # 存储结果
                df.at[index, 'SMILES'] = copolymer
                
                print(f"Generated copolymer: {copolymer}")
                
            except Exception as e:
                print(f"Error processing row {index+1}: {e}")
        
        # 保存结果到原文件
        df.to_csv(input_file, index=False)
        print(f"Copolymer generation completed. Results saved to {input_file}")

if __name__ == "__main__":
    pass