import argparse
import json
import random
import yaml

import torch
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms

from simulator.functions import LogisticRegressionFunction, ResnetClassificationFunction
from simulator.fixed_time_utils import run_pipeline
from simulator.algorithms.ringmaster_sgd import RingmasterSGDServer
from simulator.algorithms.rennala_sgd import RennalaSGDServer, RennalaSGDWorker
from simulator.algorithms.synchronized_sgd import SynchronizedSGDServer
from simulator.worker import Worker, WorkerWithLocalSteps, WorkerWithTargetComputeCommunicateRatio

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

def get_cifar10_dataloaders(batch_size=128, num_workers=4):
    """Create CIFAR-10 dataloaders"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return trainloader, testloader

def get_imagenet_dataloaders(data_dir, batch_size=256, num_workers=4):
    """Create ImageNet dataloaders"""
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    trainset = torchvision.datasets.ImageNet(
        root=data_dir, split='train', transform=transform_train)
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    valset = torchvision.datasets.ImageNet(
        root=data_dir, split='val', transform=transform_val)
    valloader = DataLoader(
        valset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    return trainloader, valloader



def train_image_classification(train_dataloader, test_dataloader, save_path, config, config_name):
    print("Dataset shape:", len(train_dataloader.dataset))
    print("Test dataset shape:", len(test_dataloader.dataset))
    
    # Hyperparameters
    reg = 0.0
    batch_size = config['batch_size']
    gamma = config['gamma']
    sim_time = config['sim_time']
    num_workers = config['num_workers']
    metric_check_num = config['metric_check_num']
    metric_check_period = sim_time / metric_check_num
    times_to_calculate = config['times_to_calculate']
    times_to_communicate = config['times_to_communicate']
    local_steps = config.get('local_steps', None)
    
    num_of_gpus = torch.cuda.device_count()
    print(f"Devices available: {num_of_gpus}")
    functions = [
        ResnetClassificationFunction(
            train_dataloader, test_dataloader, batch_size=batch_size, device=f"cuda:{w_id % num_of_gpus}"
        )
        for w_id in range(num_workers)
    ]

    # function, which parameters will be used and updated on server side
    function_test = ResnetClassificationFunction(
            train_dataloader, test_dataloader, batch_size=batch_size, device=f"cuda:{random.randint(0, num_of_gpus - 1)}"
        )
    
    # point is model mow, as othervise will be problem with .parameters() and optimizer
    point = function_test.model
   
    metric_checked_times = 0
    class Metric:
        def __init__(self, metric_check_period, saving_period = 20):
            self._metric_check_period = metric_check_period
            self._metric_checked_times = metric_checked_times
            self._saving_period = saving_period
            
        def calculate_metrics(self, env, iter, point):
            if env.now < metric_check_period * self._metric_checked_times:
                return None, None, None
            else:
                self._metric_checked_times += 1
            
            # pass None, as function tests reference the same model parameters as server
            value = function_test.value(None)
            accuracy = function_test.accuracy(None, train=True)
            accuracy_test = function_test.accuracy(None)
            
            save_path_temporal = save_path.replace(".json", "_temp.json") if (self._metric_checked_times + 1) % self._saving_period == 0 else None
            print(f"Time {env.now}, Loss: {value}, Accuracy: {accuracy}, Accuracy Test: {accuracy_test}")
            return {'value': float(value), 'time': env.now, "iter": iter, 
                    'accuracy': accuracy, 'accuracy_test': accuracy_test}, save_path_temporal, config
    
    if config['server'] == 'ringmaster_sgd':
        server_cls = RingmasterSGDServer
        worker_cls = Worker
    elif config['server'] == 'rennala_sgd':
        server_cls = RennalaSGDServer
        worker_cls = RennalaSGDWorker
    elif config['server'] == 'synchronized_sgd':
        server_cls = SynchronizedSGDServer
        worker_cls = Worker
    elif config['server'] == 'local_sgd':
        server_cls = RennalaSGDServer
        worker_cls = WorkerWithLocalSteps
    elif config['server'] == 'ringmaster_sgd_compcomm':
        server_cls = RingmasterSGDServer
        worker_cls = WorkerWithTargetComputeCommunicateRatio
    else:
        raise RuntimeError()
    
    if config['optimizer'] == 'adam':
        optimizer_cls = torch.optim.Adam
    elif config['optimizer'] == 'sgd':
        optimizer_cls = torch.optim.SGD
    else:
        raise RuntimeError()
    
    print(f"Experiment name: {config_name}, Server params: {config['server_params']}, worker params: {config['worker_params']}")
    _, stats = run_pipeline(server_cls, worker_cls,
                            functions, point, gamma, optimizer_cls, sim_time=sim_time,
                            times_to_calculate=times_to_calculate,
                            times_to_communicate=times_to_communicate,
                            server_params=config['server_params'],
                            worker_params=config.get('worker_params', {}),
                            calculate_metrics=Metric(metric_check_period).calculate_metrics,
                            local_steps = local_steps)
    
    stats["params"] = config
    with open(save_path, 'w') as f:
        json.dump(stats, f, indent=4)
    print(f"Results saved to {save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Logistic Regression on MNIST (binary classification) and save experiment results.")
    parser.add_argument('--save_path', type=str, required=True, help='Path to save experiment results')
    parser.add_argument('--config', type=str, required=True, help='Config with params')
    args = parser.parse_args()

    config = yaml.safe_load(open(args.config))
    train_dataloader, test_dataloader = get_cifar10_dataloaders()
    train_image_classification(train_dataloader, test_dataloader, save_path=args.save_path, config=config, config_name=args.config)
    print("Training completed.")
