import argparse
import csv
import os
import sys
import yaml
import hashlib
import io
from typing import List, Tuple
import random
import urllib

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.models.resnet18_elu_nobn import resnet18_nobn_elu
from src.models.models import linear, conv_linear
from src.data.datasets import BinaryDataset
from src.optim.gd import GradientDescent

def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float:
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()


def evaluate_full_batch(model: nn.Module, x_all: torch.Tensor, y_all: torch.Tensor, device: torch.device, use_amp: bool, amp_dtype, loss_name: str) -> Tuple[float, float]:
    model.eval()
    if loss_name == 'cross_entropy':
        criterion = nn.CrossEntropyLoss()
    elif loss_name == 'binary_cross_entropy':
        criterion = nn.BCEWithLogitsLoss()
    else:
        raise ValueError(f"Invalid loss name: {loss_name}")
    with torch.inference_mode():
        logits = model(x_all)
        if loss_name == 'binary_cross_entropy':
            logits = torch.squeeze(logits, dim=-1)
            loss = criterion(logits, y_all.float())
            logits = torch.cat([-logits.reshape(-1, 1), logits.reshape(-1, 1)], dim=1)
            acc = accuracy(logits, y_all)
            return loss.item(), acc
        else:
            loss = criterion(logits, y_all)
            acc = accuracy(logits, y_all)
            return loss.item(), acc


def train_one_epoch_full_batch(model: nn.Module, x_all: torch.Tensor, y_all: torch.Tensor, optimizer: torch.optim.Optimizer, device: torch.device, use_amp: bool, amp_dtype,
                               loss_name: str) -> Tuple[float, float]:
    model.train()
    if loss_name == 'cross_entropy':
        criterion = nn.CrossEntropyLoss()
    elif loss_name == 'binary_cross_entropy':
        criterion = nn.BCEWithLogitsLoss()
    elif loss_name == 'perceptron':
        def criterion(logits, targets):
            logits, targets = logits.reshape(-1), targets.reshape(-1)
            targets = 2 * targets - 1
            mask_wrong_outputs = targets * logits <= 0
            wrong_outputs = logits[mask_wrong_outputs]
            wrong_labels = targets[mask_wrong_outputs]
            aggregated_outputs = torch.sum(wrong_outputs * wrong_labels) / len(logits)
            return -aggregated_outputs
    else:
        raise ValueError(f"Invalid loss name: {loss_name}")
    optimizer.zero_grad(set_to_none=True)
    assert not use_amp
    logits = model(x_all)
    if loss_name in ['binary_cross_entropy', 'perceptron']:
        assert logits.shape[1] == 1
        logits = torch.squeeze(logits, dim=-1)
        loss = criterion(logits, y_all.float())
    else:
        loss = criterion(logits, y_all)
    loss.backward()
    optimizer.step()
    if loss_name in ['binary_cross_entropy', 'perceptron']:
        logits = torch.cat([-logits.reshape(-1, 1), logits.reshape(-1, 1)], dim=1)
        return loss.item(), accuracy(logits, y_all)
    else:
        return loss.item(), accuracy(logits, y_all)


def parse_lrs(lrs_str: str) -> List[float]:
    return [float(x.strip()) for x in lrs_str.split(',') if x.strip()]


def get_model_checksum(model: nn.Module) -> str:
    """Compute SHA256 checksum of model parameters for reproducibility verification."""
    state_dict = model.state_dict()
    # Convert to bytes for hashing
    buffer = io.BytesIO()
    torch.save(state_dict, buffer)
    return hashlib.sha256(buffer.getvalue()).hexdigest()


def get_model_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.requires_grad:
            param_norm = p.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5


def get_tensor_checksum(tensor: torch.Tensor) -> str:
    """Compute SHA256 checksum of a tensor for reproducibility verification."""
    buffer = io.BytesIO()
    torch.save(tensor.cpu(), buffer)
    return hashlib.sha256(buffer.getvalue()).hexdigest()


def get_per_layer_checksums(model: nn.Module) -> dict:
    """Compute SHA256 checksum for each layer/parameter in the model."""
    checksums = {}
    state_dict = model.state_dict()
    for name, param in state_dict.items():
        print("Shape:", param.shape)
        buffer = io.BytesIO()
        torch.save(param.cpu(), buffer)
        checksums[name] = hashlib.sha256(buffer.getvalue()).hexdigest()
    return checksums


def scale_model_parameters(model: nn.Module, scale_model: float):
    for param in model.parameters():
        param.data *= scale_model
    return model


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/.torch/data'))
    parser.add_argument('--epochs', type=int, default=40)
    parser.add_argument('--lr', type=float, default=0.5)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--amp', action='store_true', help='Enable mixed precision (bf16 if available, else fp16)')
    parser.add_argument('--channels-last', action='store_true', help='Use channels_last memory format for model and inputs')
    parser.add_argument('--compile', action='store_true', help='Use torch.compile for the model')
    parser.add_argument('--out-dir', type=str, default=os.path.join(os.getcwd(), 'runs', 'exp'), help='Directory to save results and model')
    parser.add_argument('--loss', type=str, default='cross_entropy', help='Loss function to use')
    parser.add_argument('--model', type=str, default='resnet18_nobn_elu', help='Model to use')
    parser.add_argument('--num-images', type=int, default=None, help='Number of images to use')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--classes', type=str, default='0,1', help='Classes to use')
    parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use')
    parser.add_argument('--no-normalize', action='store_true', help='Do not normalize the data')
    parser.add_argument('--num-threads', type=int, default=1, help='Number of threads to use')
    parser.add_argument('--test', action='store_true', help='Test mode')
    parser.add_argument('--debug', action='store_true', help='Debug mode')
    parser.add_argument('--resize-size', type=str, default=None, help='Resize size')
    parser.add_argument('--kernel-size', type=int, default=2, help='Kernel size')
    parser.add_argument('--shuffle-features', action='store_true', help='Shuffle features')
    parser.add_argument('--scale-model', type=float, default=None)
    args = parser.parse_args()

    torch.set_num_threads(args.num_threads)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    device = torch.device(args.device)

    classes = [int(c) for c in args.classes.split(',')]
    assert len(classes) == 2

    if args.resize_size is not None:
        resize_size = [int(x) for x in args.resize_size.split(',')]
        assert len(resize_size) == 2
    else:
        resize_size = None

    if device.type == 'cuda':
        torch.backends.cudnn.benchmark = True
        amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    else:
        amp_dtype = None

    out_dir = args.out_dir
    assert not os.path.exists(out_dir) or args.debug, f"Output directory {out_dir} already exists"
    os.makedirs(out_dir, exist_ok=args.debug)
    
    with open(os.path.join(out_dir, 'command.txt'), 'w') as f:
        f.write(' '.join(sys.argv) + '\n')
    
    with open(os.path.join(out_dir, 'config.yaml'), 'w') as cf:
        yaml.dump(vars(args), cf, default_flow_style=False, indent=2)

    # Create datasets
    train_ds = BinaryDataset(
        dataset_name=args.dataset,
        root=args.data_dir, 
        train=True, 
        download=True, 
        num_images=args.num_images, 
        classes=classes,
        no_normalize=args.no_normalize,
        resize_size=resize_size,
        shuffle_features=args.shuffle_features
    )
    if args.test:
        assert not args.shuffle_features
        test_ds = BinaryDataset(
            dataset_name=args.dataset,
            root=args.data_dir, 
            train=False, 
            download=True, 
            num_images=args.num_images, 
            classes=classes,
            no_normalize=args.no_normalize,
            resize_size=resize_size
        )

    x_train = train_ds.data.to(device)
    y_train = train_ds.targets.to(device)
    x_test = test_ds.data.to(device) if args.test else None
    y_test = test_ds.targets.to(device) if args.test else None

    if args.channels_last and device.type == 'cuda':
        x_train = x_train.contiguous(memory_format=torch.channels_last)
        x_test = x_test.contiguous(memory_format=torch.channels_last)

    if args.loss == 'cross_entropy':
        num_classes = 2
    elif args.loss in ['binary_cross_entropy', 'perceptron']:
        num_classes = 1
    else:
        raise ValueError(f"Invalid loss name: {args.loss}")
    if args.model == 'resnet18_nobn_elu':
        model = resnet18_nobn_elu(num_classes=num_classes).to(device)
    else:
        num_features = np.prod(x_train.shape[1:])
        if args.model == 'conv_linear':
            model = globals()[args.model](num_classes=num_classes, num_features=num_features, kernel_size=args.kernel_size).to(device)
        elif args.model == 'linear':
            model = globals()[args.model](num_classes=num_classes, num_features=num_features).to(device)
        else:
            raise ValueError(f"Invalid model name: {args.model}")
    
    if args.scale_model is not None:
        model = scale_model_parameters(model, args.scale_model)

    if args.channels_last and device.type == 'cuda':
        model = model.to(memory_format=torch.channels_last)
    if args.compile and hasattr(torch, 'compile'):
        model = torch.compile(model)  # type: ignore

    torch.save(model.state_dict(), os.path.join(out_dir, 'model_init.pt'))
    init_checksum = get_model_checksum(model)
    x_train_checksum = get_tensor_checksum(x_train)
    y_train_checksum = get_tensor_checksum(y_train)
    per_layer_checksums = get_per_layer_checksums(model)
    
    with open(os.path.join(out_dir, 'init_model_checksum.txt'), 'w') as f:
        f.write(f"Initial model SHA256: {init_checksum}\n")
        f.write(f"x_train SHA256: {x_train_checksum}\n")
        f.write(f"y_train SHA256: {y_train_checksum}\n")
        f.write(f"Seed: {args.seed}\n")
        f.write("\n" + "=" * 80 + "\n")
        f.write("Per-layer SHA256 checksums:\n")
        f.write("=" * 80 + "\n")
        for layer_name, checksum in per_layer_checksums.items():
            f.write(f"{layer_name}: {checksum}\n")
    

    results_csv = os.path.join(out_dir, 'results_gd.csv')
    with open(results_csv, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['lr', 'epoch', 'train_loss', 'train_acc', 'test_loss', 'test_acc', 'model_norm'])

        optimizer = GradientDescent(model.parameters(), lr=args.lr)

        pbar = tqdm(range(1, args.epochs + 1), desc=f'GD lr={args.lr}')
        for epoch in pbar:
            train_loss, train_acc = train_one_epoch_full_batch(model, x_train, y_train, optimizer, device, use_amp=(args.amp and device.type=='cuda'), amp_dtype=amp_dtype, loss_name=args.loss)
            if args.test:
                test_loss, test_acc = evaluate_full_batch(model, x_test, y_test, device, use_amp=(args.amp and device.type=='cuda'), amp_dtype=amp_dtype, loss_name=args.loss)
            else:
                test_loss, test_acc = 0.0, 0.0
            model_norm = get_model_norm(model)
            pbar.set_postfix(train_loss=f"{train_loss:.4f}", train_acc=f"{train_acc:.4f}", test_acc=f"{test_acc:.4f}", model_norm=f"{model_norm:.4f}")
            writer.writerow([args.lr, epoch, f"{train_loss:.6f}", f"{train_acc:.6f}", f"{test_loss:.6f}", f"{test_acc:.6f}", f"{model_norm:.6f}"])
            f.flush()

        torch.save(model.state_dict(), os.path.join(out_dir, 'model_final.pt'))


if __name__ == '__main__':
    try:
        main()
    except urllib.error.URLError as e:
        print(e)
        print("Error: Could not download dataset. Please check your internet connection and try again.")
        print("Try to use this hack: import ssl; ssl._create_default_https_context = ssl._create_unverified_context")
        sys.exit(1)
