import numpy as np
import torch, torchvision
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import math
from typing import *
from PIL import Image
import random
import copy
import pickle
import argparse
import sys
from torch.autograd import Variable, Function
from sklearn.model_selection import StratifiedKFold, train_test_split, ShuffleSplit
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, Subset, DataLoader, random_split
import h5py

import datetime, time
date_time = datetime.datetime.utcnow().isoformat().replace(":", "")

parser = argparse.ArgumentParser(description='Curriculum Learning')
parser.add_argument('--N', default=2000, type=int, help='number of samples')
parser.add_argument('--d', default=100, type=int, help='dimension')
parser.add_argument('--T', default=2000, type=int, help='epoch')
parser.add_argument('--gamma1', default=3.0, type=float, help='large margin')
parser.add_argument('--gamma2', default=0.5, type=float, help='small margin')
parser.add_argument('--std1', default=0.05, type=float, help='small standard deviation')
parser.add_argument('--std2', default=0.3, type=float, help='large standard deviation')
args = parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Optional: for DataLoader workers
    def seed_worker(worker_id):
        worker_seed = seed + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    return seed_worker
    


device = 'cuda' if torch.cuda.is_available() else 'cpu'

class LinearClassifier(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear = nn.Linear(d, 1)  # output 1 logit

    def forward(self, x):
        return self.linear(x)


def hinge_loss(y_pred, y_true):
    # Convert labels from (0, 1) to (-1, +1)
    y_true = 2 * y_true - 1
    return torch.mean(torch.clamp(1 - y_true * y_pred, min=0))
    

def generate_data(N, gamma=3, std=0.5, d=100, seed=42):
    """
    Generate N samples for binary classification with margin gamma and controlled variance.

    Args:
        N (int): total number of samples
        gamma (float): desired margin between two classes
        variance (float): variance of Gaussian noise for each class
        d (int): feature dimension (default 2)
        seed (int): random seed

    Returns:
        X: (N, d) feature array
        y: (N,) labels (0 or 1)
    """
    np.random.seed(seed)
    N_per_class = N // 2

    # Centers for two classes
    center_0 = np.zeros(d)
    center_1 = np.zeros(d)
    center_1[0] = gamma  # Separate along first dimension by gamma

    # Generate points around each center
    X0 = center_0 + std * np.random.randn(N_per_class, d)
    X1 = center_1 + std * np.random.randn(N_per_class, d)

    X = np.vstack([X0, X1])
    y = np.array([0] * N_per_class + [1] * N_per_class)

    # If N is odd, add one extra sample randomly
    if N % 2 == 1:
        extra_sample = center_0 + std * np.random.randn(1, d)
        X = np.vstack([X, extra_sample])
        y = np.append(y, 0)

    return X, y


def train(model, X_train, y_train, X_val, y_val, lr=0.1, epochs=1000, old_weights=None, lambda_reg=0.0):
#   Below select choose logisitic loss or hinge loss.
#    criterion = nn.BCEWithLogitsLoss()
    criterion = hinge_loss
    optimizer = optim.SGD(model.parameters(), lr=lr)

    best_val_acc = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train).squeeze()
        loss = criterion(outputs, y_train)
        
        # Add EWC-style regularization if old_weights are provided
        if old_weights is not None and lambda_reg > 0.0:
            reg_loss = 0.0
            for (name, param), old_param in zip(model.named_parameters(), old_weights):
                reg_loss += ((param - old_param)**2).sum()
            loss += lambda_reg * reg_loss


        loss.backward()
        optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val).squeeze()
#            val_preds = (torch.sigmoid(val_outputs) >= 0.5).int()
            val_preds = ((torch.sign(val_outputs) + 1) // 2).int()
            val_acc = (val_preds == y_val.int()).float().mean().item()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = {k: v.clone().detach() for k, v in model.state_dict().items()}

        if (epoch+1) % 10 == 0:
            print(f"Epoch {epoch}: train loss = {loss.item():.4f}, val acc = {val_acc:.4f}")

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, best_val_acc

def evaluate(model, X_test, y_test):
    model.eval()
    with torch.no_grad():
        outputs = model(X_test).squeeze()
#        preds = (torch.sigmoid(outputs) >= 0.5).int()
        preds = ((torch.sign(outputs) + 1) // 2).int()
        acc = (preds == y_test.int()).float().mean().item()
    return acc
    
    
    
name = 'hinge_N'+str(args.N)+'gamma'+str(args.gamma1)+'std'+str(args.std1)+'gamma'+str(args.gamma2)+'std'+str(args.std2)+'_d'+str(args.d)+'_T'+str(args.T)
log_filename = 'marginlog/'+name+'.txt'
log = open(log_filename, 'w')
sys.stdout = log


# Generate your data
N = args.N
d = args.d
gamma1 = args.gamma1
gamma2 = args.gamma2
std1 = args.std1
std2 = args.std2

lr_list = [0.001,0.01, 0.05, 0.1, 0.5, 1.0]
reg_list = [1e-5,0.0001,0.001,0.01,0.1,1.0,10.0]


curr_testacc = []
baseline_testacc = []
for seed in range(10):
    seed_worker = set_seed(seed)
    # Distribution A (easy): large margin, small variance
    X_A, y_A = generate_data(int(N*0.5), gamma=gamma1, std=std1, d=d, seed=seed)

    # Distribution B (hard): small margin, larger variance
    X_B, y_B = generate_data(int(N*0.5), gamma=gamma2, std=std2, d=d, seed=seed)
    
    X_val, y_val = generate_data(int(N*0.2), gamma=gamma2, std=std2, d=d, seed=seed)
    X_test, y_test = generate_data(int(N*0.2), gamma=gamma2, std=std2, d=d, seed=seed)

    # Convert to PyTorch tensors
    X_A = torch.tensor(X_A, dtype=torch.float32).to(device)
    y_A = torch.tensor(y_A, dtype=torch.float32).to(device)
    X_B = torch.tensor(X_B, dtype=torch.float32).to(device)
    y_B = torch.tensor(y_B, dtype=torch.float32).to(device)
    X_val_torch = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_torch = torch.tensor(y_val, dtype=torch.float32).to(device)
    X_test_torch = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_torch = torch.tensor(y_test, dtype=torch.float32).to(device)

    # ===== Baseline: train on full data =====
    best_val_acc = 0
    best_acc = 0
    for lr in lr_list:
        model = LinearClassifier(d).to(device)
        trained_model, val_acc = train(model, X_B, y_B, X_val_torch, y_val_torch, lr=lr, epochs=args.T, old_weights=None, lambda_reg=0.0)
        test_acc = evaluate(trained_model, X_test_torch, y_test_torch)
        print(f"Baseline (train on mixture): Test accuracy = {test_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_acc = test_acc
            chosen_lr = lr
    baseline_test_acc = best_acc
    print("Baseline on mixture select lr: ", chosen_lr, "and test accuracy is ", best_acc)
    print("\n")

    # ===== Curriculum: first A, then B =====
    # Step 1: train on easy data only (Distribution A)

    best_val_acc = 0
    best_acc = 0
    best_params = None
    best_model_state = None
    for lr in lr_list:
        model_curriculum = LinearClassifier(d).to(device)
        trained_model, val_acc = train(model_curriculum, X_A, y_A, X_val_torch, y_val_torch, lr=lr, epochs=args.T, old_weights=None, lambda_reg=0.0)
        print(f"LR={lr} => Acc={val_acc:.4f}\n")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_acc = test_acc
            best_model_state = {k: v.clone().detach() for k, v in trained_model.state_dict().items()}
            best_params = [p.clone().detach() for p in trained_model.parameters()]


    # Step 2: continue training on full data (mixture)
    model_curriculum.load_state_dict(best_model_state)
    best_val_acc = 0
    best_acc = 0
    best_model_state = None
    for lr in lr_list:
        for reg in reg_list:
            trained_model, val_acc = train(model_curriculum, X_B, y_B, X_val_torch, y_val_torch, lr=lr, epochs=args.T, old_weights=best_params, lambda_reg=reg)
            test_acc = evaluate(trained_model, X_test_torch, y_test_torch)
            print(f"REG={reg}, LR={lr} => Val Acc={val_acc:.4f}, Test Acc={test_acc:.4f}\n")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_acc = test_acc
                best_model_state = {k: v.clone().detach() for k, v in trained_model.state_dict().items()}
                chosen_lr = lr
                chosen_reg = reg

    print("Curriculum select lr: ", chosen_lr, "select reg: ", chosen_reg, "and test accuracy is ", best_acc)
    model_curriculum = LinearClassifier(d).to(device)
    model_curriculum.load_state_dict(best_model_state)
    test_acc_curriculum = evaluate(model_curriculum, X_test_torch, y_test_torch)
    baseline_testacc.append(baseline_test_acc)
    curr_testacc.append(test_acc_curriculum)


print("Curriculum (easy first + regularized full training): Test accuracy=", np.mean(curr_testacc))
print("Curriculum: Test std=", np.std(curr_testacc))
print("Baseline (train on mixture): Test accuracy=", np.mean(baseline_testacc))
print("Baseline: Test std=", np.std(baseline_testacc))
