import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
from scipy import stats
import pandas as pd

plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

class ExperimentAnalyzer:
    def __init__(self, results_dir="results"):
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(exist_ok=True)
        self.data = {}
        
    def load_results(self):
        
        for json_file in self.results_dir.glob("*.json"):
            optimizer_name = json_file.stem
            try:
                with open(json_file, 'r') as f:
                    self.data[optimizer_name] = json.load(f)
            except Exception as e:
                continue
        
    def _calculate_confidence_interval(self, values, confidence=0.95):
        n = len(values)
        if n < 2:
            return np.mean(values), 0, np.mean(values), np.mean(values)
        
        mean = np.mean(values)
        sem = stats.sem(values)
        
        margin = stats.norm.ppf((1 + confidence) / 2) * sem
        
        return mean, margin, mean - margin, mean + margin
    
    def plot_both_splits_comparison_with_ci(self, confidence=0.95, save=True):
        if not self.data:
            return
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))
        
        colors = plt.cm.tab20c(np.linspace(0, 1, len(self.data)))
        
        splits = ['val', 'test']
        axes = [ax1, ax2]
        split_titles = ['val', 'test']
        
        for split_idx, (split, ax, title) in enumerate(zip(splits, axes, split_titles)):
            for opt_idx, (optimizer_name, optimizer_data) in enumerate(self.data.items()):
                if split not in optimizer_data:
                    continue
                    
                seeds_data = np.array(optimizer_data[split])
                
                mean_per_epoch = np.mean(seeds_data, axis=0)
                epochs = np.arange(1, len(mean_per_epoch) + 1)
                
                ci_lower = []
                ci_upper = []
                for epoch_idx in range(seeds_data.shape[1]):
                    epoch_values = seeds_data[:, epoch_idx]
                    mean, margin, lower, upper = self._calculate_confidence_interval(
                        epoch_values, confidence
                    )
                    ci_lower.append(lower)
                    ci_upper.append(upper)
                
                ci_lower = np.array(ci_lower)
                ci_upper = np.array(ci_upper)
                
                line = ax.plot(epochs, mean_per_epoch, 
                             label=f'{optimizer_name}', 
                             linewidth=2.5,
                             alpha=0.6,
                             color=colors[opt_idx])
                
                ax.fill_between(epochs, ci_lower, ci_upper,
                              alpha=0.2,
                              color=colors[opt_idx])
                
                last_mean = mean_per_epoch[-1]
                last_margin = ci_upper[-1] - last_mean
                ax.plot(epochs[-1], last_mean, 'o', 
                       markersize=8, 
                       color=colors[opt_idx],
                       markeredgecolor='white',
                       alpha=0.6,
                       markeredgewidth=1.5)
                
                ax.annotate(f'{last_mean:.3f}±{last_margin:.3f}', 
                          xy=(epochs[-1], last_mean),
                          xytext=(5, 5), textcoords='offset points',
                          fontsize=9,
                          bbox=dict(boxstyle="round,pad=0.3", 
                                  facecolor=colors[opt_idx], 
                                  alpha=0.7,
                                  edgecolor='none'))
            
            ax.set_title(title, fontsize=16, fontweight='bold', pad=15)
            ax.set_xlabel('epoch', fontsize=14)
            ax.set_ylabel('metric', fontsize=14)
            ax.legend(fontsize=12, loc='lower right' if split == 'val' else 'upper right')
            ax.grid(True, alpha=0.3, linestyle='--')
            
            y_min = min([np.min(np.array(self.data[opt][split])[:, -10:]) 
                        for opt in self.data.keys() if split in self.data[opt]])
            y_max = max([np.max(np.array(self.data[opt][split])[:, -10:]) 
                        for opt in self.data.keys() if split in self.data[opt]])
            
            ax.set_axisbelow(True)
        
        
        plt.tight_layout()
        
        if save:
            plot_path = self.results_dir / 'both_splits_comparison.png'
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def calculate_final_scores(self, epoch=-1, confidence=0.95):
        results = []
        
        for optimizer_name, optimizer_data in self.data.items():
            val_data = np.array(optimizer_data['val'])
            
            if len(val_data.shape) > 1:
                best_epochs_per_seed = np.argmax(val_data, axis=-1)
            else:
                best_epochs_per_seed = np.array([np.argmax(val_data)])
            
            for split in ['val', 'test']:
                if split not in optimizer_data:
                    continue
                    
                seeds_data = np.array(optimizer_data[split])
                
                if epoch == -1:
                    if len(seeds_data.shape) > 1:
                        final_scores = np.array([seeds_data[i, best_epochs_per_seed[i]] 
                                                for i in range(len(best_epochs_per_seed))])
                        avg_best_epoch = int(np.mean(best_epochs_per_seed))
                        epoch_label = f'best_val (avg: {avg_best_epoch})'
                    else:
                        final_scores = np.array([seeds_data[best_epochs_per_seed[0]]])
                        epoch_label = f'best_val ({best_epochs_per_seed[0]})'
                else:
                    if len(seeds_data.shape) > 1:
                        final_scores = seeds_data[:, epoch]
                    else:
                        final_scores = np.array([seeds_data[epoch]])
                    epoch_label = epoch if epoch >= 0 else f'last ({len(seeds_data[0]) + epoch})'
                
                mean, margin, lower, upper = self._calculate_confidence_interval(
                    final_scores, confidence
                )
                
                results.append({
                    'optimizer': optimizer_name,
                    'split': split,
                    'epoch': epoch_label,
                    'mean': mean,
                    'std': np.std(final_scores),
                    'min': np.min(final_scores),
                    'max': np.max(final_scores),
                    'margin_of_error': margin,
                    f'ci_{int(confidence*100)}%_lower': lower,
                    f'ci_{int(confidence*100)}%_upper': upper,
                    'num_seeds': len(final_scores)
                })
        
        df = pd.DataFrame(results)
        
        df = df.sort_values(['split', 'optimizer'])
        
        csv_path = self.results_dir / 'final_scores_with_CI.csv'
        df.to_csv(csv_path, index=False, float_format='%.6f')
        
        
        for split in ['val', 'test']:
            split_df = df[df['split'] == split]
            if not split_df.empty:
                print("-" * 80)
                for _, row in split_df.iterrows():
                    print(f"{row['optimizer']:20} | "
                          f"Mean: {row['mean']:.4f} ± {row['margin_of_error']:.4f} | "
                          f"CI95%: [{row['ci_95%_lower']:.4f}, {row['ci_95%_upper']:.4f}] | "
                          f"Seeds: {row['num_seeds']}")
        
        print("="*80)
        
        return df
    
    def run_full_analysis(self, confidence=0.95):
        self.load_results()
        
        if not self.data:
            return
        
        self.plot_both_splits_comparison_with_ci(confidence=confidence, save=True)
        df_scores = self.calculate_final_scores(epoch=-1, confidence=confidence)
        return df_scores

def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user3'],
                        help='Name of the user folder for getting the optimizer hyperparameters.')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'tiny_imagenet', 'unbalanced_cifar10', 'unbalanced_tiny_imagenet'],
                        help='Dataset to use.')
    parser.add_argument('--model', type=str, default='SimpleCNN', choices=['SimpleCNN', 'ResNet18_32x32', 'SWIN_tiny',
                                                                           'SimpleCNNBinClass', 'ResNet18_32x32BinClass', 'char_lstm', 'char_transformer'],
                        help='Model architecture to use.')
    parser.add_argument('--unbalance_coef', type=int, default=1, choices=[1, 2, 5, 10, 20, 30, 40, 50, 60, 80, 100, 200, 300],
                        help='Unbalance coefficient to use.')
    parser.add_argument('--balanced', action='store_true', help='If set, use balanced test dataset results (f1_balanced_results folder).')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_arguments()
    results_folder = 'f1_balanced_results' if args.balanced else 'f1_results'
    path = f'./tuning/{args.user}/{args.dataset}/{args.model.lower()}_{args.unbalance_coef}/{results_folder}'
    analyzer = ExperimentAnalyzer(results_dir=path)
    results_df = analyzer.run_full_analysis(confidence=0.95)
