import numpy as np
import torch, torchvision
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import math
from typing import *
from PIL import Image
import random
import copy
import pickle
import argparse
import sys
from torch.autograd import Variable, Function
from sklearn.model_selection import StratifiedKFold, train_test_split, ShuffleSplit
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, Subset, DataLoader, random_split, ConcatDataset
import h5py

import datetime, time
date_time = datetime.datetime.utcnow().isoformat().replace(":", "")

parser = argparse.ArgumentParser(description='Curriculum Learning')
parser.add_argument('--dataset', default='mnist', type=str, help='mnist')
parser.add_argument('--arch', default='linear', type=str, help='linear, twolayernet, cnn, net')
parser.add_argument('--T', default=1000, type=int, help='epoch')
parser.add_argument('--K', default=10, type=int, help='number of tasks, curriculum step')
parser.add_argument('--sigma', default=1.0, type=float, help='noise level')
parser.add_argument('--seed', default=0, type=int, help='seed')
args = parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Optional: for DataLoader workers
    def seed_worker(worker_id):
        worker_seed = seed + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    return seed_worker
    
seed = args.seed
seed_worker = set_seed(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class linear(nn.Module):
    def __init__(self, d=784):
        super(linear, self).__init__()
        self.net = nn.Linear(d, 10, bias = True)
    def forward(self, x):
        x = x.contiguous()
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        output = self.net(x)
        return output

class twolayernet(nn.Module):
    def __init__(self, d=784, width=100):
        super(twolayernet, self).__init__()
        self.fc1 = nn.Linear(d, width, bias = True)
        self.fc2 = nn.Linear(width, 10, bias = True)
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class cnn(nn.Module):
    def __init__(self, width=100):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, width)
        self.fc2 = nn.Linear(width, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
class NoisyDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, noise_std):
        self.base = base_dataset
        self.noise_std = noise_std

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        x, y = self.base[idx]
        x = x.float()
        noise = torch.randn_like(x) * self.noise_std
        x_noisy = torch.clamp(x + noise, 0.0, 1.0)
        return x_noisy, y


def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            preds = output.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total
    

def train_model(model, train_loader, val_loader, epochs, lr, l2_reg, previous_params=None):
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    model.train()

    best_val_acc = -100
    best_model = None

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            if previous_params:
                reg = sum(((p - pp)**2).sum() for p, pp in zip(model.parameters(), previous_params))
                loss += l2_reg * reg
            loss.backward()
            optimizer.step()

        val_acc = evaluate(model, val_loader)
        if (epoch+1)%10==0:
            print("epoch:", epoch, "val_acc:", val_acc)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = {k: v.clone().detach() for k, v in model.state_dict().items()}

    model.load_state_dict(best_model)
    return model, best_val_acc



T = args.T
N_TASKS = args.K
BATCH_SIZE = 128
VAL_RATIO = 0.2
L2_REG_LIST = [1e-5,1e-4,1e-3,1e-2,1e-1,1,10]
LR_LIST = [1e-1,1e-2,1e-3]



name = 'new_noisy_comb_'+str(args.dataset)+'_'+str(args.arch)+'_T'+str(T)+'_Sigma'+str(args.sigma)+'_'+str(seed)
log_filename = 'splitdata/'+name+'.txt'
log = open(log_filename, 'w')
sys.stdout = log


    
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

digit_groups = [
    [4,7,9],  # Group 1
    [3,5,8],  # Group 2
    [0,1,2,6],     # Group 3
]

# Filter dataset by digits
def filter_digits(dataset, digits):
    indices = [i for i, label in enumerate(dataset.targets) if label in digits]
    return Subset(dataset, indices)

# Split into train/val for each group
train_loaders = []
train_subsets, val_subsets = [], []

for digits in digit_groups:
    subset = filter_digits(train_dataset, digits)
    total_len = len(subset)
    val_len = int(VAL_RATIO * total_len)
    train_len = total_len - val_len

    train_subset, val_subset = random_split(subset, [train_len, val_len], generator=torch.Generator().manual_seed(42))
    train_subsets.append(NoisyDataset(train_subset, args.sigma))
    val_subsets.append(NoisyDataset(val_subset, args.sigma))
    if len(train_subsets)>1:
        train_subsets[-1] = ConcatDataset(train_subsets[-2:])

    train_loader = DataLoader(train_subsets[-1], batch_size=BATCH_SIZE, shuffle=True)
    train_loaders.append(train_loader)

noisy_test_subsets = []
for digit in range(10):
    subset = filter_digits(test_dataset, [digit])
    noisy_test_subsets.append(NoisyDataset(subset, args.sigma))

train_dataset = ConcatDataset(train_subsets)
train_fullloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, worker_init_fn=seed_worker,generator=torch.Generator().manual_seed(seed))
val_dataset = ConcatDataset(val_subsets)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(ConcatDataset(noisy_test_subsets), BATCH_SIZE)

prev_best_params = None
for t in range(len(digit_groups)):
    print(f"\n=== Task {t+1}/{N_TASKS} ===")
    best_acc = 0
    if t>0:
        L2_REG = L2_REG_LIST
    else:
        L2_REG = [0]
    for l2_reg in L2_REG:
        for lr in LR_LIST:
            if args.arch=="linear":
                model = linear().to(device)
            elif args.arch=="twolayernet":
                model = twolayernet().to(device)
            elif args.arch=="cnn":
                model = cnn().to(device)
            if prev_best_params:
                model.load_state_dict(prev_best_model)

            trained_model, best_val_acc = train_model(model, train_loaders[t], val_loader, T, lr, l2_reg, previous_params=None if t == 0 else prev_best_params)
            print(f"L2={l2_reg}, LR={lr} => Acc={best_val_acc:.4f}")
            if best_val_acc > best_acc:
                chosen_reg = l2_reg
                chosen_lr = lr
                best_acc = best_val_acc
                best_model_state = {k: v.clone().detach() for k, v in trained_model.state_dict().items()}
                best_params = [p.clone().detach() for p in trained_model.parameters()]

    prev_best_model = best_model_state
    prev_best_params = best_params
    print("Task:", t, "chosen L2 REG:", chosen_reg, "chosen LR:", chosen_lr)

# Final evaluation on task N (shared test set)
if args.arch=="linear":
    final_model = linear().to(device)
elif args.arch=="twolayernet":
    final_model = twolayernet().to(device)
elif args.arch=="cnn":
    final_model = cnn().to(device)
final_model.load_state_dict(prev_best_model)
test_acc = evaluate(final_model, test_loader)
print(f"\n Curriculum trained model accuracy on task N: {test_acc:.4f}")

    
for i in range(len(test_loader_digits)):
    test_acc_sep = evaluate(final_model, test_loader_digits[i])
    print("\n Curriculum trained model accuracy on digit ", i, ": test_acc:", test_acc_sep)


print(f"Directly trained on noisy N")
# === Direct training on task N
best_acc_direct = 0
best_val_acc = 0
for lr in LR_LIST:
    if args.arch=="linear":
        model = linear().to(device)
    elif args.arch=="twolayernet":
        model = twolayernet().to(device)
    elif args.arch=="cnn":
        model = cnn().to(device)
    trained, val_acc = train_model(model, train_fullloader, val_loader, T, lr, l2_reg=0.0)
    acc = evaluate(trained, test_loader)
    print(f"LR={lr} => Test Acc={acc:.4f}")
    if val_acc>best_val_acc:
        best_val_acc = val_acc
        best_acc_direct = acc
        best_model_state = {k: v.clone().detach() for k, v in trained.state_dict().items()}
        
if args.arch=="linear":
    final_model = linear().to(device)
elif args.arch=="twolayernet":
    final_model = twolayernet().to(device)
elif args.arch=="cnn":
    final_model = cnn().to(device)
final_model.load_state_dict(best_model_state)

        

# === Print results
print("\n=== Final Comparison ===")
print(f"Multi-task trained model       : {test_acc:.4f}")
print(f"Directly trained on noisy N    : {best_acc_direct:.4f}")
