import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor
from torch.nn.parallel import DataParallel

from model import resnet18,resnet34,resnet50,resnet101,resnet152
import torchvision.models as models

from numpy import corrcoef

import torchmetrics

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from CIFAR100.CIFAR100_Dataloader import CIFAR100DataLoader
from FER.FER2013_Dataloader import FER2013DataLoader
from AffectNet.AffectNet_Dataloader import AffectNetDataLoader

import argparse

import wandb
wandb.init(
    # set the wandb project where this run will be logged
    project="BigNoise",
)


#weight dependency considered
def running_estimation(prev_param, new_param):
    updated_param = (0.1) * prev_param + (0.9) * (new_param)
    return updated_param

#args parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--pretrained_true', default=False, type=bool)
parser.add_argument('--weight_running_estimation', default=False, type=bool)
parser.add_argument('--model', type=str)
parser.add_argument('--today', type=float)
parser.add_argument('--type', type=str)
parser.add_argument('--checkpointFolderPath', type=str)
opt = parser.parse_args()


if opt.dataset == "CIFAR100":
    train_dataloader, test_dataloader = CIFAR100DataLoader(opt.batch_size)
    num_classes = 100
elif opt.dataset == "FER2013":
    trainloader, PublicTestloader, PrivateTestloader = FER2013DataLoader(opt.batch_size)
    train_dataloader = trainloader
    test_dataloader = PrivateTestloader
    num_classes = 7
elif opt.dataset == "AffectNet":
    train_dataloader, test_dataloader =  AffectNetDataLoader(opt.batch_size)
    num_classes = 8



device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


if opt.model == "resnet18":
    model = resnet18(num_classes).to(device)

elif opt.model == "resnet34":
    model = resnet34(num_classes).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr,weight_decay=0.0001)
train_F1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes).to(device)
test_F1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes).to(device)

if opt.weight_running_estimation:
    print("weight_running_estimation")
    prev_param_dict = model.state_dict()

if opt.pretrained_true :
    print("pretrained")
    model.load_state_dict(torch.load())


def train(dataloader, model, loss_fn, optimizer, metric):
    num_batches = len(dataloader)
    F1 = 0
    losses = []


    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 예측 오류 계산
        pred = model(X)
        loss = loss_fn(pred, y)
        losses.append(loss.item())

        F1 = F1 + metric(pred, y).item()
        wandb.log({"train loss":loss})

        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    seq = [i for i in range(len(losses))]
    corr = corrcoef(seq, losses)[0][1]
    wandb.log({"corr":corr})
    if corr >0 :
        return 1
    else:
        return 0 

def test(dataloader, model, loss_fn, metric):
    F1s = []
    num_batches = len(dataloader)
    F1 = 0
    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)

            loss = loss_fn(pred, y).item()
            F1 = F1 + metric(pred, y).item()
            F1s.append(F1)
            wandb.log({"test loss": loss})
        wandb.log({"test f1": F1/num_batches})
        return sum(F1s)/len(F1s)

MaxF1 = 0
for t in range(opt.epochs):

    print(f"Epoch {t+1}\n-------------------------------")
    switch= train(train_dataloader, model, loss_fn, optimizer, train_F1)
    F1 = test(test_dataloader, model, loss_fn, test_F1)

    if switch :  
        model.envSet()

        if opt.weight_running_estimation:
                    new_param_dict = model.named_parameters()

                    # averaging
                    for name, new_param in new_param_dict:
                        updated_param = running_estimation(prev_param=prev_param_dict[name], new_param=new_param)
                        prev_param_dict[name] = updated_param

                    model.load_state_dict(prev_param_dict)
    if F1 > MaxF1:
                    
                    MaxF1 = F1

                    torch.save(model.state_dict(),f"{opt.checkpointFolderPath}/{opt.type}+{opt.batch_size}+{opt.lr}+{opt.model}+{opt.dataset}+{opt.today}+{opt.type}.pth" )


