import argparse
import json
import yaml

import torch
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms

from simulator.functions import LogisticRegressionFunction
from simulator.fixed_time_utils import run_pipeline
from simulator.algorithms.ringmaster_sgd import RingmasterSGDServer
from simulator.algorithms.rennala_sgd import RennalaSGDServer, RennalaSGDWorker
from simulator.algorithms.rennala_softreduce_sgd import RennalaSGDSoftReduceServer, RennalaSGDSoftReduceWorker
from simulator.algorithms.synchronized_sgd import SynchronizedSGDServer
from simulator.algorithms.subset_ring_reduce import SubsetRingReduceServer, SubsetRingReduceWorker
from simulator.worker import Worker, WorkerWithLocalSteps, WorkerWithTargetComputeCommunicateRatio

def load_mnist_binary():
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    def _get_X_y(mnist):
        data = []
        targets = []
        for img, label in mnist:
            data.append(img.view(-1))
            targets.append(float(label))
        X = torch.stack(data)
        y = torch.tensor(targets).float()
        return X, y
    X, y = _get_X_y(mnist_train)
    X_t, y_t = _get_X_y(mnist_test)
    num_classes = 10
    return X, y, X_t, y_t, num_classes

class LogisticRegressionParameters:
    def __init__(self, parameters):
        self._parameters = parameters
        
    def parameters(self):
        return self._parameters

def train_logistic_regression(X, y, X_t, y_t, num_classes, save_path, config):
    print("Dataset shape:", X.shape)
    
    # Hyperparameters
    reg = 0.0
    batch_size = config['batch_size']
    gamma = config['gamma']
    sim_time = config['sim_time']
    num_workers = config['num_workers']
    metric_check_num = config['metric_check_num']
    metric_check_period = sim_time / metric_check_num
    times_to_calculate = config['times_to_calculate']
    times_to_communicate = config['times_to_communicate']
    local_steps = config.get('local_steps', None)
    
    functions = [
        LogisticRegressionFunction(
            X, y, num_classes=num_classes, 
            reg=reg, rng=np.random.default_rng(seed=w_id), batch_size=batch_size
        ) 
        for w_id in range(num_workers)
    ]
    dim = functions[0].dim()
    rng = np.random.default_rng(seed=42)
    point = LogisticRegressionParameters([torch.nn.Parameter(torch.tensor(rng.standard_normal(dim).astype(np.float32), requires_grad=True))])
    
    # rng is not used in the function, so None is passed
    function_train = LogisticRegressionFunction(X, y, num_classes=num_classes, reg=reg, rng=None, batch_size=batch_size)
    function_test = LogisticRegressionFunction(X_t, y_t, num_classes=num_classes, reg=reg, rng=None, batch_size=batch_size)
    metric_checked_times = 0
    class Metric:
        def __init__(self, metric_check_period, saving_period = 20):
            self._metric_check_period = metric_check_period
            self._metric_checked_times = metric_checked_times
            self._saving_period = saving_period
            
        def calculate_metrics(self, env, iter, point):
            if env.now < metric_check_period * self._metric_checked_times:
                return None, None, None
            else:
                self._metric_checked_times += 1
            value = function_train.value(point.parameters())
            accuracy = function_train.accuracy(point.parameters())
            accuracy_test = function_test.accuracy(point.parameters())
            print(f"Time {env.now}, Loss: {value}, Accuracy: {accuracy}, Accuracy Test: {accuracy_test}")
            save_path_temporal = save_path.replace(".json", "_temp.json") if (self._metric_checked_times + 1) % self._saving_period == 0 else None
            return {'value': float(value), 'time': env.now, "iter": iter, 
                    'accuracy': accuracy, 'accuracy_test': accuracy_test}, save_path_temporal, config
    
    if config['server'] == 'ringmaster_sgd':
        server_cls = RingmasterSGDServer
        worker_cls = Worker
    elif config['server'] == 'ringmaster_sgd_compcomm':
        server_cls = RingmasterSGDServer
        worker_cls = WorkerWithTargetComputeCommunicateRatio
    elif config['server'] == 'rennala_sgd':
        server_cls = RennalaSGDServer
        worker_cls = RennalaSGDWorker
    # elif config['server'] == 'rennala_softreduce_sgd':
    #     server_cls = RennalaSGDSoftReduceServer
    #     worker_cls = RennalaSGDSoftReduceWorker
    # elif config['server'] == 'rennala_sgd_history_window':
    #     server_cls = RennalaSGDSoftReduceServer
    #     worker_cls = RennalaSGDSoftReduceWorker
    elif config['server'] == 'synchronized_sgd':
        server_cls = SynchronizedSGDServer
        worker_cls = Worker
    elif config['server'] == 'local_sgd':
        server_cls = RennalaSGDServer
        worker_cls = WorkerWithLocalSteps
    elif config['server'] == 'subset_ring_reduce':
        server_cls = SubsetRingReduceServer
        worker_cls = SubsetRingReduceWorker
    else:
        raise RuntimeError()
    
    if config['optimizer'] == 'adam':
        optimizer_cls = torch.optim.Adam
    elif config['optimizer'] == 'sgd':
        optimizer_cls = torch.optim.SGD
    else:
        raise RuntimeError()
    
    print(config['server_params'])
    _, stats = run_pipeline(server_cls, worker_cls,
                            functions, point, gamma, optimizer_cls, sim_time=sim_time,
                            times_to_calculate=times_to_calculate,
                            times_to_communicate=times_to_communicate,
                            server_params=config['server_params'],
                            worker_params=config.get('worker_params', {}),
                            calculate_metrics=Metric(metric_check_period).calculate_metrics,
                            local_steps = local_steps)
    
    stats["params"] = config
    with open(save_path, 'w') as f:
        json.dump(stats, f, indent=4)
    print(f"Results saved to {save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Logistic Regression on MNIST (binary classification) and save experiment results.")
    parser.add_argument('--save_path', type=str, required=True, help='Path to save experiment results')
    parser.add_argument('--config', type=str, required=True, help='Config with params')
    args = parser.parse_args()

    config = yaml.safe_load(open(args.config))
    X, y, X_t, y_t, num_classes = load_mnist_binary()
    train_logistic_regression(X, y, X_t, y_t, num_classes, save_path=args.save_path, config=config)
    print("Training completed.")
