import numpy as np
import torch
import torchvision
import torch.nn as nn
import copy
import random
import json
import os
from torch.utils.data import TensorDataset, DataLoader, Subset
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torchvision import transforms
from utils import get_args, __getDirichletData__, test_img, LinearCNN, initialize_weights, measure_alignment_linear_cnn, Resnet, measure_alignment
from train_tiny_imagenet import TinyImageNetDataset, download_tiny_imagenet


args = get_args()

alpha_range = [args.alpha] if args.alpha != -1 else [0.05]
seed_range = [args.seed] if args.seed != -1 else [30]
run_range = [args.run] if args.run != -1 else [0, 1]
algs = [args.alg] if args.alg != "x" else ['random', 'pretrained-cifar', 'pretrained-tiny-imagenet']

print(f"Running for alpha {alpha_range}, seed {seed_range}, run {run_range}, alg range {algs}")

class LocalUpdate:
    def __init__(self, args, dataset, use_squared=False):
        self.args = args
        self.loss_func = nn.MSELoss() if use_squared else nn.CrossEntropyLoss()
        self.ldr_train = DataLoader(dataset, batch_size=args.bs, shuffle=True)

    def train(self, net):
        net = net.to(self.args.device)
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
        vec_init = parameters_to_vector(net.parameters())
        
        for epoch in range(self.args.cp):
            for images, labels in self.ldr_train:
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                
                # Handle different model types
                if isinstance(net, LinearCNN):
                    images = images.view(images.size(0), 1, -1)  # Flatten for LinearCNN
                    labels = (labels == 1).long()  # Convert to binary for LinearCNN
                # For Resnet/VGG, keep images as is and use original labels
        
                optimizer.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()

        with torch.no_grad():
            vec_curr = parameters_to_vector(net.parameters())
        
        net = net.to('cpu')
        return vec_curr - vec_init

# Determine if using LinearCNN (always 2 classes) or other models (all classes)
use_linear_cnn = args.model_type == 'linear_cnn'

# Dataset selection based on args.dataset and model type
if args.dataset == 'cifar10':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    if use_linear_cnn:
        # Use only 2 classes for LinearCNN
        train_indices = [i for i, (_, label) in enumerate(train_dataset) if label in [0, 1]]
        test_indices = [i for i, (_, label) in enumerate(test_dataset) if label in [0, 1]]
        dataset_train_global = Subset(train_dataset, train_indices)
        dataset_test_global = Subset(test_dataset, test_indices)
        args.num_c = 2
    else:
        # Use all 10 classes for other models
        dataset_train_global = train_dataset
        dataset_test_global = test_dataset
        args.num_c = 10
else:  # TinyImageNet
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    tiny_imagenet_path = './data/tiny-imagenet-200'
    
    if use_linear_cnn:
        # Use only 2 classes for LinearCNN
        classes = ['n02058221', 'n02279972']  # albatross, monarch butterfly
        args.num_c = 2
    else:
        # Use all 200 classes for other models
        train_dir = os.path.join(tiny_imagenet_path, 'train')
        classes = sorted([d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))])
        args.num_c = 200
    
    dataset_train_global = TinyImageNetDataset(tiny_imagenet_path, classes, transform=transform, train=True)
    dataset_test_global = TinyImageNetDataset(tiny_imagenet_path, classes, transform=transform, train=False)

for seed in seed_range:
    for alpha in alpha_range:
        for alg in algs:
            for run in run_range:
                torch.manual_seed(seed)
                random.seed(seed)
                torch.backends.cudnn.deterministic = True

                print(f"Running for alpha {alpha}, seed {seed}, run {run}")
                filename = f"results/check_alignment_alg_{alg}_alpha_{alpha}_run_{run}_seed_{seed}_dataset_{args.dataset}_eta_{args.eta}.json"


                # Load and prepare data
                train_loader = DataLoader(dataset_train_global, batch_size=len(dataset_train_global))
                test_loader = DataLoader(dataset_test_global, batch_size=len(dataset_test_global))

                X_train = next(iter(train_loader))[0].numpy()
                Y_train = next(iter(train_loader))[1].numpy()
                X_test = next(iter(test_loader))[0].numpy()
                Y_test = next(iter(test_loader))[1].numpy()

                n = args.n
                num_c = args.num_c
                np.random.seed(0)

                inds = __getDirichletData__(Y_train, n, alpha, num_c)

                # Create client datasets
                dataset_train = []
                for ind in inds:
                    x_train = torch.Tensor(X_train[ind])
                    y_train = torch.LongTensor(Y_train[ind])
                    dataset_train.append(TensorDataset(x_train, y_train))

                # Create global datasets
                dataset_test_global = TensorDataset(torch.Tensor(X_test), torch.LongTensor(Y_test))
                dataset_train_global = TensorDataset(torch.Tensor(X_train), torch.LongTensor(Y_train))
                dataloader = DataLoader(dataset_train_global, batch_size=512, shuffle=False)


                # Calculate client weights
                p = np.array([len(dataset_train[i]) for i in range(n)])
                p = p / np.sum(p)


                # Initialize data tracking
                data = {
                    'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'misalignment': []
                }

                # Create models based on model_type and algorithm
                if args.model_type == 'linear_cnn':
                    net_glob = LinearCNN(input_dim=3072, out_channel=64, patch_num=2)
                    net_glob.apply(initialize_weights)
                    
                    # Load pretrained weights if specified
                    if alg == 'pretrained-cifar':
                        net_glob.load_state_dict(torch.load('linear_cnn_weights.pth'))
                    elif alg == 'pretrained-tiny-imagenet':
                        net_glob.load_state_dict(torch.load('linear_cnn_tiny_imagenet_weights.pth'))
                else:  # resnet
                    if 'random' in alg:
                        net_glob = Resnet(num_classes=args.num_c, resnet_size=18, pretrained=False)
                        net_glob.fc = nn.Linear(512, args.num_c)
                    else:
                        net_glob = Resnet(num_classes=args.num_c, resnet_size=18, pretrained=True)
                        net_glob.fc = nn.Linear(512, args.num_c)

                # Initialize training variables
                d = parameters_to_vector(net_glob.parameters()).numel()
                merge_model_vector = torch.zeros(d).to(args.device)
                net = copy.deepcopy(net_glob)
                
                # Create optimal model for comparison
                if args.model_type == 'linear_cnn':
                    net_opt = LinearCNN(input_dim=3072, out_channel=64, patch_num=32)
                    net_opt.apply(initialize_weights)
                else:  # resnet
                    net_opt = Resnet(num_classes=args.num_c, resnet_size=18, pretrained=True)
                    net_opt.fc = nn.Linear(512, args.num_c)

                
                if run == 1:
                    model_path = f"models_new/{alg}_{args.model_type}_final_noniid_alpha_{args.alpha}_opt_300_rounds_seed_{seed}_dataset_{args.dataset}.pt"
                    net_opt.load_state_dict(torch.load(model_path))

                net_glob_init = copy.deepcopy(net_glob)


                # Main training loop
                for t in range(30):
                    print(f"Round {t}")

                    # Test every round
                    if t % 1 == 0:
                        sum_acc_test, sum_loss_test = test_img(net_glob, dataset_test_global, args, use_squared=False)
                        print(f"Testing Global Model {sum_acc_test} {sum_loss_test}")
                        data['test_loss'].append(sum_loss_test)
                        data['test_acc'].append(sum_acc_test)
                        
                        if args.json:
                            with open(filename, 'w') as f:
                                json.dump(data, f, indent=4)

                    # Measure alignment every 10 rounds
                    if t % 1 == 0:
                        reference_model = net_glob_init if run == 0 else net_opt
                        if args.model_type == 'linear_cnn':
                            ind, ratio, magnitude = measure_alignment_linear_cnn(dataloader, net_glob, reference_model)
                        else:
                            ind, ratio, magnitude = measure_alignment(dataloader, net_glob, reference_model)
                        print(f"Misalignment {ratio} {magnitude}")
                        data["misalignment"].append(ratio)
                        
                        if args.json:
                            with open(filename, 'w') as f:
                                json.dump(data, f, indent=4)

                    # Client selection
                    ind = list(range(n))
                    merge_model_vector.zero_()
                    sum_p_i = 0

                    # Federated averaging
                    for i in ind:
                        net.load_state_dict(net_glob.state_dict())
                        local = LocalUpdate(args=args, dataset=dataset_train[i])
                        model_vector = local.train(net)
                        merge_model_vector += p[i] * model_vector
                        sum_p_i += p[i]

                    merge_model_vector /= sum_p_i
                    grad_avg = merge_model_vector
                    
                    args.eta *= 0.998

                    w_vec_estimate = parameters_to_vector(net_glob.parameters()) + grad_avg.cpu()
                    vector_to_parameters(w_vec_estimate, net_glob.parameters())

                # Print average metrics
                avg_test_acc = np.mean(data['test_acc'])
                avg_test_loss = np.mean(data['test_loss'])
                print (data['test_acc'])
                print(f"Average Test Accuracy: {avg_test_acc:.2f}%")
                print(f"Average Test Loss: {avg_test_loss:.4f}")

                # Save final model
                net_glob.eval()
                if run == 0:
                    model_path = f"models_new/{alg}_{args.model_type}_final_noniid_alpha_{args.alpha}_opt_300_rounds_seed_{seed}_dataset_{args.dataset}.pt"
                    torch.save(net_glob.state_dict(), model_path)