import numpy as np
import pandas as pd
import os
import cv2
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from PIL import Image
import math
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Tuple, List

pd.set_option("display.max_rows", 10)
plt.rcParams["font.size"] = 15

BASE_DATA_PATH = '/content'

train_df = pd.read_csv("train.csv")
sample_df = pd.read_csv("sample_submission.csv")
train_image_path = "train_images"
total_training_images = os.listdir(train_image_path)

intensity_df = pd.read_csv("predicted_defect_intensity_scores.csv")
df = pd.merge(train_df, intensity_df, how='left', on='ImageId')

class3 = df[df['ClassId'] == 3].sample(n=200).reset_index(drop=True)

class SteelDefectParameterAugmentor:
    def __init__(self):
        self.normal_ranges = {
            'surface_cleanliness': (95, 100),
            'ambient_humidity': (40, 50),
            'coating_spray_pressure': (2.5, 3.0),
            'coating_viscosity': (80, 100),
            'curing_temperature': (180, 200),
            'curing_time': (20, 25),
            'water_jet_pressure': (180, 200),
            'flow_rate': (100, 120),
            'vibration': (2, 4),
            'drive_load': (10, 15)
        }

        self.defect_ranges = {
            1: {
                'surface_cleanliness': (85, 95),
                'ambient_humidity': (40, 60),
                'coating_spray_pressure': (2.0, 2.8),
                'coating_viscosity': (85, 110),
                'curing_temperature': (140, 175),
                'curing_time': (10, 18),
                'water_jet_pressure': (160, 185),
                'flow_rate': (90, 110),
                'vibration': (3, 6),
                'drive_load': (12, 18)
            },
            2: {
                'surface_cleanliness': (70, 85),
                'ambient_humidity': (50, 75),
                'coating_spray_pressure': (2.0, 2.6),
                'coating_viscosity': (90, 120),
                'curing_temperature': (160, 190),
                'curing_time': (15, 22),
                'water_jet_pressure': (140, 170),
                'flow_rate': (75, 95),
                'vibration': (3, 5),
                'drive_load': (13, 17)
            },
            3: {
                'surface_cleanliness': (75, 90),
                'ambient_humidity': (55, 80),
                'coating_spray_pressure': (1.5, 2.2),
                'coating_viscosity': (110, 150),
                'curing_temperature': (140, 170),
                'curing_time': (10, 18),
                'water_jet_pressure': (150, 175),
                'flow_rate': (80, 100),
                'vibration': (4, 8),
                'drive_load': (16, 22)
            },
            4: {
                'surface_cleanliness': (70, 85),
                'ambient_humidity': (60, 85),
                'coating_spray_pressure': (1.8, 2.4),
                'coating_viscosity': (100, 130),
                'curing_temperature': (150, 175),
                'curing_time': (12, 20),
                'water_jet_pressure': (145, 165),
                'flow_rate': (80, 95),
                'vibration': (3, 6),
                'drive_load': (14, 19)
            }
        }

    def generate_parameter_value(self, param_name: str, defect_class: int,
                                defect_intensity: float, use_correlation: bool = True) -> float:
        normal_min, normal_max = self.normal_ranges[param_name]
        defect_min, defect_max = self.defect_ranges[defect_class][param_name]

        if use_correlation:
            min_val = normal_min + defect_intensity * (defect_min - normal_min)
            max_val = normal_max + defect_intensity * (defect_max - normal_max)
        else:
            min_val, max_val = defect_min, defect_max

        base_value = np.random.uniform(min_val, max_val)
        noise_range = (max_val - min_val) * 0.02
        noise = np.random.normal(0, noise_range)

        return base_value + noise

    def apply_parameter_interactions(self, params: Dict[str, float],
                                   defect_class: int) -> Dict[str, float]:
        adjusted_params = params.copy()

        if params['coating_viscosity'] > 120:
            adjusted_params['coating_spray_pressure'] = min(
                params['coating_spray_pressure'] * 1.1, 3.0
            )

        if params['ambient_humidity'] > 70:
            adjusted_params['curing_time'] = max(
                params['curing_time'] * 0.9, 8
            )

        if params['drive_load'] > 18:
            adjusted_params['vibration'] = min(
                params['vibration'] * 1.2, 10
            )

        if params['water_jet_pressure'] < 160:
            adjusted_params['surface_cleanliness'] = max(
                params['surface_cleanliness'] * 0.95, 60
            )

        return adjusted_params

    def add_temporal_drift(self, params: Dict[str, float]) -> Dict[str, float]:
        drifted_params = params.copy()

        drift_factors = {
            'water_jet_pressure': np.random.uniform(0.95, 1.0),
            'coating_spray_pressure': np.random.uniform(0.97, 1.0),
            'vibration': np.random.uniform(1.0, 1.15),
            'surface_cleanliness': np.random.uniform(0.95, 1.0),
        }

        for param, factor in drift_factors.items():
            if param in drifted_params:
                drifted_params[param] *= factor

        return drifted_params

    def augment_dataset(self, df: pd.DataFrame, add_interactions: bool = True,
                       add_temporal_drift: bool = True) -> pd.DataFrame:
        augmented_df = df.copy()
        param_columns = list(self.normal_ranges.keys())
        
        for col in param_columns:
            augmented_df[col] = 0.0

        print(f"Augmenting {len(df)} defective samples...")

        for idx, row in df.iterrows():
            defect_class = int(row['ClassId'])
            defect_intensity = row['defect_intensity_score']

            params = {}
            for param_name in param_columns:
                params[param_name] = self.generate_parameter_value(
                    param_name, defect_class, defect_intensity
                )

            if add_interactions:
                params = self.apply_parameter_interactions(params, defect_class)

            if add_temporal_drift:
                params = self.add_temporal_drift(params)

            for param_name, value in params.items():
                augmented_df.at[idx, param_name] = value

            if idx % 1000 == 0:
                print(f"Processed {idx} samples...")

        return augmented_df

    def validate_augmentation(self, df: pd.DataFrame) -> None:
        print("\n=== AUGMENTATION VALIDATION ===")

        for defect_class in [1, 2, 3, 4]:
            class_data = df[df['ClassId'] == defect_class]
            if len(class_data) == 0:
                continue

            print(f"\nDefect Class {defect_class} (n={len(class_data)}):")

            for param in ['surface_cleanliness', 'coating_spray_pressure',
                         'ambient_humidity', 'vibration']:
                values = class_data[param]
                print(f"  {param}: {values.mean():.2f} ± {values.std():.2f} "
                      f"[{values.min():.2f}, {values.max():.2f}]")

        print("\n=== PARAMETER-INTENSITY CORRELATIONS ===")
        correlations = {}
        for param in ['surface_cleanliness', 'coating_spray_pressure',
                     'ambient_humidity', 'curing_temperature']:
            corr = df['defect_intensity_score'].corr(df[param])
            correlations[param] = corr
            print(f"{param}: {corr:.3f}")

    def plot_parameter_distributions(self, df: pd.DataFrame, save_path: str = None):
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.ravel()

        key_params = ['surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure',
                     'curing_temperature', 'vibration', 'drive_load']

        for i, param in enumerate(key_params):
            ax = axes[i]

            for defect_class in [1, 2, 3, 4]:
                class_data = df[df['ClassId'] == defect_class][param]
                if len(class_data) > 0:
                    ax.hist(class_data, alpha=0.6, bins=20,
                           label=f'Class {defect_class}', density=True)

            ax.set_xlabel(param.replace('_', ' ').title())
            ax.set_ylabel('Density')
            ax.legend()
            ax.grid(True, alpha=0.3)

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

def augment_steel_defect_dataset(df, output_file_path: str = None):
    print("Loading dataset...")
    print(f"Loaded dataset with {len(df)} samples")
    print(f"Defect class distribution:\n{df['ClassId'].value_counts()}")

    augmentor = SteelDefectParameterAugmentor()

    augmented_df = augmentor.augment_dataset(
        df,
        add_interactions=True,
        add_temporal_drift=True
    )

    augmentor.validate_augmentation(augmented_df)
    augmentor.plot_parameter_distributions(augmented_df)

    if output_file_path:
        augmented_df.to_csv(output_file_path, index=False)
        print(f"\nAugmented dataset saved to: {output_file_path}")

    return augmented_df

class NonDefectiveParameterGenerator:
    def __init__(self):
        self.normal_ranges = {
            'surface_cleanliness': (95, 100),
            'ambient_humidity': (40, 50),
            'coating_spray_pressure': (2.5, 3.0),
            'coating_viscosity': (80, 100),
            'curing_temperature': (180, 200),
            'curing_time': (20, 25),
            'water_jet_pressure': (180, 200),
            'flow_rate': (100, 120),
            'vibration': (2, 4),
            'drive_load': (10, 15)
        }

        self.optimal_values = {
            'surface_cleanliness': 97.5,
            'ambient_humidity': 45.0,
            'coating_spray_pressure': 2.75,
            'coating_viscosity': 90.0,
            'curing_temperature': 190.0,
            'curing_time': 22.5,
            'water_jet_pressure': 190.0,
            'flow_rate': 110.0,
            'vibration': 3.0,
            'drive_load': 12.5
        }

    def generate_normal_parameter(self, param_name: str, quality_level: str = 'high') -> float:
        min_val, max_val = self.normal_ranges[param_name]
        optimal_val = self.optimal_values[param_name]

        if quality_level == 'high':
            range_width = max_val - min_val
            deviation = np.random.uniform(-0.1, 0.1) * range_width
            value = optimal_val + deviation
        elif quality_level == 'medium':
            range_width = max_val - min_val
            deviation = np.random.uniform(-0.3, 0.3) * range_width
            value = optimal_val + deviation
        else:
            value = np.random.uniform(min_val, max_val)

        value = np.clip(value, min_val, max_val)
        noise_range = (max_val - min_val) * 0.01
        noise = np.random.normal(0, noise_range)

        return value + noise

    def apply_normal_parameter_correlations(self, params: Dict[str, float]) -> Dict[str, float]:
        adjusted_params = params.copy()

        if params['water_jet_pressure'] > 185:
            adjusted_params['surface_cleanliness'] = min(
                params['surface_cleanliness'] * 1.02, 100
            )

        viscosity_ratio = params['coating_viscosity'] / 90.0
        if 0.9 <= viscosity_ratio <= 1.1:
            adjusted_params['coating_spray_pressure'] = np.clip(
                params['coating_spray_pressure'] * 1.01, 2.5, 3.0
            )

        if params['ambient_humidity'] < 47:
            adjusted_params['curing_temperature'] = min(
                params['curing_temperature'] * 1.01, 200
            )

        if params['vibration'] < 3.5:
            adjusted_params['drive_load'] = np.clip(
                params['drive_load'] * 0.98, 10, 15
            )

        return adjusted_params

    def add_equipment_stability(self, params: Dict[str, float]) -> Dict[str, float]:
        stable_params = params.copy()

        stability_factors = {
            'water_jet_pressure': np.random.uniform(0.98, 1.02),
            'coating_spray_pressure': np.random.uniform(0.99, 1.01),
            'vibration': np.random.uniform(0.95, 1.05),
            'flow_rate': np.random.uniform(0.98, 1.02),
        }

        for param, factor in stability_factors.items():
            if param in stable_params:
                stable_params[param] *= factor
                min_val, max_val = self.normal_ranges[param]
                stable_params[param] = np.clip(stable_params[param], min_val, max_val)

        return stable_params

    def generate_non_defective_dataset(self, image_list: List[str],
                                     quality_distribution: Dict[str, float] = None) -> pd.DataFrame:
        if quality_distribution is None:
            quality_distribution = {'high': 0.6, 'medium': 0.3, 'low': 0.1}

        data = {
            'ImageId': image_list,
            'ClassId': [0] * len(image_list),
            'defect_intensity_score': [0.0] * len(image_list)
        }

        df = pd.DataFrame(data)
        param_columns = list(self.normal_ranges.keys())
        
        for col in param_columns:
            df[col] = 0.0

        print(f"Generating parameters for {len(image_list)} non-defective images...")

        quality_levels = np.random.choice(
            list(quality_distribution.keys()),
            size=len(image_list),
            p=list(quality_distribution.values())
        )

        for idx, (_, row) in enumerate(df.iterrows()):
            quality_level = quality_levels[idx]

            params = {}
            for param_name in param_columns:
                params[param_name] = self.generate_normal_parameter(param_name, quality_level)

            params = self.apply_normal_parameter_correlations(params)
            params = self.add_equipment_stability(params)

            for param_name, value in params.items():
                df.at[idx, param_name] = value

            if idx % 100 == 0 and idx > 0:
                print(f"Generated {idx} samples...")

        return df

    def validate_non_defective_data(self, df: pd.DataFrame,
                                  defective_df: pd.DataFrame = None) -> None:
        print("\n=== NON-DEFECTIVE DATA VALIDATION ===")
        print(f"Dataset shape: {df.shape}")
        print(f"All ClassId values are 0: {(df['ClassId'] == 0).all()}")
        print(f"All defect_intensity_score values are 0: {(df['defect_intensity_score'] == 0.0).all()}")

        print(f"\n=== PARAMETER STATISTICS ===")
        param_cols = ['surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure',
                     'coating_viscosity', 'curing_temperature', 'curing_time',
                     'water_jet_pressure', 'flow_rate', 'vibration', 'drive_load']

        for param in param_cols:
            values = df[param]
            normal_min, normal_max = self.normal_ranges[param]
            in_range_pct = ((values >= normal_min) & (values <= normal_max)).mean() * 100

            print(f"{param}:")
            print(f"  Range: [{values.min():.2f}, {values.max():.2f}]")
            print(f"  Mean ± Std: {values.mean():.2f} ± {values.std():.2f}")
            print(f"  In normal range: {in_range_pct:.1f}%")

        if defective_df is not None:
            print(f"\n=== COMPARISON WITH DEFECTIVE DATA ===")
            for param in ['surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure', 'vibration']:
                non_def_mean = df[param].mean()
                def_mean = defective_df[param].mean()
                difference = non_def_mean - def_mean
                print(f"{param}: Non-defective {non_def_mean:.2f} vs Defective {def_mean:.2f} "
                      f"(Diff: {difference:+.2f})")

    def plot_parameter_comparison(self, non_def_df: pd.DataFrame,
                                defective_df: pd.DataFrame = None,
                                save_path: str = None):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.ravel()

        key_params = ['surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure',
                     'curing_temperature', 'vibration', 'drive_load']

        for i, param in enumerate(key_params):
            ax = axes[i]

            non_def_values = non_def_df[param]
            ax.hist(non_def_values, alpha=0.7, bins=25, label='Non-defective',
                   color='green', density=True)

            if defective_df is not None:
                def_values = defective_df[param]
                ax.hist(def_values, alpha=0.7, bins=25, label='Defective',
                       color='red', density=True)

            normal_min, normal_max = self.normal_ranges[param]
            ax.axvspan(normal_min, normal_max, alpha=0.2, color='blue',
                      label='Normal Range')

            ax.set_xlabel(param.replace('_', ' ').title())
            ax.set_ylabel('Density')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_title(f'{param.replace("_", " ").title()} Distribution')

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

def create_non_defective_dataset(image_list: List[str],
                               defective_df: pd.DataFrame = None,
                               output_file_path: str = None) -> pd.DataFrame:
    generator = NonDefectiveParameterGenerator()

    non_def_df = generator.generate_non_defective_dataset(
        image_list,
        quality_distribution={'high': 0.7, 'medium': 0.25, 'low': 0.05}
    )

    generator.validate_non_defective_data(non_def_df, defective_df)

    if defective_df is not None:
        generator.plot_parameter_comparison(non_def_df, defective_df)

    if output_file_path:
        non_def_df.to_csv(output_file_path, index=False)
        print(f"\nNon-defective dataset saved to: {output_file_path}")

    return non_def_df

if __name__ == "__main__":
    augmented_data = augment_steel_defect_dataset(
        df,
        output_file_path='steel_defects_augmented.csv'
    )

    print("\n=== SAMPLE AUGMENTED DATA ===")
    sample_cols = ['ImageId', 'ClassId', 'defect_intensity_score',
                   'surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure',
                   'curing_temperature', 'vibration']
    print(augmented_data[sample_cols].head(10).to_string(index=False))

    def calculate_sample_attributions(row):
        defect_class = int(row['ClassId'])
        intensity = row['defect_intensity_score']

        params = {
            'surface_cleanliness': row['surface_cleanliness'],
            'ambient_humidity': row['ambient_humidity'],
            'coating_spray_pressure': row['coating_spray_pressure'],
            'curing_temperature': row['curing_temperature'],
            'vibration': row['vibration'],
            'drive_load': row['drive_load']
        }

        normal_ranges = {
            'surface_cleanliness': 97.5,
            'ambient_humidity': 45,
            'coating_spray_pressure': 2.75,
            'curing_temperature': 190,
            'vibration': 3,
            'drive_load': 12.5
        }

        weight_matrices = {
            1: {'curing_temperature': 0.45, 'ambient_humidity': 0.25, 'vibration': 0.15, 'drive_load': 0.15},
            2: {'surface_cleanliness': 0.50, 'coating_spray_pressure': 0.25, 'ambient_humidity': 0.25},
            3: {'coating_spray_pressure': 0.40, 'ambient_humidity': 0.25, 'vibration': 0.20, 'drive_load': 0.15},
            4: {'ambient_humidity': 0.35, 'surface_cleanliness': 0.25, 'curing_temperature': 0.20, 'vibration': 0.20}
        }

        weights = {}
        if defect_class in weight_matrices:
            for param, base_weight in weight_matrices[defect_class].items():
                if param in params:
                    deviation = abs(params[param] - normal_ranges[param]) / normal_ranges[param]
                    weights[param] = base_weight * intensity * (1 + deviation)

        return weights

    sample_row = augmented_data.iloc[0]
    attributions = calculate_sample_attributions(sample_row)

    print(f"\n=== SAMPLE PARAMETER ATTRIBUTION ===")
    print(f"Image: {sample_row['ImageId']}")
    print(f"Class: {sample_row['ClassId']}")
    print(f"Intensity: {sample_row['defect_intensity_score']:.3f}")
    print("Parameter attributions:")
    for param, weight in sorted(attributions.items(), key=lambda x: x[1], reverse=True):
        print(f"  {param}: {weight:.3f}")

    final_defect_data = augmented_data[['ImageId', 'ClassId', 'defect_intensity_score',
                                       'surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure',
                                       'coating_viscosity', 'curing_temperature', 'curing_time',
                                       'water_jet_pressure', 'flow_rate', 'vibration', 'drive_load']]

    non_defective_list = [i for i in total_training_images if i not in train_df['ImageId'].unique()]
    print(non_defective_list[:10])
    print(len(non_defective_list))

    non_def_data = create_non_defective_dataset(
        image_list=non_defective_list,
        defective_df=None,
        output_file_path='non_defective_steel_data.csv'
    )

    print("\n=== SAMPLE NON-DEFECTIVE DATA ===")
    print(non_def_data.head())

    print(f"\n=== PARAMETER RANGES IN NON-DEFECTIVE DATA ===")
    param_cols = ['surface_cleanliness', 'ambient_humidity', 'coating_spray_pressure', 'vibration']
    for col in param_cols:
        print(f"{col}: [{non_def_data[col].min():.2f}, {non_def_data[col].max():.2f}] "
              f"(Mean: {non_def_data[col].mean():.2f})")

    df_sampled = (
        final_defect_data.groupby("ClassId", group_keys=False)
        .apply(lambda x: x.sample(n=200, random_state=42))
    )

    df_sampled = df_sampled.reset_index(drop=True)
    final_df = pd.concat([non_def_data, df_sampled], ignore_index=True)

    final_df.to_csv('steel_data_augmented.csv', index=False)
    print(final_df['ClassId'].value_counts())