import torch
import sys
import os
import pandas as pd
import argparse
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import models
import time
from .data import load_data
import datetime
import random

def random_seed(seed):
    random.seed(seed)
    os.environ["PYTHONSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res




def test(encoder, classifier, test_loader, device, data):

    top1_accuracy = 0
    top5_accuracy = 0

    classifier.eval()

    with torch.no_grad():
        for counter, (x_batch, y_batch) in enumerate(test_loader):

            if data == 'mnist' or data == 'fashion-mnist':
                x_batch = x_batch.expand(-1, 3, -1, -1)

            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            h = encoder(x_batch)
                
            x_in = h.view(h.size(0), -1)
                
            logits = classifier(x_in)

            top1, top5 = accuracy(logits, y_batch, topk=(1,5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]

        top1_accuracy /= (counter + 1)
        top5_accuracy /= (counter + 1)

    return top1_accuracy.item(), top5_accuracy.item()







def classify(data, encoder, head, args, mode="FTLL"):
    """
    mode: "FTLL" = freeze encoder, only train head
          "FTAL" = fine-tune encoder + head
    """

    # ---------------------------
    # 1. Create folder
    # ---------------------------
    save_dir = f"{data}_{args.arch}_{mode}"
    os.makedirs(save_dir, exist_ok=True)

    metrics_path = os.path.join(save_dir, "records.txt")
    metrics_file = open(metrics_path, "w")

    # ---------------------------
    # 2. Load data
    # ---------------------------
    train_loader, test_loader = load_data(data, args.batch_size)
    F_head = head

    encoder.to(args.device)
    F_head.to(args.device)

    # ---------------------------
    # 3. Training mode settings
    # ---------------------------
    if mode == "FTLL":
        print(">> Training mode: FTLL (freeze encoder, only train head)")
        for p in encoder.parameters():
            p.requires_grad = False
        optimizer = torch.optim.Adam(
            F_head.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )

    elif mode == "FTAL":
        print(">> Training mode: FTAL (fine-tune encoder + head)")
        for p in encoder.parameters():
            p.requires_grad = True
        optimizer = torch.optim.Adam(
            list(encoder.parameters()) + list(F_head.parameters()),
            lr=args.lr, weight_decay=args.weight_decay   
        )

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)
    criterion = torch.nn.CrossEntropyLoss().to(args.device)

    encoder.train() if mode == "FTAL" else encoder.eval()
    F_head.train()

    # ---------------------------
    # 4. Start training
    # ---------------------------
    for epoch in range(args.epochs):
        start = time.time()
        top1_train_accuracy = 0

        for counter, (x_batch, y_batch) in enumerate(train_loader):

            optimizer.zero_grad()

            if data in ("mnist", "fashion-mnist"):
                x_batch = x_batch.expand(-1, 3, -1, -1)

            x_batch = x_batch.to(args.device)
            y_batch = y_batch.to(args.device)

            # encoder forward
            h = encoder(x_batch)
            downstream_input = h.view(h.size(0), -1)

            # head forward
            logits = F_head(downstream_input)
            loss = criterion(logits, y_batch)

            # accuracy
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy += top1[0]

            loss.backward()
            optimizer.step()

        # ---------------------------
        # 5. Test
        # ---------------------------
        top1_test, _ = test(encoder, F_head, test_loader, args.device, data)
        scheduler.step()

        end = time.time()
        duration = end - start
        top1_train_accuracy /= (counter + 1)

        # Record
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        msg = (
            f"[{timestamp}] epoch {epoch}: "
            f"train_top1={top1_train_accuracy.item():.4f}, "
            f"test_top1={top1_test:.4f}, "
            f"loss={loss:.4f}, "
            f"time={duration:.2f}s"
        )
        print(msg)
        metrics_file.write(msg + "\n")

    metrics_file.close()

    # ---------------------------
    # 6. Save
    # ---------------------------
    torch.save(encoder.state_dict(), os.path.join(save_dir, "encoder_state_dict.pth"))
    torch.save(F_head.state_dict(), os.path.join(save_dir, "head_state_dict.pth"))

    print(f"\nModels saved in {save_dir}/")
    print("Training complete.\n")

    return F_head



